mirror of
https://github.com/ollama/ollama.git
synced 2026-02-18 23:36:41 -05:00
Compare commits
14 Commits
v0.16.2
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
365a3657ad | ||
|
|
71c1d8d0a9 | ||
|
|
325b72bc31 | ||
|
|
f01a9a7859 | ||
|
|
9aefd2dfee | ||
|
|
d07e4a1dd3 | ||
|
|
8a257ec00a | ||
|
|
2f4de1acf7 | ||
|
|
ec95c45f70 | ||
|
|
3a88f7eb20 | ||
|
|
0d5da826d4 | ||
|
|
9b795698b8 | ||
|
|
041fb77639 | ||
|
|
8224cce583 |
@@ -1 +1 @@
|
||||
v0.4.1
|
||||
v0.5.0
|
||||
|
||||
@@ -16,7 +16,7 @@ Start building with open models.
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
```
|
||||
|
||||
or [download manually](http://localhost:8080/download/Ollama.dmg)
|
||||
or [download manually](https://ollama.com/download/Ollama.dmg)
|
||||
|
||||
### Windows
|
||||
|
||||
|
||||
22
auth/auth.go
22
auth/auth.go
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -83,3 +84,24 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||
// signature is <pubkey>:<signature>
|
||||
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
|
||||
}
|
||||
|
||||
// SignRequest adds a nonce query parameter and an Authorization header with
|
||||
// an Ed25519 signature to req.
|
||||
func SignRequest(ctx context.Context, req *http.Request) error {
|
||||
nonce, err := NewNonce(rand.Reader, 16)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
q.Set("nonce", nonce)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
data := []byte(fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI()))
|
||||
signature, err := Sign(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", signature)
|
||||
return nil
|
||||
}
|
||||
|
||||
40
cmd/cmd.go
40
cmd/cmd.go
@@ -57,9 +57,9 @@ import (
|
||||
|
||||
func init() {
|
||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem) (string, error) {
|
||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
@@ -182,6 +182,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
case "parser":
|
||||
mfConfig.Parser = cmd.Args
|
||||
case "renderer":
|
||||
mfConfig.Renderer = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1896,10 +1900,25 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
return
|
||||
}
|
||||
|
||||
if version.Version != "0.0.0" && version.IsOfficialInstall() && version.IsLocalHost(envconfig.Host()) {
|
||||
if version.HasCachedUpdate() {
|
||||
fmt.Print("A new version of Ollama is available. Run \"ollama update\" to install.\n\n")
|
||||
_ = version.ClearCachedUpdate()
|
||||
}
|
||||
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if available, err := version.CheckForUpdate(ctx); err == nil && available {
|
||||
_ = version.CacheAvailableUpdate()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem) (string, error) {
|
||||
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
@@ -2313,6 +2332,18 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
}
|
||||
|
||||
updateCmd := &cobra.Command{
|
||||
Use: "update",
|
||||
Short: "Update Ollama to the latest version",
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
force, _ := cmd.Flags().GetBool("force")
|
||||
_ = version.ClearCachedUpdate()
|
||||
return version.DoUpdate(force)
|
||||
},
|
||||
}
|
||||
updateCmd.Flags().BoolP("force", "f", false, "Force update even if installed via a package manager")
|
||||
|
||||
rootCmd.AddCommand(
|
||||
serveCmd,
|
||||
createCmd,
|
||||
@@ -2330,6 +2361,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
updateCmd,
|
||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
)
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := DefaultSingleSelector("Select model:", items)
|
||||
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
123
cmd/config/cline.go
Normal file
123
cmd/config/cline.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Cline implements Runner and Editor for the Cline CLI integration
|
||||
type Cline struct{}
|
||||
|
||||
func (c *Cline) String() string { return "Cline" }
|
||||
|
||||
func (c *Cline) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("cline"); err != nil {
|
||||
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "cline", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("cline", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (c *Cline) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
p := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cline) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
|
||||
}
|
||||
}
|
||||
|
||||
// Set Ollama as the provider for both act and plan modes
|
||||
baseURL := envconfig.Host().String()
|
||||
config["ollamaBaseUrl"] = baseURL
|
||||
config["actModeApiProvider"] = "ollama"
|
||||
config["actModeOllamaModelId"] = models[0]
|
||||
config["actModeOllamaBaseUrl"] = baseURL
|
||||
config["planModeApiProvider"] = "ollama"
|
||||
config["planModeOllamaModelId"] = models[0]
|
||||
config["planModeOllamaBaseUrl"] = baseURL
|
||||
|
||||
config["welcomeViewCompleted"] = true
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Cline) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if config["actModeApiProvider"] != "ollama" {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelID, _ := config["actModeOllamaModelId"].(string)
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{modelID}
|
||||
}
|
||||
204
cmd/config/cline_test.go
Normal file
204
cmd/config/cline_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClineIntegration(t *testing.T) {
|
||||
c := &Cline{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Cline" {
|
||||
t.Errorf("String() = %q, want %q", got, "Cline")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClineEdit(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
|
||||
readConfig := func() map[string]any {
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var config map[string]any
|
||||
json.Unmarshal(data, &config)
|
||||
return config
|
||||
}
|
||||
|
||||
t.Run("creates config from scratch", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeApiProvider"] != "ollama" {
|
||||
t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
|
||||
}
|
||||
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
if config["planModeApiProvider"] != "ollama" {
|
||||
t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
|
||||
}
|
||||
if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
|
||||
}
|
||||
if config["welcomeViewCompleted"] != true {
|
||||
t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves existing fields", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
existing := map[string]any{
|
||||
"remoteRulesToggles": map[string]any{},
|
||||
"remoteWorkflowToggles": map[string]any{},
|
||||
"customSetting": "keep-me",
|
||||
}
|
||||
data, _ := json.Marshal(existing)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["customSetting"] != "keep-me" {
|
||||
t.Errorf("customSetting was not preserved")
|
||||
}
|
||||
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates model on re-edit", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
if config["planModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit(nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||
t.Error("expected no config file to be created for empty models")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses first model as primary", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud", "glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClineModels(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
|
||||
t.Run("returns nil when no config", func(t *testing.T) {
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when provider is not ollama", func(t *testing.T) {
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
config := map[string]any{
|
||||
"actModeApiProvider": "anthropic",
|
||||
"actModeOllamaModelId": "some-model",
|
||||
}
|
||||
data, _ := json.Marshal(config)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns model when ollama is configured", func(t *testing.T) {
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
config := map[string]any{
|
||||
"actModeApiProvider": "ollama",
|
||||
"actModeOllamaModelId": "kimi-k2.5:cloud",
|
||||
}
|
||||
data, _ := json.Marshal(config)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClinePaths(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns nil when no config exists", func(t *testing.T) {
|
||||
if paths := c.Paths(); paths != nil {
|
||||
t.Errorf("Paths() = %v, want nil", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 || paths[0] != configPath {
|
||||
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -54,6 +53,7 @@ type AliasConfigurer interface {
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
"clawdbot": &Openclaw{},
|
||||
"cline": &Cline{},
|
||||
"codex": &Codex{},
|
||||
"moltbot": &Openclaw{},
|
||||
"droid": &Droid{},
|
||||
@@ -102,16 +102,17 @@ var recommendedVRAM = map[string]string{
|
||||
var integrationAliases = map[string]bool{
|
||||
"clawdbot": true,
|
||||
"moltbot": true,
|
||||
"pi": true,
|
||||
}
|
||||
|
||||
// integrationInstallHints maps integration names to install URLs.
|
||||
var integrationInstallHints = map[string]string{
|
||||
"claude": "https://code.claude.com/docs/en/quickstart",
|
||||
"cline": "https://cline.bot/cli",
|
||||
"openclaw": "https://docs.openclaw.ai",
|
||||
"codex": "https://developers.openai.com/codex/cli/",
|
||||
"droid": "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||
"opencode": "https://opencode.ai",
|
||||
"pi": "https://github.com/badlogic/pi-mono",
|
||||
}
|
||||
|
||||
// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable.
|
||||
@@ -129,13 +130,21 @@ type IntegrationInfo struct {
|
||||
// integrationDescriptions maps integration names to short descriptions.
|
||||
var integrationDescriptions = map[string]string{
|
||||
"claude": "Anthropic's coding tool with subagents",
|
||||
"cline": "Autonomous coding agent with parallel execution",
|
||||
"codex": "OpenAI's open-source coding agent",
|
||||
"openclaw": "Personal AI with 100+ skills",
|
||||
"droid": "Factory's coding agent across terminal and IDEs",
|
||||
"opencode": "Anomaly's open-source coding agent",
|
||||
"pi": "Minimal AI agent toolkit with plugin support",
|
||||
}
|
||||
|
||||
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name.
|
||||
// integrationOrder defines a custom display order for integrations.
|
||||
// Integrations listed here are placed at the end in the given order;
|
||||
// all others appear first, sorted alphabetically.
|
||||
var integrationOrder = []string{"opencode", "droid", "pi", "cline"}
|
||||
|
||||
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name
|
||||
// with integrationOrder entries placed at the end.
|
||||
func ListIntegrationInfos() []IntegrationInfo {
|
||||
var result []IntegrationInfo
|
||||
for name, r := range integrations {
|
||||
@@ -148,7 +157,26 @@ func ListIntegrationInfos() []IntegrationInfo {
|
||||
Description: integrationDescriptions[name],
|
||||
})
|
||||
}
|
||||
|
||||
orderRank := make(map[string]int, len(integrationOrder))
|
||||
for i, name := range integrationOrder {
|
||||
orderRank[name] = i + 1 // 1-indexed so 0 means "not in the list"
|
||||
}
|
||||
|
||||
slices.SortFunc(result, func(a, b IntegrationInfo) int {
|
||||
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||
// Both have custom order: sort by their rank
|
||||
if aRank > 0 && bRank > 0 {
|
||||
return aRank - bRank
|
||||
}
|
||||
// Only one has custom order: it goes last
|
||||
if aRank > 0 {
|
||||
return 1
|
||||
}
|
||||
if bRank > 0 {
|
||||
return -1
|
||||
}
|
||||
// Neither has custom order: alphabetical
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
return result
|
||||
@@ -186,9 +214,15 @@ func IsIntegrationInstalled(name string) bool {
|
||||
case "droid":
|
||||
_, err := exec.LookPath("droid")
|
||||
return err == nil
|
||||
case "cline":
|
||||
_, err := exec.LookPath("cline")
|
||||
return err == nil
|
||||
case "opencode":
|
||||
_, err := exec.LookPath("opencode")
|
||||
return err == nil
|
||||
case "pi":
|
||||
_, err := exec.LookPath("pi")
|
||||
return err == nil
|
||||
default:
|
||||
return true // Assume installed for unknown integrations
|
||||
}
|
||||
@@ -214,7 +248,8 @@ type ModelItem struct {
|
||||
}
|
||||
|
||||
// SingleSelector is a function type for single item selection.
|
||||
type SingleSelector func(title string, items []ModelItem) (string, error)
|
||||
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
|
||||
|
||||
// MultiSelector is a function type for multi item selection.
|
||||
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||
@@ -257,7 +292,7 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
|
||||
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
selected, err := selector("Select model to run:", items)
|
||||
selected, err := selector("Select model to run:", items, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -367,13 +402,11 @@ func selectIntegration() (string, error) {
|
||||
return "", fmt.Errorf("no integrations available")
|
||||
}
|
||||
|
||||
names := slices.Sorted(maps.Keys(integrations))
|
||||
var items []ModelItem
|
||||
for _, name := range names {
|
||||
for name, r := range integrations {
|
||||
if integrationAliases[name] {
|
||||
continue
|
||||
}
|
||||
r := integrations[name]
|
||||
description := r.String()
|
||||
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
||||
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
|
||||
@@ -381,7 +414,25 @@ func selectIntegration() (string, error) {
|
||||
items = append(items, ModelItem{Name: name, Description: description})
|
||||
}
|
||||
|
||||
return DefaultSingleSelector("Select integration:", items)
|
||||
orderRank := make(map[string]int, len(integrationOrder))
|
||||
for i, name := range integrationOrder {
|
||||
orderRank[name] = i + 1
|
||||
}
|
||||
slices.SortFunc(items, func(a, b ModelItem) int {
|
||||
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||
if aRank > 0 && bRank > 0 {
|
||||
return aRank - bRank
|
||||
}
|
||||
if aRank > 0 {
|
||||
return 1
|
||||
}
|
||||
if bRank > 0 {
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
return DefaultSingleSelector("Select integration:", items, "")
|
||||
}
|
||||
|
||||
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
|
||||
@@ -439,7 +490,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
if _, ok := r.(AliasConfigurer); ok {
|
||||
prompt = fmt.Sprintf("Select Primary model for %s:", r)
|
||||
}
|
||||
model, err := single(prompt, items)
|
||||
model, err := single(prompt, items, current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -812,10 +863,12 @@ Without arguments, this is equivalent to running 'ollama' directly.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
cline Cline
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||
pi Pi
|
||||
|
||||
Examples:
|
||||
ollama launch
|
||||
@@ -915,11 +968,9 @@ Examples:
|
||||
}
|
||||
|
||||
// Validate saved model still exists
|
||||
cloudCleared := false
|
||||
if model != "" && modelFlag == "" {
|
||||
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled && isCloudModelName(model) {
|
||||
model = ""
|
||||
cloudCleared = true
|
||||
} else if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
|
||||
if err := ShowOrPull(cmd.Context(), client, model); err != nil {
|
||||
@@ -928,18 +979,16 @@ Examples:
|
||||
}
|
||||
}
|
||||
|
||||
// If no valid model or --config flag, show picker
|
||||
if model == "" || configFlag {
|
||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag || cloudCleared)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model = aliases["primary"]
|
||||
existingAliases = aliases
|
||||
// Show picker so user can change model (skip when --model flag provided)
|
||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, modelFlag == "")
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model = aliases["primary"]
|
||||
existingAliases = aliases
|
||||
|
||||
// Ensure cloud models are authenticated
|
||||
if isCloudModel(cmd.Context(), client, model) {
|
||||
@@ -1001,27 +1050,13 @@ Examples:
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
|
||||
savedModels := filterDisabledCloudModels(saved.Models)
|
||||
if len(savedModels) != len(saved.Models) {
|
||||
_ = SaveIntegration(name, savedModels)
|
||||
}
|
||||
if len(savedModels) == 0 {
|
||||
// All saved models were cloud — fall through to picker
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
models = savedModels
|
||||
return runIntegration(name, models[0], passArgs)
|
||||
}
|
||||
} else {
|
||||
current := ""
|
||||
if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 {
|
||||
current = saved.Models[0]
|
||||
}
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
models, err = selectModels(cmd.Context(), name, current)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1248,10 +1248,26 @@ func TestListIntegrationInfos(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sorted by name", func(t *testing.T) {
|
||||
t.Run("sorted with custom order at end", func(t *testing.T) {
|
||||
// integrationOrder entries (cline, opencode) should appear last, in that order.
|
||||
// All other entries should be sorted alphabetically before them.
|
||||
orderRank := make(map[string]int)
|
||||
for i, name := range integrationOrder {
|
||||
orderRank[name] = i + 1
|
||||
}
|
||||
for i := 1; i < len(infos); i++ {
|
||||
if infos[i-1].Name >= infos[i].Name {
|
||||
t.Errorf("not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
||||
aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
|
||||
switch {
|
||||
case aRank == 0 && bRank == 0:
|
||||
if infos[i-1].Name >= infos[i].Name {
|
||||
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
||||
}
|
||||
case aRank > 0 && bRank == 0:
|
||||
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
|
||||
case aRank > 0 && bRank > 0:
|
||||
if aRank >= bRank {
|
||||
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -365,14 +365,27 @@ func (m selectorModel) View() string {
|
||||
return s
|
||||
}
|
||||
|
||||
func SelectSingle(title string, items []SelectItem) (string, error) {
|
||||
// cursorForCurrent returns the item index matching current, or 0 if not found.
|
||||
func cursorForCurrent(items []SelectItem, current string) int {
|
||||
if current != "" {
|
||||
for i, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func SelectSingle(title string, items []SelectItem, current string) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
title: title,
|
||||
items: items,
|
||||
cursor: cursorForCurrent(items, current),
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
@@ -402,6 +415,12 @@ type multiSelectorModel struct {
|
||||
cancelled bool
|
||||
confirmed bool
|
||||
width int
|
||||
|
||||
// multi enables full multi-select editing mode. The zero value (false)
|
||||
// shows a single-select picker where Enter adds the chosen model to
|
||||
// the existing list. Tab toggles between modes.
|
||||
multi bool
|
||||
singleAdd string // model picked in single mode
|
||||
}
|
||||
|
||||
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
|
||||
@@ -416,13 +435,23 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
|
||||
m.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := m.itemIndex[name]; ok {
|
||||
// Reverse order so preChecked[0] (the current default) ends up last
|
||||
// in checkOrder, matching the "last checked = default" convention.
|
||||
for i := len(preChecked) - 1; i >= 0; i-- {
|
||||
if idx, ok := m.itemIndex[preChecked[i]]; ok {
|
||||
m.checked[idx] = true
|
||||
m.checkOrder = append(m.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
// Position cursor on the current default model
|
||||
if len(preChecked) > 0 {
|
||||
if idx, ok := m.itemIndex[preChecked[0]]; ok {
|
||||
m.cursor = idx
|
||||
m.updateScroll(m.otherStart())
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -533,14 +562,25 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyTab:
|
||||
m.multi = !m.multi
|
||||
|
||||
case tea.KeyEnter:
|
||||
if len(m.checkOrder) > 0 {
|
||||
if !m.multi {
|
||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||
m.singleAdd = filtered[m.cursor].Name
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
} else if len(m.checkOrder) > 0 {
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case tea.KeySpace:
|
||||
m.toggleItem()
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
@@ -579,7 +619,9 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// On some terminals (e.g. Windows PowerShell), space arrives as
|
||||
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
|
||||
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
|
||||
m.toggleItem()
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
} else {
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
@@ -591,6 +633,19 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
s.WriteString("\n")
|
||||
if item.Description != "" {
|
||||
s.WriteString(selectorDescLineStyle.Render(item.Description))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
@@ -602,7 +657,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
|
||||
}
|
||||
|
||||
suffix := ""
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
|
||||
suffix = " " + selectorDefaultTagStyle.Render("(default)")
|
||||
}
|
||||
|
||||
@@ -624,6 +679,11 @@ func (m multiSelectorModel) View() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
renderItem := m.renderSingleItem
|
||||
if m.multi {
|
||||
renderItem = m.renderMultiItem
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
@@ -648,7 +708,7 @@ func (m multiSelectorModel) View() string {
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
m.renderMultiItem(&s, filtered[idx], idx)
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
@@ -671,7 +731,7 @@ func (m multiSelectorModel) View() string {
|
||||
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
||||
s.WriteString("\n")
|
||||
for _, idx := range recItems {
|
||||
m.renderMultiItem(&s, filtered[idx], idx)
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -691,7 +751,7 @@ func (m multiSelectorModel) View() string {
|
||||
if idx >= len(otherItems) {
|
||||
break
|
||||
}
|
||||
m.renderMultiItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
}
|
||||
|
||||
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
@@ -703,15 +763,18 @@ func (m multiSelectorModel) View() string {
|
||||
|
||||
s.WriteString("\n")
|
||||
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
if !m.multi {
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
|
||||
|
||||
result := s.String()
|
||||
if m.width > 0 {
|
||||
@@ -734,18 +797,28 @@ func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]st
|
||||
}
|
||||
|
||||
fm := finalModel.(multiSelectorModel)
|
||||
if fm.cancelled {
|
||||
if fm.cancelled || !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
if !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
// Single-add mode: prepend the picked model, keep existing models deduped
|
||||
if fm.singleAdd != "" {
|
||||
result := []string{fm.singleAdd}
|
||||
for _, name := range preChecked {
|
||||
if name != fm.singleAdd {
|
||||
result = append(result, name)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var result []string
|
||||
// Multi-edit mode: last checked is default (first in result)
|
||||
last := fm.checkOrder[len(fm.checkOrder)-1]
|
||||
result := []string{fm.items[last].Name}
|
||||
for _, idx := range fm.checkOrder {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
if idx != last {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -382,6 +382,42 @@ func TestUpdateNavigation_Backspace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- cursorForCurrent ---
|
||||
|
||||
func TestCursorForCurrent(t *testing.T) {
|
||||
testItems := []SelectItem{
|
||||
{Name: "llama3.2", Recommended: true},
|
||||
{Name: "qwen3:8b", Recommended: true},
|
||||
{Name: "gemma3:latest"},
|
||||
{Name: "deepseek-r1"},
|
||||
{Name: "glm-5:cloud"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
current string
|
||||
want int
|
||||
}{
|
||||
{"empty current", "", 0},
|
||||
{"exact match", "qwen3:8b", 1},
|
||||
{"no match returns 0", "nonexistent", 0},
|
||||
{"bare name matches with :latest suffix", "gemma3", 2},
|
||||
{"full tag matches bare item", "llama3.2:latest", 0},
|
||||
{"cloud model exact match", "glm-5:cloud", 4},
|
||||
{"cloud model bare name", "glm-5", 4},
|
||||
{"recommended item exact match", "llama3.2", 0},
|
||||
{"recommended item with tag", "qwen3", 1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := cursorForCurrent(testItems, tt.current); got != tt.want {
|
||||
t.Errorf("cursorForCurrent(%q) = %d, want %d", tt.current, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReorderItems ---
|
||||
|
||||
func TestReorderItems(t *testing.T) {
|
||||
@@ -503,6 +539,7 @@ func TestMultiView_CursorIndicator(t *testing.T) {
|
||||
|
||||
func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
|
||||
if !strings.Contains(content, "[x]") {
|
||||
@@ -514,11 +551,18 @@ func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMultiView_DefaultTag(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b"})
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Error("first checked item should have (default) tag")
|
||||
t.Error("should have (default) tag")
|
||||
}
|
||||
// preChecked[0] ("a") should be the default (last in checkOrder)
|
||||
aIdx := strings.Index(content, "a")
|
||||
defaultIdx := strings.Index(content, "(default)")
|
||||
if defaultIdx < aIdx {
|
||||
t.Error("(default) tag should appear after 'a' (the current default)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -549,6 +593,7 @@ func TestMultiView_OverflowIndicator(t *testing.T) {
|
||||
|
||||
func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.multi = true
|
||||
m.cursor = 1
|
||||
|
||||
// Simulate space delivered as tea.KeySpace
|
||||
@@ -565,6 +610,7 @@ func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
|
||||
|
||||
func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.multi = true
|
||||
m.cursor = 1
|
||||
|
||||
// Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
|
||||
@@ -582,6 +628,161 @@ func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Single-add mode ---
|
||||
|
||||
func TestMulti_StartsInSingleMode(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
if m.multi {
|
||||
t.Error("should start in single mode (multi=false)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeNoCheckboxes(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
content := m.View()
|
||||
if strings.Contains(content, "[x]") || strings.Contains(content, "[ ]") {
|
||||
t.Error("single mode should not show checkboxes")
|
||||
}
|
||||
if !strings.Contains(content, "▸") {
|
||||
t.Error("single mode should show cursor indicator")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeEnterPicksItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.cursor = 1
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if m.singleAdd != "b" {
|
||||
t.Errorf("enter in single mode should pick cursor item, got %q", m.singleAdd)
|
||||
}
|
||||
if !m.confirmed {
|
||||
t.Error("should set confirmed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeSpaceIsNoop(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
m.cursor = 0
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if len(m.checked) != 0 {
|
||||
t.Error("space in single mode should not toggle items")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeSpaceRuneIsNoop(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
m.cursor = 0
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if len(m.checked) != 0 {
|
||||
t.Error("space rune in single mode should not toggle items")
|
||||
}
|
||||
if m.filter != "" {
|
||||
t.Error("space rune in single mode should not add to filter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_TabTogglesMode(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
|
||||
if m.multi {
|
||||
t.Fatal("should start in single mode")
|
||||
}
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
m = updated.(multiSelectorModel)
|
||||
if !m.multi {
|
||||
t.Error("tab should switch to multi mode")
|
||||
}
|
||||
|
||||
updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
m = updated.(multiSelectorModel)
|
||||
if m.multi {
|
||||
t.Error("tab should switch back to single mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeHelpText(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "tab add multiple") {
|
||||
t.Error("single mode should show 'tab add multiple' in help")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_MultiModeHelpText(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "tab select single") {
|
||||
t.Error("multi mode should show 'tab select single' in help")
|
||||
}
|
||||
}
|
||||
|
||||
// --- preChecked initialization order ---
|
||||
|
||||
func TestMulti_PreCheckedDefaultIsLast(t *testing.T) {
|
||||
// preChecked[0] ("a") is the current default and should end up
|
||||
// last in checkOrder so it gets the (default) tag.
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b", "c"})
|
||||
|
||||
if len(m.checkOrder) != 3 {
|
||||
t.Fatalf("expected 3 in checkOrder, got %d", len(m.checkOrder))
|
||||
}
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "a" {
|
||||
t.Errorf("preChecked[0] should be last in checkOrder, got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_CursorOnDefaultModel(t *testing.T) {
|
||||
// preChecked[0] ("b") is the default; cursor should start on it
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c"})
|
||||
|
||||
if m.cursor != 1 {
|
||||
t.Errorf("cursor should be on preChecked[0] ('b') at index 1, got %d", m.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Multi-mode last-checked is default ---
|
||||
|
||||
func TestMulti_LastCheckedIsDefault(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("alpha", "beta", "gamma"), nil)
|
||||
m.multi = true
|
||||
|
||||
// Check "alpha" then "gamma"
|
||||
m.cursor = 0
|
||||
m.toggleItem()
|
||||
m.cursor = 2
|
||||
m.toggleItem()
|
||||
|
||||
// Last checked ("gamma") should be at the end of checkOrder
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "gamma" {
|
||||
t.Errorf("last checked should be 'gamma', got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
|
||||
// The (default) tag renders based on checkOrder[len-1]
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Fatal("should show (default) tag")
|
||||
}
|
||||
// "alpha" line should NOT have the default tag
|
||||
for _, line := range strings.Split(content, "\n") {
|
||||
if strings.Contains(line, "alpha") && strings.Contains(line, "(default)") {
|
||||
t.Error("'alpha' (first checked) should not have (default) tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Key message helpers for testing
|
||||
|
||||
type keyType = int
|
||||
|
||||
@@ -429,8 +429,24 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
if m.multiModalSelector.confirmed {
|
||||
var selected []string
|
||||
for _, idx := range m.multiModalSelector.checkOrder {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
if m.multiModalSelector.singleAdd != "" {
|
||||
// Single-add mode: prepend picked model, keep existing deduped
|
||||
selected = []string{m.multiModalSelector.singleAdd}
|
||||
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
|
||||
if name != m.multiModalSelector.singleAdd {
|
||||
selected = append(selected, name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Last checked is default (first in result)
|
||||
co := m.multiModalSelector.checkOrder
|
||||
last := co[len(co)-1]
|
||||
selected = []string{m.multiModalSelector.items[last].Name}
|
||||
for _, idx := range co {
|
||||
if idx != last {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(selected) > 0 {
|
||||
m.changeModels = selected
|
||||
|
||||
@@ -106,20 +106,23 @@
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/index",
|
||||
{
|
||||
"group": "Assistants",
|
||||
"expanded": true,
|
||||
"pages": [
|
||||
"/integrations/openclaw"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Coding",
|
||||
"expanded": true,
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/codex",
|
||||
"/integrations/opencode",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Assistants",
|
||||
"pages": [
|
||||
"/integrations/openclaw"
|
||||
"/integrations/goose",
|
||||
"/integrations/pi"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -13,6 +13,7 @@ Coding assistants that can read, modify, and execute code in your projects.
|
||||
- [OpenCode](/integrations/opencode)
|
||||
- [Droid](/integrations/droid)
|
||||
- [Goose](/integrations/goose)
|
||||
- [Pi](/integrations/pi)
|
||||
|
||||
## Assistants
|
||||
|
||||
|
||||
57
docs/integrations/pi.mdx
Normal file
57
docs/integrations/pi.mdx
Normal file
@@ -0,0 +1,57 @@
|
||||
---
|
||||
title: Pi
|
||||
---
|
||||
|
||||
Pi is a minimal AI agent toolkit with plugin support.
|
||||
|
||||
## Install
|
||||
|
||||
Install [Pi](https://github.com/badlogic/pi-mono):
|
||||
|
||||
```bash
|
||||
npm install -g @mariozechner/pi-coding-agent
|
||||
```
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch pi
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch pi --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a configuration block to `~/.pi/agent/models.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{
|
||||
"id": "qwen3-coder"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Update `~/.pi/agent/settings.json` to set the default provider:
|
||||
|
||||
```json
|
||||
{
|
||||
"defaultProvider": "ollama",
|
||||
"defaultModel": "qwen3-coder"
|
||||
}
|
||||
```
|
||||
@@ -27,9 +27,17 @@ The menu provides quick access to:
|
||||
- **Launch tools** - Claude Code, Codex, OpenClaw, and more
|
||||
- **Additional integrations** - Available under "More..."
|
||||
|
||||
## Assistants
|
||||
|
||||
Launch [OpenClaw](/integrations/openclaw), a personal AI with 100+ skills:
|
||||
|
||||
```sh
|
||||
ollama launch openclaw
|
||||
```
|
||||
|
||||
## Coding
|
||||
|
||||
Launch coding tools with Ollama models:
|
||||
Launch [Claude Code](/integrations/claude-code) and other coding tools with Ollama models:
|
||||
|
||||
```sh
|
||||
ollama launch claude
|
||||
|
||||
@@ -45,6 +45,10 @@ func ParserForName(name string) Parser {
|
||||
var p Parser
|
||||
|
||||
switch name {
|
||||
case "qwen3":
|
||||
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
case "qwen3-thinking":
|
||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
case "qwen3-coder":
|
||||
p = &Qwen3CoderParser{}
|
||||
case "qwen3-vl-instruct":
|
||||
|
||||
@@ -54,6 +54,8 @@ func TestBuiltInParsersStillWork(t *testing.T) {
|
||||
name string
|
||||
}{
|
||||
{"passthrough"},
|
||||
{"qwen3"},
|
||||
{"qwen3-thinking"},
|
||||
{"qwen3-coder"},
|
||||
{"harmony"},
|
||||
}
|
||||
|
||||
335
model/parsers/qwen3.go
Normal file
335
model/parsers/qwen3.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type qwen3ParserState int
|
||||
|
||||
const (
|
||||
qwen3ParserStateLookingForThinkingOpen qwen3ParserState = iota
|
||||
qwen3ParserStateThinkingStartedEatingWhitespace
|
||||
qwen3ParserStateCollectingThinking
|
||||
qwen3ParserStateThinkingDoneEatingWhitespace
|
||||
qwen3ParserStateCollectingContent
|
||||
qwen3ParserStateToolStartedEatingWhitespace
|
||||
qwen3ParserStateCollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
qwen3ThinkingOpenTag = "<think>"
|
||||
qwen3ThinkingCloseTag = "</think>"
|
||||
qwen3ToolOpenTag = "<tool_call>"
|
||||
qwen3ToolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
// Qwen3Parser parses Qwen3 output to extract thinking and tool calls.
|
||||
// Qwen3 prompts end with <think> when thinking is enabled, so output begins
|
||||
// with thinking content directly (without an opening tag).
|
||||
type Qwen3Parser struct {
|
||||
state qwen3ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
hasThinkingSupport bool
|
||||
defaultThinking bool
|
||||
maybeThinkingOpenAtBOL bool
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.buffer.Reset()
|
||||
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
if thinkValue == nil {
|
||||
thinkingEnabled = p.defaultThinking
|
||||
}
|
||||
|
||||
if p.hasThinkingSupport && thinkingEnabled {
|
||||
p.state = qwen3ParserStateCollectingThinking
|
||||
p.maybeThinkingOpenAtBOL = true
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
type qwen3Event interface {
|
||||
isQwen3Event()
|
||||
}
|
||||
|
||||
type qwen3EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwen3EventContent) isQwen3Event() {}
|
||||
|
||||
type qwen3EventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
func (qwen3EventRawToolCall) isQwen3Event() {}
|
||||
|
||||
type qwen3EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwen3EventThinkingContent) isQwen3Event() {}
|
||||
|
||||
func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case qwen3EventRawToolCall:
|
||||
toolCall, err := parseQwen3ToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
calls = append(calls, toolCall)
|
||||
case qwen3EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case qwen3EventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), calls, nil
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) parseEvents() []qwen3Event {
|
||||
var all []qwen3Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []qwen3Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen3ParserState) ([]qwen3Event, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||
return splitAtTag(&p.buffer, tag, trimAfter)
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||
var events []qwen3Event
|
||||
|
||||
switch p.state {
|
||||
case qwen3ParserStateLookingForThinkingOpen:
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
|
||||
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
p.state = qwen3ParserStateThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingThinking
|
||||
}
|
||||
return events, true
|
||||
} else if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
|
||||
return events, false
|
||||
} else if trimmed == "" {
|
||||
return events, false
|
||||
}
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
return events, true
|
||||
|
||||
case qwen3ParserStateThinkingStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingThinking)
|
||||
|
||||
case qwen3ParserStateCollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
|
||||
// Some qwen3 checkpoints emit an explicit opening <think> tag even
|
||||
// though the prompt already ended with <think>. Strip exactly one
|
||||
// leading opening tag if present.
|
||||
if p.maybeThinkingOpenAtBOL {
|
||||
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
|
||||
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
return events, false
|
||||
}
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
return events, true
|
||||
}
|
||||
if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
|
||||
return events, false
|
||||
}
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
}
|
||||
|
||||
if strings.Contains(acc, qwen3ThinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = qwen3ParserStateThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case qwen3ParserStateThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingContent)
|
||||
|
||||
case qwen3ParserStateCollectingContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, qwen3ToolOpenTag) {
|
||||
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwen3EventContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = qwen3ParserStateToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, qwen3ToolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case qwen3ParserStateToolStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingToolContent)
|
||||
|
||||
case qwen3ParserStateCollectingToolContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, qwen3ToolCloseTag) {
|
||||
toolContent, _ := p.splitAtTag(qwen3ToolCloseTag, true)
|
||||
if len(toolContent) == 0 {
|
||||
slog.Warn("qwen3 tool call closing tag found but no content before it")
|
||||
}
|
||||
events = append(events, qwen3EventRawToolCall{raw: toolContent})
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
return events, true
|
||||
}
|
||||
return events, false
|
||||
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
var parsed struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse JSON: %w", err)
|
||||
}
|
||||
|
||||
if parsed.Name == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||
}
|
||||
|
||||
_ = tools // qwen3 uses direct JSON args and does not require schema coercion here.
|
||||
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: parsed.Name,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
|
||||
for key, value := range parsed.Arguments {
|
||||
toolCall.Function.Arguments.Set(key, value)
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
147
model/parsers/qwen3_test.go
Normal file
147
model/parsers/qwen3_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQwen3ParserThinkingEnabled(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Let me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingEnabledWithExplicitOpeningTag(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingEnabledWithSplitOpeningTag(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("<thi", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on first chunk: %v", err)
|
||||
}
|
||||
if content != "" || thinking != "" || len(calls) != 0 {
|
||||
t.Fatalf("expected no output for first chunk, got content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("nk>Let me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on second chunk: %v", err)
|
||||
}
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingDisabled(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Direct answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Direct answer" {
|
||||
t.Fatalf("expected content %q, got %q", "Direct answer", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserNilThinkDefaultsToContentForInstructParser(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add("Direct answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Direct answer" {
|
||||
t.Fatalf("expected content %q, got %q", "Direct answer", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCall(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
input := "<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
|
||||
location, ok := calls[0].Function.Arguments.Get("location")
|
||||
if !ok || location != "San Francisco" {
|
||||
t.Fatalf("expected location %q, got %v", "San Francisco", location)
|
||||
}
|
||||
unit, ok := calls[0].Function.Arguments.Get("unit")
|
||||
if !ok || unit != "celsius" {
|
||||
t.Fatalf("expected unit %q, got %v", "celsius", unit)
|
||||
}
|
||||
}
|
||||
@@ -2371,30 +2371,6 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: map[string]*runnerRef{
|
||||
"": {
|
||||
llama: &mock,
|
||||
Options: &opts,
|
||||
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
|
||||
isImagegen: true,
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create model manifest with image capability
|
||||
n := model.ParseName("test-image")
|
||||
cfg := model.ConfigV2{Capabilities: []string{"image"}}
|
||||
@@ -2410,6 +2386,35 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loadedModel, err := GetModel("test-image")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: map[string]*runnerRef{
|
||||
schedulerModelKey(loadedModel): {
|
||||
llama: &mock,
|
||||
Options: &opts,
|
||||
model: loadedModel,
|
||||
isImagegen: true,
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-image",
|
||||
|
||||
@@ -83,6 +83,28 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
||||
return sched
|
||||
}
|
||||
|
||||
// schedulerModelKey returns the scheduler map key for a model.
|
||||
// GGUF-backed models use ModelPath; safetensors/image models without a
|
||||
// ModelPath use manifest digest so distinct models don't collide.
|
||||
func schedulerModelKey(m *Model) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
if m.ModelPath != "" {
|
||||
return m.ModelPath
|
||||
}
|
||||
if m.Digest != "" {
|
||||
return "digest:" + m.Digest
|
||||
}
|
||||
if m.Name != "" {
|
||||
return "name:" + m.Name
|
||||
}
|
||||
if m.ShortName != "" {
|
||||
return "short:" + m.ShortName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
|
||||
if opts.NumCtx < 4 {
|
||||
@@ -104,8 +126,9 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
|
||||
useImagegen: useImagegen,
|
||||
}
|
||||
|
||||
key := schedulerModelKey(req.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[req.model.ModelPath]
|
||||
runner := s.loaded[key]
|
||||
s.loadedMu.Unlock()
|
||||
if runner != nil && !runner.needsReload(c, req) {
|
||||
req.useLoadedRunner(runner, s.finishedReqCh)
|
||||
@@ -151,8 +174,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
|
||||
for {
|
||||
var runnerToExpire *runnerRef
|
||||
pendingKey := schedulerModelKey(pending.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[pending.model.ModelPath]
|
||||
runner := s.loaded[pendingKey]
|
||||
loadedCount := len(s.loaded)
|
||||
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
@@ -166,7 +190,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
runnerToExpire = runner
|
||||
} else {
|
||||
// Runner is usable, return it
|
||||
logutil.Trace("using existing loaded runner", "model", pending.model.ModelPath)
|
||||
logutil.Trace("using existing loaded runner", "model", pendingKey)
|
||||
pending.useLoadedRunner(runner, s.finishedReqCh)
|
||||
break
|
||||
}
|
||||
@@ -292,11 +316,12 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
slog.Debug("shutting down scheduler completed loop")
|
||||
return
|
||||
case finished := <-s.finishedReqCh:
|
||||
finishedKey := schedulerModelKey(finished.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[finished.model.ModelPath]
|
||||
runner := s.loaded[finishedKey]
|
||||
s.loadedMu.Unlock()
|
||||
if runner == nil {
|
||||
slog.Error("finished request signal received after model unloaded", "modelPath", finished.model.ModelPath)
|
||||
slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
|
||||
continue
|
||||
}
|
||||
runner.refMu.Lock()
|
||||
@@ -347,7 +372,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
|
||||
s.loadedMu.Lock()
|
||||
slog.Debug("got lock to unload expired event", "runner", runner)
|
||||
runnerToUnload := s.loaded[runner.modelPath]
|
||||
runnerToUnload := s.loaded[runner.modelKey]
|
||||
if runnerToUnload == nil {
|
||||
// If runnerToUnload is nil, we already processed an event and
|
||||
// unloaded it. This double unload can happen if the initial
|
||||
@@ -376,7 +401,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
}
|
||||
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
|
||||
runner.unload()
|
||||
delete(s.loaded, runner.modelPath)
|
||||
delete(s.loaded, runner.modelKey)
|
||||
s.loadedMu.Unlock()
|
||||
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
|
||||
<-finished
|
||||
@@ -514,6 +539,7 @@ iGPUScan:
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
modelKey: schedulerModelKey(req.model),
|
||||
llama: llama,
|
||||
Options: &req.opts,
|
||||
sessionDuration: sessionDuration,
|
||||
@@ -528,7 +554,7 @@ iGPUScan:
|
||||
runner.refMu.Lock() // hold lock until running or aborted
|
||||
|
||||
s.loadedMu.Lock()
|
||||
if oldRunner, ok := s.loaded[req.model.ModelPath]; ok {
|
||||
if oldRunner, ok := s.loaded[runner.modelKey]; ok {
|
||||
// Shouldn't happen, but safeguard against leaking a runner
|
||||
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
|
||||
oldRunner.refMu.Lock()
|
||||
@@ -536,7 +562,7 @@ iGPUScan:
|
||||
oldRunner.refMu.Unlock()
|
||||
}
|
||||
s.activeLoading = nil
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loaded[runner.modelKey] = runner
|
||||
slog.Info("loaded runners", "count", len(s.loaded))
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
@@ -596,6 +622,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
modelKey: schedulerModelKey(req.model),
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
@@ -606,7 +633,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loaded[runner.modelKey] = runner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Set up expiration timer
|
||||
@@ -684,6 +711,7 @@ type runnerRef struct {
|
||||
|
||||
model *Model
|
||||
modelPath string
|
||||
modelKey string
|
||||
numParallel int
|
||||
*api.Options
|
||||
}
|
||||
@@ -703,7 +731,7 @@ func (runner *runnerRef) unload() {
|
||||
}
|
||||
|
||||
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
|
||||
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
|
||||
slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
|
||||
@@ -814,6 +842,10 @@ func (runner *runnerRef) LogValue() slog.Value {
|
||||
if runner == nil {
|
||||
return slog.StringValue("nil")
|
||||
}
|
||||
modelID := runner.modelPath
|
||||
if modelID == "" {
|
||||
modelID = runner.modelKey
|
||||
}
|
||||
attrs := []slog.Attr{}
|
||||
if runner.model != nil {
|
||||
attrs = append(attrs, slog.String("name", runner.model.Name))
|
||||
@@ -828,7 +860,7 @@ func (runner *runnerRef) LogValue() slog.Value {
|
||||
slog.String("vram", format.HumanBytes2(runner.vramSize)),
|
||||
slog.Int("parallel", runner.numParallel),
|
||||
slog.Int("pid", runner.pid),
|
||||
slog.String("model", runner.modelPath),
|
||||
slog.String("model", modelID),
|
||||
)
|
||||
if runner.Options != nil {
|
||||
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
|
||||
@@ -873,8 +905,16 @@ func (a ByDurationAndName) Less(i, j int) bool {
|
||||
if d1 != d2 {
|
||||
return d1 < d2
|
||||
}
|
||||
// Secondary sort by model path lex order
|
||||
return a[i].modelPath < a[j].modelPath
|
||||
// Secondary sort by model key/path lex order
|
||||
n1 := a[i].modelPath
|
||||
if n1 == "" {
|
||||
n1 = a[i].modelKey
|
||||
}
|
||||
n2 := a[j].modelPath
|
||||
if n2 == "" {
|
||||
n2 = a[j].modelKey
|
||||
}
|
||||
return n1 < n2
|
||||
}
|
||||
|
||||
// TODO - future consideration to pick runners based on size
|
||||
@@ -934,8 +974,9 @@ func (s *Scheduler) unloadAllRunners() {
|
||||
}
|
||||
|
||||
func (s *Scheduler) expireRunner(model *Model) {
|
||||
modelKey := schedulerModelKey(model)
|
||||
s.loadedMu.Lock()
|
||||
runner, ok := s.loaded[model.ModelPath]
|
||||
runner, ok := s.loaded[modelKey]
|
||||
s.loadedMu.Unlock()
|
||||
if ok {
|
||||
runner.refMu.Lock()
|
||||
|
||||
@@ -448,6 +448,71 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
b.ctxDone()
|
||||
}
|
||||
|
||||
func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
opts := api.DefaultOptions()
|
||||
opts.NumCtx = 4
|
||||
|
||||
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
|
||||
loadedRunner := &runnerRef{
|
||||
model: loadedModel,
|
||||
modelKey: schedulerModelKey(loadedModel),
|
||||
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
|
||||
Options: &opts,
|
||||
numParallel: 1,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[loadedRunner.modelKey] = loadedRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
|
||||
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
|
||||
|
||||
require.Empty(t, successCh)
|
||||
require.Empty(t, errCh)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
}
|
||||
|
||||
func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
opts := api.DefaultOptions()
|
||||
opts.NumCtx = 4
|
||||
|
||||
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
|
||||
loadedRunner := &runnerRef{
|
||||
model: loadedModel,
|
||||
modelKey: schedulerModelKey(loadedModel),
|
||||
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
|
||||
Options: &opts,
|
||||
numParallel: 1,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[loadedRunner.modelKey] = loadedRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
reqCtx, cancelReq := context.WithCancel(ctx)
|
||||
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
|
||||
cancelReq()
|
||||
|
||||
select {
|
||||
case runner := <-successCh:
|
||||
require.Equal(t, loadedRunner, runner)
|
||||
default:
|
||||
t.Fatal("expected existing runner to be reused")
|
||||
}
|
||||
|
||||
require.Empty(t, errCh)
|
||||
require.Empty(t, s.pendingReqCh)
|
||||
}
|
||||
|
||||
func TestSchedExpireRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
190
version/update.go
Normal file
190
version/update.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
)
|
||||
|
||||
var updateCheckURLBase = "https://ollama.com"
|
||||
|
||||
// CheckForUpdate calls the ollama.com update API and reports whether a
|
||||
// newer version is available.
|
||||
func CheckForUpdate(ctx context.Context) (bool, error) {
|
||||
requestURL, err := url.Parse(updateCheckURLBase + "/api/update")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parse update URL: %w", err)
|
||||
}
|
||||
|
||||
query := requestURL.Query()
|
||||
query.Add("os", runtime.GOOS)
|
||||
query.Add("arch", runtime.GOARCH)
|
||||
query.Add("version", Version)
|
||||
requestURL.RawQuery = query.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
_ = auth.SignRequest(ctx, req)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("update check request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
|
||||
func cacheFilePath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "update"), nil
|
||||
}
|
||||
|
||||
// CacheAvailableUpdate creates the update marker file.
|
||||
func CacheAvailableUpdate() error {
|
||||
path, err := cacheFilePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
|
||||
// HasCachedUpdate reports whether a non-stale update marker exists.
|
||||
func HasCachedUpdate() bool {
|
||||
path, err := cacheFilePath()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Since(fi.ModTime()) <= 24*time.Hour
|
||||
}
|
||||
|
||||
// ClearCachedUpdate removes the update marker file.
|
||||
func ClearCachedUpdate() error {
|
||||
path, err := cacheFilePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.Remove(path)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func IsOfficialInstall() bool {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
exe, err = filepath.EvalSymlinks(exe)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
if localAppData == "" {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(strings.ToLower(exe), strings.ToLower(filepath.Join(localAppData, "Programs", "Ollama")+string(filepath.Separator)))
|
||||
case "darwin":
|
||||
return strings.HasPrefix(exe, "/Applications/Ollama.app/")
|
||||
default:
|
||||
dir := filepath.Dir(exe)
|
||||
return dir == "/usr/local/bin" || dir == "/usr/bin" || dir == "/bin"
|
||||
}
|
||||
}
|
||||
|
||||
// DoUpdate downloads and runs the platform-appropriate install script.
|
||||
func DoUpdate(force bool) error {
|
||||
if !force && !IsOfficialInstall() {
|
||||
return fmt.Errorf("ollama appears to be installed through a package manager. Please update it using your package manager")
|
||||
}
|
||||
|
||||
var scriptURL, tmpPattern, shell string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
scriptURL = "https://ollama.com/install.ps1"
|
||||
tmpPattern = "ollama-install-*.ps1"
|
||||
shell = "powershell"
|
||||
default:
|
||||
scriptURL = "https://ollama.com/install.sh"
|
||||
tmpPattern = "ollama-install-*.sh"
|
||||
shell = "sh"
|
||||
}
|
||||
|
||||
resp, err := http.Get(scriptURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download install script: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download install script: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", tmpPattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
if _, err := io.Copy(tmpFile, resp.Body); err != nil {
|
||||
tmpFile.Close()
|
||||
return fmt.Errorf("write install script: %w", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
cmd := exec.Command(shell, tmpFile.Name())
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// IsLocalHost reports whether the configured Ollama host points to the
|
||||
// local machine.
|
||||
func IsLocalHost(host *url.URL) bool {
|
||||
hostname := host.Hostname()
|
||||
switch hostname {
|
||||
case "", "127.0.0.1", "localhost", "::1", "0.0.0.0":
|
||||
return true
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
return ip.IsLoopback()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
146
version/update_test.go
Normal file
146
version/update_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func setHome(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
} else {
|
||||
t.Setenv("HOME", dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckForUpdate(t *testing.T) {
|
||||
t.Run("update available", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("os") == "" || r.URL.Query().Get("arch") == "" || r.URL.Query().Get("version") == "" {
|
||||
t.Error("missing expected query parameters")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
old := updateCheckURLBase
|
||||
updateCheckURLBase = ts.URL
|
||||
defer func() { updateCheckURLBase = old }()
|
||||
|
||||
available, err := CheckForUpdate(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !available {
|
||||
t.Fatal("expected update to be available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("up to date", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
old := updateCheckURLBase
|
||||
updateCheckURLBase = ts.URL
|
||||
defer func() { updateCheckURLBase = old }()
|
||||
|
||||
available, err := CheckForUpdate(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if available {
|
||||
t.Fatal("expected no update available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("network error", func(t *testing.T) {
|
||||
old := updateCheckURLBase
|
||||
updateCheckURLBase = "http://localhost:1"
|
||||
defer func() { updateCheckURLBase = old }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := CheckForUpdate(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unreachable server")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCacheRoundTrip(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
setHome(t, tmp)
|
||||
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
|
||||
|
||||
if err := CacheAvailableUpdate(); err != nil {
|
||||
t.Fatalf("cache write: %v", err)
|
||||
}
|
||||
|
||||
if !HasCachedUpdate() {
|
||||
t.Fatal("expected cached update to be present")
|
||||
}
|
||||
|
||||
if err := ClearCachedUpdate(); err != nil {
|
||||
t.Fatalf("cache clear: %v", err)
|
||||
}
|
||||
|
||||
if HasCachedUpdate() {
|
||||
t.Fatal("expected no cached update after clear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasCachedUpdateStale(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
setHome(t, tmp)
|
||||
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
|
||||
|
||||
if err := CacheAvailableUpdate(); err != nil {
|
||||
t.Fatalf("cache write: %v", err)
|
||||
}
|
||||
|
||||
// Backdate the file to make it stale
|
||||
path := filepath.Join(tmp, ".ollama", "update")
|
||||
staleTime := time.Now().Add(-25 * time.Hour)
|
||||
os.Chtimes(path, staleTime, staleTime)
|
||||
|
||||
if HasCachedUpdate() {
|
||||
t.Fatal("expected no cached update for stale file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
host string
|
||||
local bool
|
||||
}{
|
||||
{"http://127.0.0.1:11434", true},
|
||||
{"http://localhost:11434", true},
|
||||
{"http://[::1]:11434", true},
|
||||
{"http://0.0.0.0:11434", true},
|
||||
{"http://remote.example.com:11434", false},
|
||||
{"http://192.168.1.100:11434", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.host, func(t *testing.T) {
|
||||
u, err := url.Parse(tt.host)
|
||||
if err != nil {
|
||||
t.Fatalf("parse URL: %v", err)
|
||||
}
|
||||
if got := IsLocalHost(u); got != tt.local {
|
||||
t.Errorf("IsLocalHost(%s) = %v, want %v", tt.host, got, tt.local)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,8 @@ type ModelfileConfig struct {
|
||||
Template string
|
||||
System string
|
||||
License string
|
||||
Parser string
|
||||
Renderer string
|
||||
}
|
||||
|
||||
// CreateOptions holds all options for model creation.
|
||||
@@ -37,7 +39,7 @@ type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile
|
||||
}
|
||||
|
||||
// CreateModel imports a model from a local directory.
|
||||
@@ -267,8 +269,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: caps,
|
||||
Requires: MinOllamaVersion,
|
||||
Parser: parserName,
|
||||
Renderer: rendererName,
|
||||
Parser: resolveParserName(opts.Modelfile, parserName),
|
||||
Renderer: resolveRendererName(opts.Modelfile, rendererName),
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
@@ -305,6 +307,22 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
|
||||
}
|
||||
}
|
||||
|
||||
func resolveParserName(mf *ModelfileConfig, inferred string) string {
|
||||
if mf != nil && mf.Parser != "" {
|
||||
return mf.Parser
|
||||
}
|
||||
|
||||
return inferred
|
||||
}
|
||||
|
||||
func resolveRendererName(mf *ModelfileConfig, inferred string) string {
|
||||
if mf != nil && mf.Renderer != "" {
|
||||
return mf.Renderer
|
||||
}
|
||||
|
||||
return inferred
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
var layers []manifest.Layer
|
||||
@@ -410,7 +428,7 @@ func getParserName(modelDir string) string {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
return "qwen3"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -424,7 +442,7 @@ func getParserName(modelDir string) string {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
return "qwen3"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ func TestModelfileConfig(t *testing.T) {
|
||||
Template: "{{ .Prompt }}",
|
||||
System: "You are a helpful assistant.",
|
||||
License: "MIT",
|
||||
Parser: "qwen3",
|
||||
Renderer: "qwen3",
|
||||
}
|
||||
|
||||
if config.Template != "{{ .Prompt }}" {
|
||||
@@ -21,6 +23,12 @@ func TestModelfileConfig(t *testing.T) {
|
||||
if config.License != "MIT" {
|
||||
t.Errorf("License = %q, want %q", config.License, "MIT")
|
||||
}
|
||||
if config.Parser != "qwen3" {
|
||||
t.Errorf("Parser = %q, want %q", config.Parser, "qwen3")
|
||||
}
|
||||
if config.Renderer != "qwen3" {
|
||||
t.Errorf("Renderer = %q, want %q", config.Renderer, "qwen3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_Empty(t *testing.T) {
|
||||
@@ -35,6 +43,12 @@ func TestModelfileConfig_Empty(t *testing.T) {
|
||||
if config.License != "" {
|
||||
t.Errorf("License should be empty, got %q", config.License)
|
||||
}
|
||||
if config.Parser != "" {
|
||||
t.Errorf("Parser should be empty, got %q", config.Parser)
|
||||
}
|
||||
if config.Renderer != "" {
|
||||
t.Errorf("Renderer should be empty, got %q", config.Renderer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
@@ -53,6 +67,12 @@ func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
if config.License != "" {
|
||||
t.Error("License should be empty")
|
||||
}
|
||||
if config.Parser != "" {
|
||||
t.Error("Parser should be empty")
|
||||
}
|
||||
if config.Renderer != "" {
|
||||
t.Error("Renderer should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinOllamaVersion(t *testing.T) {
|
||||
@@ -98,6 +118,8 @@ func TestCreateOptions(t *testing.T) {
|
||||
Template: "test",
|
||||
System: "system",
|
||||
License: "MIT",
|
||||
Parser: "qwen3-thinking",
|
||||
Renderer: "qwen3",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -116,6 +138,92 @@ func TestCreateOptions(t *testing.T) {
|
||||
if opts.Modelfile.Template != "test" {
|
||||
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
|
||||
}
|
||||
if opts.Modelfile.Parser != "qwen3-thinking" {
|
||||
t.Errorf("Modelfile.Parser = %q, want %q", opts.Modelfile.Parser, "qwen3-thinking")
|
||||
}
|
||||
if opts.Modelfile.Renderer != "qwen3" {
|
||||
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveParserName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mf *ModelfileConfig
|
||||
inferred string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil modelfile uses inferred",
|
||||
mf: nil,
|
||||
inferred: "qwen3",
|
||||
want: "qwen3",
|
||||
},
|
||||
{
|
||||
name: "empty parser uses inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Parser: "",
|
||||
},
|
||||
inferred: "qwen3",
|
||||
want: "qwen3",
|
||||
},
|
||||
{
|
||||
name: "explicit parser overrides inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Parser: "qwen3-thinking",
|
||||
},
|
||||
inferred: "qwen3",
|
||||
want: "qwen3-thinking",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveParserName(tt.mf, tt.inferred); got != tt.want {
|
||||
t.Fatalf("resolveParserName() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRendererName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mf *ModelfileConfig
|
||||
inferred string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil modelfile uses inferred",
|
||||
mf: nil,
|
||||
inferred: "qwen3-coder",
|
||||
want: "qwen3-coder",
|
||||
},
|
||||
{
|
||||
name: "empty renderer uses inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Renderer: "",
|
||||
},
|
||||
inferred: "qwen3-coder",
|
||||
want: "qwen3-coder",
|
||||
},
|
||||
{
|
||||
name: "explicit renderer overrides inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Renderer: "qwen3",
|
||||
},
|
||||
inferred: "qwen3-coder",
|
||||
want: "qwen3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveRendererName(tt.mf, tt.inferred); got != tt.want {
|
||||
t.Fatalf("resolveRendererName() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions_Defaults(t *testing.T) {
|
||||
|
||||
@@ -16,10 +16,10 @@ import (
|
||||
)
|
||||
|
||||
type Function struct {
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
NeedsARM64Guard bool
|
||||
}
|
||||
|
||||
@@ -29,6 +29,11 @@ func findHeaders(directory string) ([]string, error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Private headers contain C++ implementation helpers and are not part of
|
||||
// the C API surface; parsing them can produce invalid wrapper signatures.
|
||||
if d.IsDir() && d.Name() == "private" {
|
||||
return fs.SkipDir
|
||||
}
|
||||
if !d.IsDir() && strings.HasSuffix(path, ".h") {
|
||||
headers = append(headers, path)
|
||||
}
|
||||
@@ -194,10 +199,10 @@ func parseFunctions(content string) []Function {
|
||||
needsGuard := needsARM64Guard(funcName, returnType, params)
|
||||
|
||||
functions = append(functions, Function{
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
NeedsARM64Guard: needsGuard,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,8 @@ mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL;
|
||||
mlx_array (*mlx_array_new_double_ptr)(double val) = NULL;
|
||||
mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL;
|
||||
mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL;
|
||||
int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL;
|
||||
int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL;
|
||||
int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL;
|
||||
@@ -49,7 +51,7 @@ int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL;
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL;
|
||||
#endif
|
||||
@@ -67,7 +69,7 @@ const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL;
|
||||
const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL;
|
||||
const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL;
|
||||
const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL;
|
||||
const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
|
||||
const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL;
|
||||
#endif
|
||||
@@ -123,6 +125,7 @@ int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL;
|
||||
int (*mlx_disable_compile_ptr)(void) = NULL;
|
||||
int (*mlx_enable_compile_ptr)(void) = NULL;
|
||||
int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL;
|
||||
int (*mlx_cuda_is_available_ptr)(bool* res) = NULL;
|
||||
mlx_device (*mlx_device_new_ptr)(void) = NULL;
|
||||
mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL;
|
||||
int (*mlx_device_free_ptr)(mlx_device dev) = NULL;
|
||||
@@ -133,6 +136,16 @@ int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL;
|
||||
int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL;
|
||||
int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL;
|
||||
int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL;
|
||||
int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL;
|
||||
int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL;
|
||||
mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL;
|
||||
int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL;
|
||||
int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL;
|
||||
int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL;
|
||||
int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL;
|
||||
int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
||||
int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
||||
@@ -263,7 +276,6 @@ int (*mlx_reset_peak_memory_ptr)(void) = NULL;
|
||||
int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL;
|
||||
int (*mlx_metal_is_available_ptr)(bool* res) = NULL;
|
||||
int (*mlx_metal_start_capture_ptr)(const char* path) = NULL;
|
||||
int (*mlx_metal_stop_capture_ptr)(void) = NULL;
|
||||
@@ -658,6 +670,16 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed");
|
||||
if (mlx_array_new_data_managed_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload");
|
||||
if (mlx_array_new_data_managed_payload_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_set_ptr = dlsym(handle, "mlx_array_set");
|
||||
if (mlx_array_set_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n");
|
||||
@@ -1141,6 +1163,11 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available");
|
||||
if (mlx_cuda_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_new_ptr = dlsym(handle, "mlx_device_new");
|
||||
if (mlx_device_new_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n");
|
||||
@@ -1191,6 +1218,56 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available");
|
||||
if (mlx_device_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_count_ptr = dlsym(handle, "mlx_device_count");
|
||||
if (mlx_device_count_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new");
|
||||
if (mlx_device_info_new_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get");
|
||||
if (mlx_device_info_get_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free");
|
||||
if (mlx_device_info_free_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key");
|
||||
if (mlx_device_info_has_key_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string");
|
||||
if (mlx_device_info_is_string_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string");
|
||||
if (mlx_device_info_get_string_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size");
|
||||
if (mlx_device_info_get_size_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys");
|
||||
if (mlx_device_info_get_keys_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather");
|
||||
if (mlx_distributed_all_gather_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n");
|
||||
@@ -1841,11 +1918,6 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info");
|
||||
if (mlx_metal_device_info_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available");
|
||||
if (mlx_metal_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n");
|
||||
@@ -3528,6 +3600,14 @@ mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dt
|
||||
return mlx_array_new_data_ptr(data, shape, dim, dtype);
|
||||
}
|
||||
|
||||
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) {
|
||||
return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor);
|
||||
}
|
||||
|
||||
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) {
|
||||
return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor);
|
||||
}
|
||||
|
||||
int mlx_array_set(mlx_array* arr, const mlx_array src) {
|
||||
return mlx_array_set_ptr(arr, src);
|
||||
}
|
||||
@@ -3644,7 +3724,7 @@ int mlx_array_item_float64(double* res, const mlx_array arr) {
|
||||
return mlx_array_item_float64_ptr(res, arr);
|
||||
}
|
||||
|
||||
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) {
|
||||
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) {
|
||||
return mlx_array_item_complex64_ptr(res, arr);
|
||||
}
|
||||
|
||||
@@ -3704,7 +3784,7 @@ const double* mlx_array_data_float64(const mlx_array arr) {
|
||||
return mlx_array_data_float64_ptr(arr);
|
||||
}
|
||||
|
||||
const float _Complex* mlx_array_data_complex64(const mlx_array arr) {
|
||||
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) {
|
||||
return mlx_array_data_complex64_ptr(arr);
|
||||
}
|
||||
|
||||
@@ -3916,6 +3996,10 @@ int mlx_set_compile_mode(mlx_compile_mode mode) {
|
||||
return mlx_set_compile_mode_ptr(mode);
|
||||
}
|
||||
|
||||
int mlx_cuda_is_available(bool* res) {
|
||||
return mlx_cuda_is_available_ptr(res);
|
||||
}
|
||||
|
||||
mlx_device mlx_device_new(void) {
|
||||
return mlx_device_new_ptr();
|
||||
}
|
||||
@@ -3956,6 +4040,46 @@ int mlx_set_default_device(mlx_device dev) {
|
||||
return mlx_set_default_device_ptr(dev);
|
||||
}
|
||||
|
||||
int mlx_device_is_available(bool* avail, mlx_device dev) {
|
||||
return mlx_device_is_available_ptr(avail, dev);
|
||||
}
|
||||
|
||||
int mlx_device_count(int* count, mlx_device_type type) {
|
||||
return mlx_device_count_ptr(count, type);
|
||||
}
|
||||
|
||||
mlx_device_info mlx_device_info_new(void) {
|
||||
return mlx_device_info_new_ptr();
|
||||
}
|
||||
|
||||
int mlx_device_info_get(mlx_device_info* info, mlx_device dev) {
|
||||
return mlx_device_info_get_ptr(info, dev);
|
||||
}
|
||||
|
||||
int mlx_device_info_free(mlx_device_info info) {
|
||||
return mlx_device_info_free_ptr(info);
|
||||
}
|
||||
|
||||
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_has_key_ptr(exists, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_is_string_ptr(is_string, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_get_string_ptr(value, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_get_size_ptr(value, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) {
|
||||
return mlx_device_info_get_keys_ptr(keys, info);
|
||||
}
|
||||
|
||||
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) {
|
||||
return mlx_distributed_all_gather_ptr(res, x, group, S);
|
||||
}
|
||||
@@ -4476,10 +4600,6 @@ int mlx_set_wired_limit(size_t* res, size_t limit) {
|
||||
return mlx_set_wired_limit_ptr(res, limit);
|
||||
}
|
||||
|
||||
mlx_metal_device_info_t mlx_metal_device_info(void) {
|
||||
return mlx_metal_device_info_ptr();
|
||||
}
|
||||
|
||||
int mlx_metal_is_available(bool* res) {
|
||||
return mlx_metal_is_available_ptr(res);
|
||||
}
|
||||
|
||||
@@ -26,6 +26,8 @@
|
||||
#undef mlx_array_new_double
|
||||
#undef mlx_array_new_complex
|
||||
#undef mlx_array_new_data
|
||||
#undef mlx_array_new_data_managed
|
||||
#undef mlx_array_new_data_managed_payload
|
||||
#undef mlx_array_set
|
||||
#undef mlx_array_set_bool
|
||||
#undef mlx_array_set_int
|
||||
@@ -121,6 +123,7 @@
|
||||
#undef mlx_disable_compile
|
||||
#undef mlx_enable_compile
|
||||
#undef mlx_set_compile_mode
|
||||
#undef mlx_cuda_is_available
|
||||
#undef mlx_device_new
|
||||
#undef mlx_device_new_type
|
||||
#undef mlx_device_free
|
||||
@@ -131,6 +134,16 @@
|
||||
#undef mlx_device_get_type
|
||||
#undef mlx_get_default_device
|
||||
#undef mlx_set_default_device
|
||||
#undef mlx_device_is_available
|
||||
#undef mlx_device_count
|
||||
#undef mlx_device_info_new
|
||||
#undef mlx_device_info_get
|
||||
#undef mlx_device_info_free
|
||||
#undef mlx_device_info_has_key
|
||||
#undef mlx_device_info_is_string
|
||||
#undef mlx_device_info_get_string
|
||||
#undef mlx_device_info_get_size
|
||||
#undef mlx_device_info_get_keys
|
||||
#undef mlx_distributed_all_gather
|
||||
#undef mlx_distributed_all_max
|
||||
#undef mlx_distributed_all_min
|
||||
@@ -261,7 +274,6 @@
|
||||
#undef mlx_set_cache_limit
|
||||
#undef mlx_set_memory_limit
|
||||
#undef mlx_set_wired_limit
|
||||
#undef mlx_metal_device_info
|
||||
#undef mlx_metal_is_available
|
||||
#undef mlx_metal_start_capture
|
||||
#undef mlx_metal_stop_capture
|
||||
@@ -602,6 +614,8 @@ extern mlx_array (*mlx_array_new_float64_ptr)(double val);
|
||||
extern mlx_array (*mlx_array_new_double_ptr)(double val);
|
||||
extern mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val);
|
||||
extern mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype);
|
||||
extern mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
|
||||
extern mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
|
||||
extern int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src);
|
||||
extern int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val);
|
||||
extern int (*mlx_array_set_int_ptr)(mlx_array* arr, int val);
|
||||
@@ -631,7 +645,7 @@ extern int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr);
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
extern int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr);
|
||||
#endif
|
||||
@@ -649,7 +663,7 @@ extern const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr);
|
||||
extern const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr);
|
||||
extern const float* (*mlx_array_data_float32_ptr)(const mlx_array arr);
|
||||
extern const double* (*mlx_array_data_float64_ptr)(const mlx_array arr);
|
||||
extern const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
|
||||
extern const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
extern const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr);
|
||||
#endif
|
||||
@@ -705,6 +719,7 @@ extern int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id);
|
||||
extern int (*mlx_disable_compile_ptr)(void);
|
||||
extern int (*mlx_enable_compile_ptr)(void);
|
||||
extern int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode);
|
||||
extern int (*mlx_cuda_is_available_ptr)(bool* res);
|
||||
extern mlx_device (*mlx_device_new_ptr)(void);
|
||||
extern mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index);
|
||||
extern int (*mlx_device_free_ptr)(mlx_device dev);
|
||||
@@ -715,6 +730,16 @@ extern int (*mlx_device_get_index_ptr)(int* index, mlx_device dev);
|
||||
extern int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev);
|
||||
extern int (*mlx_get_default_device_ptr)(mlx_device* dev);
|
||||
extern int (*mlx_set_default_device_ptr)(mlx_device dev);
|
||||
extern int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev);
|
||||
extern int (*mlx_device_count_ptr)(int* count, mlx_device_type type);
|
||||
extern mlx_device_info (*mlx_device_info_new_ptr)(void);
|
||||
extern int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev);
|
||||
extern int (*mlx_device_info_free_ptr)(mlx_device_info info);
|
||||
extern int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info);
|
||||
extern int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
|
||||
extern int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
extern int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
@@ -845,7 +870,6 @@ extern int (*mlx_reset_peak_memory_ptr)(void);
|
||||
extern int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit);
|
||||
extern int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit);
|
||||
extern int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit);
|
||||
extern mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void);
|
||||
extern int (*mlx_metal_is_available_ptr)(bool* res);
|
||||
extern int (*mlx_metal_start_capture_ptr)(const char* path);
|
||||
extern int (*mlx_metal_stop_capture_ptr)(void);
|
||||
@@ -1202,6 +1226,10 @@ mlx_array mlx_array_new_complex(float real_val, float imag_val);
|
||||
|
||||
mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype);
|
||||
|
||||
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
|
||||
|
||||
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
|
||||
|
||||
int mlx_array_set(mlx_array* arr, const mlx_array src);
|
||||
|
||||
int mlx_array_set_bool(mlx_array* arr, bool val);
|
||||
@@ -1260,7 +1288,7 @@ int mlx_array_item_float32(float* res, const mlx_array arr);
|
||||
|
||||
int mlx_array_item_float64(double* res, const mlx_array arr);
|
||||
|
||||
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr);
|
||||
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
|
||||
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
|
||||
@@ -1292,7 +1320,7 @@ const float* mlx_array_data_float32(const mlx_array arr);
|
||||
|
||||
const double* mlx_array_data_float64(const mlx_array arr);
|
||||
|
||||
const float _Complex* mlx_array_data_complex64(const mlx_array arr);
|
||||
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
|
||||
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
const float16_t* mlx_array_data_float16(const mlx_array arr);
|
||||
@@ -1400,6 +1428,8 @@ int mlx_enable_compile(void);
|
||||
|
||||
int mlx_set_compile_mode(mlx_compile_mode mode);
|
||||
|
||||
int mlx_cuda_is_available(bool* res);
|
||||
|
||||
mlx_device mlx_device_new(void);
|
||||
|
||||
mlx_device mlx_device_new_type(mlx_device_type type, int index);
|
||||
@@ -1420,6 +1450,26 @@ int mlx_get_default_device(mlx_device* dev);
|
||||
|
||||
int mlx_set_default_device(mlx_device dev);
|
||||
|
||||
int mlx_device_is_available(bool* avail, mlx_device dev);
|
||||
|
||||
int mlx_device_count(int* count, mlx_device_type type);
|
||||
|
||||
mlx_device_info mlx_device_info_new(void);
|
||||
|
||||
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
|
||||
|
||||
int mlx_device_info_free(mlx_device_info info);
|
||||
|
||||
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
|
||||
|
||||
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
|
||||
|
||||
int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
@@ -1680,8 +1730,6 @@ int mlx_set_memory_limit(size_t* res, size_t limit);
|
||||
|
||||
int mlx_set_wired_limit(size_t* res, size_t limit);
|
||||
|
||||
mlx_metal_device_info_t mlx_metal_device_info(void);
|
||||
|
||||
int mlx_metal_is_available(bool* res);
|
||||
|
||||
int mlx_metal_start_capture(const char* path);
|
||||
|
||||
@@ -3,5 +3,8 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/x/models/gemma3"
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
|
||||
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
|
||||
@@ -22,6 +22,19 @@ mlx_array (*mlx_array_new_data_)(
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_)(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void (*dtor)(void*)) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_payload_)(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void* payload,
|
||||
void (*dtor)(void*)) = NULL;
|
||||
int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL;
|
||||
int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL;
|
||||
int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL;
|
||||
@@ -56,7 +69,7 @@ int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL;
|
||||
const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL;
|
||||
@@ -70,7 +83,7 @@ const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL;
|
||||
const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL;
|
||||
const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL;
|
||||
const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL;
|
||||
const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
|
||||
const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
|
||||
const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL;
|
||||
const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL;
|
||||
int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL;
|
||||
@@ -94,10 +107,11 @@ int (*mlx_closure_apply_)(
|
||||
mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL;
|
||||
int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -136,11 +150,12 @@ int (*mlx_closure_value_and_grad_apply_)(
|
||||
const mlx_vector_array input) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array)) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array)) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -161,12 +176,13 @@ int (*mlx_closure_custom_apply_)(
|
||||
const mlx_vector_array input_2) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -189,12 +205,13 @@ int (*mlx_closure_custom_jvp_apply_)(
|
||||
size_t input_2_num) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -228,6 +245,7 @@ int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL;
|
||||
int (*mlx_disable_compile_)(void) = NULL;
|
||||
int (*mlx_enable_compile_)(void) = NULL;
|
||||
int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL;
|
||||
int (*mlx_cuda_is_available_)(bool* res) = NULL;
|
||||
mlx_device (*mlx_device_new_)(void) = NULL;
|
||||
mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL;
|
||||
int (*mlx_device_free_)(mlx_device dev) = NULL;
|
||||
@@ -238,11 +256,28 @@ int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL;
|
||||
int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL;
|
||||
int (*mlx_get_default_device_)(mlx_device* dev) = NULL;
|
||||
int (*mlx_set_default_device_)(mlx_device dev) = NULL;
|
||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
||||
int (*mlx_device_is_available_)(bool* avail, mlx_device dev) = NULL;
|
||||
int (*mlx_device_count_)(int* count, mlx_device_type type) = NULL;
|
||||
mlx_device_info (*mlx_device_info_new_)(void) = NULL;
|
||||
int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev) = NULL;
|
||||
int (*mlx_device_info_free_)(mlx_device_info info) = NULL;
|
||||
int (*mlx_device_info_has_key_)(
|
||||
bool* exists,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_is_string_)(
|
||||
bool* is_string,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_string_)(
|
||||
const char** value,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_size_)(
|
||||
size_t* value,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info) = NULL;
|
||||
int (*mlx_distributed_all_gather_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
@@ -288,6 +323,11 @@ int (*mlx_distributed_sum_scatter_)(
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
||||
void (*mlx_set_error_handler_)(
|
||||
mlx_error_handler_func handler,
|
||||
void* data,
|
||||
@@ -450,6 +490,16 @@ int (*mlx_fast_rope_)(
|
||||
int offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_fast_rope_dynamic_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
mlx_optional_float base,
|
||||
float scale,
|
||||
const mlx_array offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_fast_scaled_dot_product_attention_)(
|
||||
mlx_array* res,
|
||||
const mlx_array queries,
|
||||
@@ -560,14 +610,6 @@ int (*mlx_fft_rfftn_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
|
||||
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
|
||||
int (*mlx_load_reader_)(
|
||||
mlx_array* res,
|
||||
mlx_io_reader in_stream,
|
||||
@@ -593,6 +635,14 @@ int (*mlx_save_safetensors_)(
|
||||
const char* file,
|
||||
const mlx_map_string_to_array param,
|
||||
const mlx_map_string_to_string metadata) = NULL;
|
||||
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
|
||||
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
|
||||
int (*mlx_linalg_cholesky_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -733,7 +783,6 @@ int (*mlx_reset_peak_memory_)(void) = NULL;
|
||||
int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL;
|
||||
mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL;
|
||||
int (*mlx_metal_is_available_)(bool* res) = NULL;
|
||||
int (*mlx_metal_start_capture_)(const char* path) = NULL;
|
||||
int (*mlx_metal_stop_capture_)(void) = NULL;
|
||||
@@ -1162,6 +1211,14 @@ int (*mlx_gather_)(
|
||||
const int* slice_sizes,
|
||||
size_t slice_sizes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_gather_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
int axis,
|
||||
const int* slice_sizes,
|
||||
size_t slice_sizes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_gather_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1483,6 +1540,15 @@ int (*mlx_put_along_axis_)(
|
||||
const mlx_array values,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_qqmm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_array w,
|
||||
const mlx_array w_scales /* may be null */,
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_quantize_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_array w,
|
||||
@@ -1566,6 +1632,13 @@ int (*mlx_scatter_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1574,6 +1647,13 @@ int (*mlx_scatter_add_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_axis_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1589,6 +1669,13 @@ int (*mlx_scatter_max_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_max_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_min_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1597,6 +1684,13 @@ int (*mlx_scatter_min_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_min_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_prod_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1605,6 +1699,13 @@ int (*mlx_scatter_prod_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_prod_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_segmented_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -2028,22 +2129,6 @@ mlx_string (*mlx_string_new_data_)(const char* str) = NULL;
|
||||
int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL;
|
||||
const char * (*mlx_string_data_)(mlx_string str) = NULL;
|
||||
int (*mlx_string_free_)(mlx_string str) = NULL;
|
||||
int (*mlx_detail_vmap_replace_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_vector_array s_inputs,
|
||||
const mlx_vector_array s_outputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num,
|
||||
const int* out_axes,
|
||||
size_t out_axes_num) = NULL;
|
||||
int (*mlx_detail_vmap_trace_)(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array inputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num) = NULL;
|
||||
int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL;
|
||||
int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL;
|
||||
int (*mlx_custom_function_)(
|
||||
@@ -2074,6 +2159,22 @@ int (*mlx_vjp_)(
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array primals,
|
||||
const mlx_vector_array cotangents) = NULL;
|
||||
int (*mlx_detail_vmap_replace_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_vector_array s_inputs,
|
||||
const mlx_vector_array s_outputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num,
|
||||
const int* out_axes,
|
||||
size_t out_axes_num) = NULL;
|
||||
int (*mlx_detail_vmap_trace_)(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array inputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num) = NULL;
|
||||
mlx_vector_array (*mlx_vector_array_new_)(void) = NULL;
|
||||
int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL;
|
||||
int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL;
|
||||
@@ -2166,6 +2267,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_array_new_double);
|
||||
CHECK_LOAD(handle, mlx_array_new_complex);
|
||||
CHECK_LOAD(handle, mlx_array_new_data);
|
||||
CHECK_LOAD(handle, mlx_array_new_data_managed);
|
||||
CHECK_LOAD(handle, mlx_array_new_data_managed_payload);
|
||||
CHECK_LOAD(handle, mlx_array_set);
|
||||
CHECK_LOAD(handle, mlx_array_set_bool);
|
||||
CHECK_LOAD(handle, mlx_array_set_int);
|
||||
@@ -2261,6 +2364,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_disable_compile);
|
||||
CHECK_LOAD(handle, mlx_enable_compile);
|
||||
CHECK_LOAD(handle, mlx_set_compile_mode);
|
||||
CHECK_LOAD(handle, mlx_cuda_is_available);
|
||||
CHECK_LOAD(handle, mlx_device_new);
|
||||
CHECK_LOAD(handle, mlx_device_new_type);
|
||||
CHECK_LOAD(handle, mlx_device_free);
|
||||
@@ -2271,11 +2375,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_device_get_type);
|
||||
CHECK_LOAD(handle, mlx_get_default_device);
|
||||
CHECK_LOAD(handle, mlx_set_default_device);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_rank);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_size);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_split);
|
||||
CHECK_LOAD(handle, mlx_distributed_is_available);
|
||||
CHECK_LOAD(handle, mlx_distributed_init);
|
||||
CHECK_LOAD(handle, mlx_device_is_available);
|
||||
CHECK_LOAD(handle, mlx_device_count);
|
||||
CHECK_LOAD(handle, mlx_device_info_new);
|
||||
CHECK_LOAD(handle, mlx_device_info_get);
|
||||
CHECK_LOAD(handle, mlx_device_info_free);
|
||||
CHECK_LOAD(handle, mlx_device_info_has_key);
|
||||
CHECK_LOAD(handle, mlx_device_info_is_string);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_string);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_size);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_keys);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_gather);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_max);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_min);
|
||||
@@ -2284,6 +2393,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_distributed_recv_like);
|
||||
CHECK_LOAD(handle, mlx_distributed_send);
|
||||
CHECK_LOAD(handle, mlx_distributed_sum_scatter);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_rank);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_size);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_split);
|
||||
CHECK_LOAD(handle, mlx_distributed_is_available);
|
||||
CHECK_LOAD(handle, mlx_distributed_init);
|
||||
CHECK_LOAD(handle, mlx_set_error_handler);
|
||||
CHECK_LOAD(handle, _mlx_error);
|
||||
CHECK_LOAD(handle, mlx_export_function);
|
||||
@@ -2325,6 +2439,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_fast_metal_kernel_apply);
|
||||
CHECK_LOAD(handle, mlx_fast_rms_norm);
|
||||
CHECK_LOAD(handle, mlx_fast_rope);
|
||||
CHECK_LOAD(handle, mlx_fast_rope_dynamic);
|
||||
CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention);
|
||||
CHECK_LOAD(handle, mlx_fft_fft);
|
||||
CHECK_LOAD(handle, mlx_fft_fft2);
|
||||
@@ -2340,14 +2455,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_fft_rfft);
|
||||
CHECK_LOAD(handle, mlx_fft_rfft2);
|
||||
CHECK_LOAD(handle, mlx_fft_rfftn);
|
||||
CHECK_LOAD(handle, mlx_io_reader_new);
|
||||
CHECK_LOAD(handle, mlx_io_reader_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_reader_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_reader_free);
|
||||
CHECK_LOAD(handle, mlx_io_writer_new);
|
||||
CHECK_LOAD(handle, mlx_io_writer_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_writer_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_writer_free);
|
||||
CHECK_LOAD(handle, mlx_load_reader);
|
||||
CHECK_LOAD(handle, mlx_load);
|
||||
CHECK_LOAD(handle, mlx_load_safetensors_reader);
|
||||
@@ -2356,6 +2463,14 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_save);
|
||||
CHECK_LOAD(handle, mlx_save_safetensors_writer);
|
||||
CHECK_LOAD(handle, mlx_save_safetensors);
|
||||
CHECK_LOAD(handle, mlx_io_reader_new);
|
||||
CHECK_LOAD(handle, mlx_io_reader_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_reader_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_reader_free);
|
||||
CHECK_LOAD(handle, mlx_io_writer_new);
|
||||
CHECK_LOAD(handle, mlx_io_writer_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_writer_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_writer_free);
|
||||
CHECK_LOAD(handle, mlx_linalg_cholesky);
|
||||
CHECK_LOAD(handle, mlx_linalg_cholesky_inv);
|
||||
CHECK_LOAD(handle, mlx_linalg_cross);
|
||||
@@ -2400,7 +2515,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_set_cache_limit);
|
||||
CHECK_LOAD(handle, mlx_set_memory_limit);
|
||||
CHECK_LOAD(handle, mlx_set_wired_limit);
|
||||
CHECK_LOAD(handle, mlx_metal_device_info);
|
||||
CHECK_LOAD(handle, mlx_metal_is_available);
|
||||
CHECK_LOAD(handle, mlx_metal_start_capture);
|
||||
CHECK_LOAD(handle, mlx_metal_stop_capture);
|
||||
@@ -2486,6 +2600,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_full);
|
||||
CHECK_LOAD(handle, mlx_full_like);
|
||||
CHECK_LOAD(handle, mlx_gather);
|
||||
CHECK_LOAD(handle, mlx_gather_single);
|
||||
CHECK_LOAD(handle, mlx_gather_mm);
|
||||
CHECK_LOAD(handle, mlx_gather_qmm);
|
||||
CHECK_LOAD(handle, mlx_greater);
|
||||
@@ -2550,6 +2665,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_prod_axis);
|
||||
CHECK_LOAD(handle, mlx_prod);
|
||||
CHECK_LOAD(handle, mlx_put_along_axis);
|
||||
CHECK_LOAD(handle, mlx_qqmm);
|
||||
CHECK_LOAD(handle, mlx_quantize);
|
||||
CHECK_LOAD(handle, mlx_quantized_matmul);
|
||||
CHECK_LOAD(handle, mlx_radians);
|
||||
@@ -2566,11 +2682,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_round);
|
||||
CHECK_LOAD(handle, mlx_rsqrt);
|
||||
CHECK_LOAD(handle, mlx_scatter);
|
||||
CHECK_LOAD(handle, mlx_scatter_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_add);
|
||||
CHECK_LOAD(handle, mlx_scatter_add_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_add_axis);
|
||||
CHECK_LOAD(handle, mlx_scatter_max);
|
||||
CHECK_LOAD(handle, mlx_scatter_max_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_min);
|
||||
CHECK_LOAD(handle, mlx_scatter_min_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_prod);
|
||||
CHECK_LOAD(handle, mlx_scatter_prod_single);
|
||||
CHECK_LOAD(handle, mlx_segmented_mm);
|
||||
CHECK_LOAD(handle, mlx_sigmoid);
|
||||
CHECK_LOAD(handle, mlx_sign);
|
||||
@@ -2665,8 +2786,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_string_set);
|
||||
CHECK_LOAD(handle, mlx_string_data);
|
||||
CHECK_LOAD(handle, mlx_string_free);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_replace);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_trace);
|
||||
CHECK_LOAD(handle, mlx_async_eval);
|
||||
CHECK_LOAD(handle, mlx_checkpoint);
|
||||
CHECK_LOAD(handle, mlx_custom_function);
|
||||
@@ -2675,6 +2794,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_jvp);
|
||||
CHECK_LOAD(handle, mlx_value_and_grad);
|
||||
CHECK_LOAD(handle, mlx_vjp);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_replace);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_trace);
|
||||
CHECK_LOAD(handle, mlx_vector_array_new);
|
||||
CHECK_LOAD(handle, mlx_vector_array_set);
|
||||
CHECK_LOAD(handle, mlx_vector_array_free);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,10 @@
|
||||
#define MLX_GENERATED_H
|
||||
|
||||
#include "dynamic.h"
|
||||
{{ range .Functions }}
|
||||
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
|
||||
{{- end }}
|
||||
|
||||
#include "mlx/c/mlx.h"
|
||||
{{ range .Functions }}
|
||||
#undef {{ .Name }}
|
||||
|
||||
92
x/mlxrunner/model/linear.go
Normal file
92
x/mlxrunner/model/linear.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build mlx
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// LinearFactory builds linear layers using shared tensor maps and quant defaults.
|
||||
type LinearFactory struct {
|
||||
tensors map[string]*mlx.Array
|
||||
defaultGroupSize int
|
||||
defaultBits int
|
||||
defaultMode string
|
||||
tensorQuant map[string]*TensorQuantInfo
|
||||
}
|
||||
|
||||
// NewLinearFactory creates a reusable constructor for model linear layers.
|
||||
func NewLinearFactory(
|
||||
tensors map[string]*mlx.Array,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) LinearFactory {
|
||||
return LinearFactory{
|
||||
tensors: tensors,
|
||||
defaultGroupSize: defaultGroupSize,
|
||||
defaultBits: defaultBits,
|
||||
defaultMode: defaultMode,
|
||||
tensorQuant: tensorQuant,
|
||||
}
|
||||
}
|
||||
|
||||
// Make constructs a linear layer at path.
|
||||
func (f LinearFactory) Make(path string) nn.LinearLayer {
|
||||
return MakeLinearLayer(
|
||||
f.tensors,
|
||||
path,
|
||||
f.defaultGroupSize,
|
||||
f.defaultBits,
|
||||
f.defaultMode,
|
||||
f.tensorQuant,
|
||||
)
|
||||
}
|
||||
|
||||
// MakeLinearLayer constructs a linear layer from a tensor map.
|
||||
//
|
||||
// For quantized tensors (path.weight + path.weight_scale), it resolves per-tensor
|
||||
// quant params via TensorQuant metadata (with shape-based affine fallback).
|
||||
// For non-quantized tensors, it returns a standard nn.Linear.
|
||||
func MakeLinearLayer(
|
||||
tensors map[string]*mlx.Array,
|
||||
path string,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) nn.LinearLayer {
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scales := tensors[path+".weight_scale"]
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
bias := tensors[path+".bias"]
|
||||
|
||||
groupSize, bits, mode := ResolveLinearQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
path+".weight",
|
||||
w,
|
||||
scales,
|
||||
)
|
||||
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: w,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
bias := tensors[path+".bias"]
|
||||
return nn.NewLinear(w, bias)
|
||||
}
|
||||
130
x/mlxrunner/model/quant.go
Normal file
130
x/mlxrunner/model/quant.go
Normal file
@@ -0,0 +1,130 @@
|
||||
//go:build mlx
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// QuantizationParams returns default groupSize, bits, and mode for a quantization type.
|
||||
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "NVFP4":
|
||||
return 16, 4, "nvfp4"
|
||||
case "FP4", "Q4", "INT4":
|
||||
return 32, 4, "affine"
|
||||
case "MXFP8":
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8", "":
|
||||
return 64, 8, "affine"
|
||||
default:
|
||||
return 32, 8, "affine"
|
||||
}
|
||||
}
|
||||
|
||||
// TensorQuantParams resolves quant params for a tensor using per-tensor metadata
|
||||
// when available, otherwise falling back to the provided model defaults.
|
||||
func TensorQuantParams(
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
tensorName string,
|
||||
) (groupSize, bits int, mode string, fromTensor bool) {
|
||||
if tensorQuant != nil {
|
||||
if tq := tensorQuant[tensorName]; tq != nil {
|
||||
groupSize, bits, mode = QuantizationParams(tq.QuantType)
|
||||
if tq.GroupSize > 0 {
|
||||
groupSize = tq.GroupSize
|
||||
}
|
||||
return groupSize, bits, mode, true
|
||||
}
|
||||
}
|
||||
return defaultGroupSize, defaultBits, defaultMode, false
|
||||
}
|
||||
|
||||
// ResolveLinearQuantParams resolves quantization params for a quantized linear
|
||||
// tensor, preferring per-tensor metadata and falling back to shape-based
|
||||
// inference for affine packed tensors.
|
||||
func ResolveLinearQuantParams(
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
tensorName string,
|
||||
weight, scales *mlx.Array,
|
||||
) (groupSize, bits int, mode string) {
|
||||
groupSize, bits, mode, fromTensor := TensorQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
tensorName,
|
||||
)
|
||||
|
||||
if mode == "affine" {
|
||||
if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok {
|
||||
if !fromTensor || groupSize == 0 || bits == 0 {
|
||||
groupSize = inferredGroupSize
|
||||
bits = inferredBits
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return groupSize, bits, mode
|
||||
}
|
||||
|
||||
// InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized
|
||||
// tensors from packed weight and scale shapes.
|
||||
func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) {
|
||||
if weight == nil || scales == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
weightShape := weight.Dims()
|
||||
scaleShape := scales.Dims()
|
||||
if len(weightShape) == 0 || len(scaleShape) == 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
weightCols := weightShape[len(weightShape)-1]
|
||||
scalesCols := scaleShape[len(scaleShape)-1]
|
||||
if weightCols <= 0 || scalesCols <= 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
groupSize4 := weightCols * 8 / scalesCols
|
||||
groupSize8 := weightCols * 4 / scalesCols
|
||||
|
||||
switch {
|
||||
case groupSize4 == 32:
|
||||
return 32, 4, true
|
||||
case groupSize8 == 64:
|
||||
return 64, 8, true
|
||||
case groupSize4 == 64 && groupSize8 == 32:
|
||||
if hintBits == 8 {
|
||||
return 32, 8, true
|
||||
}
|
||||
if hintBits == 4 {
|
||||
return 64, 4, true
|
||||
}
|
||||
}
|
||||
|
||||
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
|
||||
return groupSize4, 4, true
|
||||
}
|
||||
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
|
||||
return groupSize8, 8, true
|
||||
}
|
||||
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func isCommonGroupSize(v int) bool {
|
||||
switch v {
|
||||
case 16, 32, 64, 128:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -8,42 +8,63 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// Root wraps a ModelManifest with pre-scanned quantization metadata.
|
||||
type Root struct {
|
||||
Manifest *manifest.ModelManifest
|
||||
quantType string
|
||||
groupSize int
|
||||
// TensorQuantInfo describes per-tensor quantization metadata.
|
||||
type TensorQuantInfo struct {
|
||||
QuantType string
|
||||
GroupSize int
|
||||
}
|
||||
|
||||
// Open loads a manifest for the given model name and pre-scans the first
|
||||
// tensor blob for quantization metadata (quant_type, group_size).
|
||||
// Root wraps a ModelManifest with pre-scanned quantization metadata.
|
||||
type Root struct {
|
||||
Manifest *manifest.ModelManifest
|
||||
|
||||
// Backwards-compatible model-level quant metadata (first tensor blob).
|
||||
quantType string
|
||||
groupSize int
|
||||
|
||||
// Per-tensor quantization metadata.
|
||||
tensorQuant map[string]*TensorQuantInfo
|
||||
}
|
||||
|
||||
// Open loads a manifest for the given model name and scans tensor blobs for
|
||||
// quantization metadata.
|
||||
func Open(modelName string) (*Root, error) {
|
||||
m, err := manifest.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
root := &Root{Manifest: m}
|
||||
root := &Root{
|
||||
Manifest: m,
|
||||
tensorQuant: make(map[string]*TensorQuantInfo),
|
||||
}
|
||||
|
||||
// Pre-scan first tensor blob for quantization metadata
|
||||
for _, layer := range m.GetTensorLayers("") {
|
||||
blobPath := m.BlobPath(layer.Digest)
|
||||
meta, err := readBlobMetadata(blobPath)
|
||||
if err != nil || meta == nil {
|
||||
|
||||
infos, blobQuantType, blobGroupSize, err := readBlobTensorQuantInfo(blobPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if qt := meta["quant_type"]; qt != "" {
|
||||
root.quantType = strings.ToUpper(qt)
|
||||
|
||||
for name, info := range infos {
|
||||
root.tensorQuant[name] = info
|
||||
}
|
||||
if gs := meta["group_size"]; gs != "" {
|
||||
fmt.Sscanf(gs, "%d", &root.groupSize)
|
||||
|
||||
if root.quantType == "" && blobQuantType != "" {
|
||||
root.quantType = strings.ToUpper(blobQuantType)
|
||||
root.groupSize = blobGroupSize
|
||||
if root.groupSize == 0 {
|
||||
root.groupSize = defaultGroupSize(root.quantType)
|
||||
}
|
||||
}
|
||||
break // only check the first tensor blob
|
||||
}
|
||||
|
||||
return root, nil
|
||||
@@ -52,46 +73,180 @@ func Open(modelName string) (*Root, error) {
|
||||
// Close is a no-op for now (future: release resources).
|
||||
func (r *Root) Close() {}
|
||||
|
||||
// QuantType returns the quantization type detected from tensor metadata.
|
||||
// QuantType returns the quantization type detected from the first tensor blob metadata.
|
||||
func (r *Root) QuantType() string { return r.quantType }
|
||||
|
||||
// GroupSize returns the quantization group size detected from tensor metadata.
|
||||
// GroupSize returns the quantization group size detected from the first tensor blob metadata.
|
||||
func (r *Root) GroupSize() int { return r.groupSize }
|
||||
|
||||
// readBlobMetadata reads the __metadata__ from a safetensors blob header.
|
||||
func readBlobMetadata(path string) (map[string]string, error) {
|
||||
// TensorQuant returns per-tensor quantization metadata if available.
|
||||
func (r *Root) TensorQuant(name string) *TensorQuantInfo {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return r.tensorQuant[name]
|
||||
}
|
||||
|
||||
// AllTensorQuant returns a copy of the per-tensor quantization metadata.
|
||||
func (r *Root) AllTensorQuant() map[string]*TensorQuantInfo {
|
||||
out := make(map[string]*TensorQuantInfo, len(r.tensorQuant))
|
||||
for k, v := range r.tensorQuant {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
copy := *v
|
||||
out[k] = ©
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func defaultGroupSize(quantType string) int {
|
||||
groupSize, _, _ := QuantizationParams(quantType)
|
||||
return groupSize
|
||||
}
|
||||
|
||||
func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string, int, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, err
|
||||
return nil, "", 0, err
|
||||
}
|
||||
if headerSize > 1024*1024 {
|
||||
return nil, fmt.Errorf("header too large: %d", headerSize)
|
||||
if headerSize > 100*1024*1024 {
|
||||
return nil, "", 0, fmt.Errorf("header too large: %d", headerSize)
|
||||
}
|
||||
|
||||
data := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(f, data); err != nil {
|
||||
return nil, err
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &header); err != nil {
|
||||
return nil, err
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
|
||||
globalQuantType = strings.ToUpper(globalQuantType)
|
||||
|
||||
mainNames := mainTensorNames(header)
|
||||
infos := make(map[string]*TensorQuantInfo)
|
||||
for _, name := range mainNames {
|
||||
if _, ok := header[name+".scale"]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
quantType := globalQuantType
|
||||
groupSize := globalGroupSize
|
||||
|
||||
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
|
||||
if quantType == "" {
|
||||
quantType = inferredType
|
||||
}
|
||||
if groupSize == 0 {
|
||||
groupSize = inferredGroup
|
||||
}
|
||||
if quantType == "" {
|
||||
continue
|
||||
}
|
||||
if groupSize == 0 {
|
||||
groupSize = defaultGroupSize(quantType)
|
||||
}
|
||||
|
||||
infos[name] = &TensorQuantInfo{QuantType: quantType, GroupSize: groupSize}
|
||||
}
|
||||
|
||||
return infos, globalQuantType, globalGroupSize, nil
|
||||
}
|
||||
|
||||
func parseGlobalQuantMetadata(header map[string]json.RawMessage) (quantType string, groupSize int) {
|
||||
metaRaw, ok := header["__metadata__"]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal(metaRaw, &meta); err != nil {
|
||||
return nil, err
|
||||
return "", 0
|
||||
}
|
||||
return meta, nil
|
||||
|
||||
quantType = meta["quant_type"]
|
||||
if gs := meta["group_size"]; gs != "" {
|
||||
groupSize, _ = strconv.Atoi(gs)
|
||||
}
|
||||
return quantType, groupSize
|
||||
}
|
||||
|
||||
func mainTensorNames(header map[string]json.RawMessage) []string {
|
||||
names := make([]string, 0, len(header))
|
||||
for name := range header {
|
||||
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func inferQuantTypeFromShapes(header map[string]json.RawMessage, tensorName string, hintQuantType string) (string, int) {
|
||||
type tensorShape struct {
|
||||
Shape []int64 `json:"shape"`
|
||||
}
|
||||
|
||||
mainRaw, ok := header[tensorName]
|
||||
if !ok {
|
||||
return "", 0
|
||||
}
|
||||
scaleRaw, ok := header[tensorName+".scale"]
|
||||
if !ok {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var mainInfo tensorShape
|
||||
if err := json.Unmarshal(mainRaw, &mainInfo); err != nil || len(mainInfo.Shape) == 0 {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var scaleInfo tensorShape
|
||||
if err := json.Unmarshal(scaleRaw, &scaleInfo); err != nil || len(scaleInfo.Shape) == 0 {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
weightCols := int(mainInfo.Shape[len(mainInfo.Shape)-1])
|
||||
scalesCols := int(scaleInfo.Shape[len(scaleInfo.Shape)-1])
|
||||
if weightCols <= 0 || scalesCols <= 0 {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
groupSize4 := weightCols * 8 / scalesCols
|
||||
groupSize8 := weightCols * 4 / scalesCols
|
||||
|
||||
switch {
|
||||
case groupSize4 == 32:
|
||||
return "INT4", 32
|
||||
case groupSize8 == 64:
|
||||
return "INT8", 64
|
||||
case groupSize4 == 64 && groupSize8 == 32:
|
||||
h := strings.ToUpper(hintQuantType)
|
||||
if strings.Contains(h, "8") {
|
||||
return "INT8", 32
|
||||
}
|
||||
if strings.Contains(h, "4") {
|
||||
return "INT4", 64
|
||||
}
|
||||
}
|
||||
|
||||
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
|
||||
return "INT4", groupSize4
|
||||
}
|
||||
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
|
||||
return "INT8", groupSize8
|
||||
}
|
||||
|
||||
return "", 0
|
||||
}
|
||||
|
||||
@@ -18,15 +18,27 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
mlx.EnableCompile()
|
||||
enableCompile := true
|
||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||
enableCompile = modelCompile.EnableCompile()
|
||||
}
|
||||
if enableCompile {
|
||||
mlx.EnableCompile()
|
||||
} else {
|
||||
mlx.DisableCompile()
|
||||
}
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
|
||||
caches, tokens := r.FindNearestCache(inputs)
|
||||
if len(caches) == 0 {
|
||||
caches = make([]cache.Cache, r.Model.NumLayers())
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
caches = cacheFactory.NewCaches()
|
||||
} else {
|
||||
caches = make([]cache.Cache, r.Model.NumLayers())
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
521
x/models/gemma3/gemma3.go
Normal file
521
x/models/gemma3/gemma3.go
Normal file
@@ -0,0 +1,521 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package gemma3 provides the Gemma 3 text model implementation for MLX.
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("Gemma3ForCausalLM", newModel)
|
||||
base.Register("Gemma3ForConditionalGeneration", newModel)
|
||||
}
|
||||
|
||||
// TextConfig holds configuration for the Gemma 3 text model.
|
||||
type TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
SlidingWindow int32 `json:"sliding_window"`
|
||||
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
TieWordEmbeddings bool `json:"tie_word_embeddings"`
|
||||
|
||||
// Quantization parameters (set during load based on model quantization).
|
||||
QuantGroupSize int `json:"-"`
|
||||
QuantBits int `json:"-"`
|
||||
QuantMode string `json:"-"`
|
||||
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
||||
|
||||
// Computed fields.
|
||||
Scale float32 `json:"-"`
|
||||
}
|
||||
|
||||
// Attention implements Gemma 3 attention with Q/K normalization.
|
||||
type Attention struct {
|
||||
QProj nn.LinearLayer
|
||||
KProj nn.LinearLayer
|
||||
VProj nn.LinearLayer
|
||||
OProj nn.LinearLayer
|
||||
|
||||
QNorm *nn.RMSNorm
|
||||
KNorm *nn.RMSNorm
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm.
|
||||
QNormScaled *mlx.Array
|
||||
KNormScaled *mlx.Array
|
||||
}
|
||||
|
||||
// MLP is the feed-forward network with GELU activation.
|
||||
type MLP struct {
|
||||
GateProj nn.LinearLayer
|
||||
UpProj nn.LinearLayer
|
||||
DownProj nn.LinearLayer
|
||||
}
|
||||
|
||||
// DecoderLayer is a single transformer block.
|
||||
type DecoderLayer struct {
|
||||
InputNorm *nn.RMSNorm
|
||||
Attention *Attention
|
||||
PostAttnNorm *nn.RMSNorm
|
||||
PreFFNorm *nn.RMSNorm
|
||||
MLP *MLP
|
||||
PostFFNorm *nn.RMSNorm
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm.
|
||||
InputNormScaled *mlx.Array
|
||||
PostAttnNormScaled *mlx.Array
|
||||
PreFFNormScaled *mlx.Array
|
||||
PostFFNormScaled *mlx.Array
|
||||
|
||||
// Layer metadata.
|
||||
IsSliding bool
|
||||
LayerIdx int32
|
||||
}
|
||||
|
||||
// Model is the Gemma 3 text-only model.
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding
|
||||
Layers []*DecoderLayer
|
||||
Norm *nn.RMSNorm
|
||||
LMHead nn.LinearLayer
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm.
|
||||
NormScaled *mlx.Array
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*TextConfig
|
||||
|
||||
weightPrefix string
|
||||
}
|
||||
|
||||
func defaultHeads(numLayers int32) (numHeads, numKVHeads int32) {
|
||||
switch numLayers {
|
||||
case 34:
|
||||
return 8, 4
|
||||
case 48:
|
||||
return 16, 8
|
||||
case 62:
|
||||
return 32, 16
|
||||
default:
|
||||
return 8, 4
|
||||
}
|
||||
}
|
||||
|
||||
func parseTextConfig(configData []byte) (TextConfig, bool, error) {
|
||||
var cfg TextConfig
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return TextConfig{}, false, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
var wrapped struct {
|
||||
TextConfig *TextConfig `json:"text_config"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &wrapped); err != nil {
|
||||
return TextConfig{}, false, fmt.Errorf("parse nested text config: %w", err)
|
||||
}
|
||||
|
||||
fromConditional := wrapped.TextConfig != nil
|
||||
if fromConditional {
|
||||
cfg = *wrapped.TextConfig
|
||||
|
||||
if cfg.HeadDim == 0 {
|
||||
cfg.HeadDim = 256
|
||||
}
|
||||
if cfg.NumAttentionHeads == 0 {
|
||||
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
|
||||
}
|
||||
if cfg.NumKeyValueHeads == 0 {
|
||||
_, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
|
||||
}
|
||||
if cfg.VocabSize == 0 {
|
||||
cfg.VocabSize = 262208
|
||||
}
|
||||
if cfg.SlidingWindowPattern == 0 && len(cfg.LayerTypes) == 0 {
|
||||
cfg.SlidingWindowPattern = 6
|
||||
}
|
||||
if cfg.MaxPositionEmbeddings == 0 {
|
||||
cfg.MaxPositionEmbeddings = 131072
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.HeadDim == 0 {
|
||||
cfg.HeadDim = 256
|
||||
}
|
||||
if cfg.NumAttentionHeads == 0 {
|
||||
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
|
||||
}
|
||||
if cfg.NumKeyValueHeads == 0 {
|
||||
cfg.NumKeyValueHeads = max(1, cfg.NumAttentionHeads/2)
|
||||
}
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
}
|
||||
if cfg.RopeLocalBaseFreq == 0 {
|
||||
cfg.RopeLocalBaseFreq = 10000
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
if cfg.VocabSize == 0 {
|
||||
cfg.VocabSize = 262208
|
||||
}
|
||||
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
return cfg, fromConditional, nil
|
||||
}
|
||||
|
||||
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
|
||||
for _, prefix := range []string{"", "language_model."} {
|
||||
if tensors[prefix+"model.embed_tokens.weight"] != nil {
|
||||
return prefix
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isLayerSliding(layerIdx int32, cfg *TextConfig) bool {
|
||||
if len(cfg.LayerTypes) > 0 && int(layerIdx) < len(cfg.LayerTypes) {
|
||||
return cfg.LayerTypes[layerIdx] == "sliding_attention"
|
||||
}
|
||||
if cfg.SlidingWindowPattern <= 0 {
|
||||
return false
|
||||
}
|
||||
return (layerIdx+1)%cfg.SlidingWindowPattern != 0
|
||||
}
|
||||
|
||||
func precomputeGemmaScaledWeights(m *Model) {
|
||||
if m.Norm != nil {
|
||||
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
|
||||
}
|
||||
|
||||
var scaled []*mlx.Array
|
||||
if m.NormScaled != nil {
|
||||
scaled = append(scaled, m.NormScaled)
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
if layer == nil || layer.Attention == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if layer.InputNorm != nil {
|
||||
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
|
||||
scaled = append(scaled, layer.InputNormScaled)
|
||||
}
|
||||
if layer.PostAttnNorm != nil {
|
||||
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
|
||||
scaled = append(scaled, layer.PostAttnNormScaled)
|
||||
}
|
||||
if layer.PreFFNorm != nil {
|
||||
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
|
||||
scaled = append(scaled, layer.PreFFNormScaled)
|
||||
}
|
||||
if layer.PostFFNorm != nil {
|
||||
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
|
||||
scaled = append(scaled, layer.PostFFNormScaled)
|
||||
}
|
||||
|
||||
if layer.Attention.QNorm != nil {
|
||||
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
|
||||
scaled = append(scaled, layer.Attention.QNormScaled)
|
||||
}
|
||||
if layer.Attention.KNorm != nil {
|
||||
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
|
||||
scaled = append(scaled, layer.Attention.KNormScaled)
|
||||
}
|
||||
}
|
||||
|
||||
if len(scaled) > 0 {
|
||||
mlx.Eval(scaled...)
|
||||
}
|
||||
}
|
||||
|
||||
func newModel(root *model.Root) (base.Model, error) {
|
||||
configData, err := root.Manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
cfg, _, err := parseTextConfig(configData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if qt := root.QuantType(); qt != "" {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||
if gs := root.GroupSize(); gs > 0 {
|
||||
cfg.QuantGroupSize = gs
|
||||
}
|
||||
} else {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
||||
}
|
||||
cfg.TensorQuant = root.AllTensorQuant()
|
||||
|
||||
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
|
||||
tokConfig := &tokenizer.TokenizerConfig{ConfigJSON: configData}
|
||||
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
||||
TextConfig: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
for i := range m.Layers {
|
||||
m.Layers[i] = &DecoderLayer{
|
||||
LayerIdx: int32(i),
|
||||
IsSliding: isLayerSliding(int32(i), m.TextConfig),
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns them
|
||||
// to model fields.
|
||||
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.weightPrefix = resolveWeightPrefix(tensors)
|
||||
prefix := m.weightPrefix
|
||||
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
|
||||
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
|
||||
if embedWeight == nil {
|
||||
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
|
||||
}
|
||||
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||
|
||||
normWeight := tensors[prefix+"model.norm.weight"]
|
||||
if normWeight == nil {
|
||||
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
|
||||
}
|
||||
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
|
||||
|
||||
if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else {
|
||||
// Gemma usually ties output projection to embeddings.
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
}
|
||||
|
||||
for i := int32(0); i < m.NumHiddenLayers; i++ {
|
||||
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
|
||||
|
||||
layer := &DecoderLayer{
|
||||
LayerIdx: i,
|
||||
IsSliding: isLayerSliding(i, m.TextConfig),
|
||||
Attention: &Attention{},
|
||||
MLP: &MLP{},
|
||||
}
|
||||
|
||||
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
|
||||
layer.InputNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
|
||||
layer.PostAttnNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil {
|
||||
layer.PreFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil {
|
||||
layer.PostFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
|
||||
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
|
||||
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
|
||||
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
|
||||
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
|
||||
|
||||
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
|
||||
layer.Attention.QNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
|
||||
layer.Attention.KNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
|
||||
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
|
||||
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
|
||||
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
|
||||
|
||||
if layer.InputNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing input_layernorm", i)
|
||||
}
|
||||
if layer.PostAttnNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
|
||||
}
|
||||
if layer.PreFFNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing pre_feedforward_layernorm", i)
|
||||
}
|
||||
if layer.PostFFNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing post_feedforward_layernorm", i)
|
||||
}
|
||||
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
|
||||
return fmt.Errorf("layer %d: missing attention projections", i)
|
||||
}
|
||||
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing attention q/k norms", i)
|
||||
}
|
||||
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
|
||||
return fmt.Errorf("layer %d: missing mlp projections", i)
|
||||
}
|
||||
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
precomputeGemmaScaledWeights(m)
|
||||
if m.NormScaled == nil {
|
||||
return fmt.Errorf("missing precomputed final norm weight")
|
||||
}
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.TextConfig)
|
||||
}
|
||||
|
||||
return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps)
|
||||
}
|
||||
|
||||
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
return m.LMHead.Forward(x)
|
||||
}
|
||||
|
||||
func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
// NewCaches creates cache objects for all layers.
|
||||
func (m *Model) NewCaches() []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i, layer := range m.Layers {
|
||||
if m.SlidingWindow > 0 && layer.IsSliding {
|
||||
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
||||
} else {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// FormatPrompt applies the Gemma 3 chat template.
|
||||
func (m *Model) FormatPrompt(prompt string) string {
|
||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
|
||||
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
|
||||
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||
|
||||
mlpOut := l.MLP.Forward(normed)
|
||||
mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||
|
||||
return mlx.Add(h, mlpOut)
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
|
||||
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
|
||||
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps)
|
||||
k = mlx.RMSNormFn(k, a.KNormScaled, cfg.RMSNormEps)
|
||||
|
||||
ropeTheta := cfg.RopeTheta
|
||||
if isSliding {
|
||||
ropeTheta = cfg.RopeLocalBaseFreq
|
||||
}
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
}
|
||||
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = nn.RepeatKV(k, repeatFactor)
|
||||
v = nn.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.GELUApprox(m.GateProj.Forward(x))
|
||||
up := m.UpProj.Forward(x)
|
||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
@@ -64,9 +63,10 @@ type Config struct {
|
||||
RopeScaling *RopeScaling `json:"rope_scaling"`
|
||||
|
||||
// Quantization parameters (set during load based on model quantization)
|
||||
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
|
||||
QuantBits int `json:"-"` // Bits per weight (4 or 8)
|
||||
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
|
||||
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
|
||||
QuantBits int `json:"-"` // Bits per weight (4 or 8)
|
||||
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
|
||||
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
||||
|
||||
// Computed fields
|
||||
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
|
||||
@@ -372,22 +372,6 @@ func supportsGatherQMM(mode string, bits int) bool {
|
||||
return mode == "affine" && (bits == 4 || bits == 8)
|
||||
}
|
||||
|
||||
// quantizationParams returns groupSize, bits, mode for a quantization type string.
|
||||
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "NVFP4":
|
||||
return 16, 4, "nvfp4"
|
||||
case "FP4", "Q4", "INT4":
|
||||
return 32, 4, "affine"
|
||||
case "MXFP8":
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8", "":
|
||||
return 64, 8, "affine"
|
||||
default:
|
||||
return 32, 8, "affine"
|
||||
}
|
||||
}
|
||||
|
||||
// ExpertWeight holds a single expert's weight with optional quantization components.
|
||||
type ExpertWeight struct {
|
||||
Weight *mlx.Array
|
||||
@@ -408,7 +392,15 @@ func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized b
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
|
||||
groupSize, bits, mode := cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode
|
||||
groupSize, bits, mode := model.ResolveLinearQuantParams(
|
||||
cfg.QuantGroupSize,
|
||||
cfg.QuantBits,
|
||||
cfg.QuantMode,
|
||||
cfg.TensorQuant,
|
||||
path+".weight",
|
||||
w,
|
||||
scales,
|
||||
)
|
||||
|
||||
if useQuantized && supportsGatherQMM(mode, bits) {
|
||||
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
|
||||
@@ -492,7 +484,16 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi
|
||||
// Check if quantized and dequantize
|
||||
if scales := tensors[path+".weight_scale"]; scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
w = mlx.Dequantize(w, scales, qbiases, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode)
|
||||
groupSize, bits, mode := model.ResolveLinearQuantParams(
|
||||
cfg.QuantGroupSize,
|
||||
cfg.QuantBits,
|
||||
cfg.QuantMode,
|
||||
cfg.TensorQuant,
|
||||
path+".weight",
|
||||
w,
|
||||
scales,
|
||||
)
|
||||
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
|
||||
@@ -507,32 +508,6 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi
|
||||
return embedQ, unembedOut
|
||||
}
|
||||
|
||||
// makeLinear creates a Linear or QuantizedLinear layer from the tensor map.
|
||||
func makeLinear(tensors map[string]*mlx.Array, path string, cfg *Config) nn.LinearLayer {
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scales := tensors[path+".weight_scale"]
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
bias := tensors[path+".bias"]
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: w,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: cfg.QuantGroupSize,
|
||||
Bits: cfg.QuantBits,
|
||||
Mode: cfg.QuantMode,
|
||||
}
|
||||
}
|
||||
|
||||
bias := tensors[path+".bias"]
|
||||
return nn.NewLinear(w, bias)
|
||||
}
|
||||
|
||||
// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer,
|
||||
// no weights loaded yet). Called by the registry via base.New().
|
||||
func newModel(root *model.Root) (base.Model, error) {
|
||||
@@ -551,13 +526,14 @@ func newModel(root *model.Root) (base.Model, error) {
|
||||
|
||||
// Set up quantization parameters from pre-scanned metadata
|
||||
if qt := root.QuantType(); qt != "" {
|
||||
_, cfg.QuantBits, cfg.QuantMode = quantizationParams(qt)
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||
if gs := root.GroupSize(); gs > 0 {
|
||||
cfg.QuantGroupSize = gs
|
||||
} else {
|
||||
cfg.QuantGroupSize, _, _ = quantizationParams(qt)
|
||||
}
|
||||
} else {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
||||
}
|
||||
cfg.TensorQuant = root.AllTensorQuant()
|
||||
|
||||
// Load tokenizer
|
||||
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||
@@ -596,7 +572,20 @@ func newModel(root *model.Root) (base.Model, error) {
|
||||
// layer creation.
|
||||
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
cfg := m.Config
|
||||
linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
|
||||
useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
||||
if !useQuantized && cfg.TensorQuant != nil {
|
||||
for _, tq := range cfg.TensorQuant {
|
||||
if tq == nil {
|
||||
continue
|
||||
}
|
||||
_, bits, mode := model.QuantizationParams(tq.QuantType)
|
||||
if supportsGatherQMM(mode, bits) {
|
||||
useQuantized = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load embedding
|
||||
if w := tensors["model.embed_tokens.weight"]; w != nil {
|
||||
@@ -609,7 +598,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
}
|
||||
|
||||
// Load LM head
|
||||
m.LMHead = makeLinear(tensors, "lm_head", cfg)
|
||||
m.LMHead = linears.Make("lm_head")
|
||||
|
||||
// Load layers
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
@@ -617,16 +606,16 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
|
||||
// Load attention (same for both block types)
|
||||
attn := &MLAAttention{}
|
||||
attn.QAProj = makeLinear(tensors, prefix+".self_attn.q_a_proj", cfg)
|
||||
attn.QAProj = linears.Make(prefix + ".self_attn.q_a_proj")
|
||||
if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
|
||||
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
||||
}
|
||||
attn.QBProj = makeLinear(tensors, prefix+".self_attn.q_b_proj", cfg)
|
||||
attn.KVAProjWithMQA = makeLinear(tensors, prefix+".self_attn.kv_a_proj_with_mqa", cfg)
|
||||
attn.QBProj = linears.Make(prefix + ".self_attn.q_b_proj")
|
||||
attn.KVAProjWithMQA = linears.Make(prefix + ".self_attn.kv_a_proj_with_mqa")
|
||||
if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
|
||||
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
||||
}
|
||||
attn.OProj = makeLinear(tensors, prefix+".self_attn.o_proj", cfg)
|
||||
attn.OProj = linears.Make(prefix + ".self_attn.o_proj")
|
||||
|
||||
// Sanitize MLA weights for absorbed attention
|
||||
embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg)
|
||||
@@ -647,9 +636,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
}
|
||||
|
||||
block.MLP = &DenseMLP{
|
||||
GateProj: makeLinear(tensors, prefix+".mlp.gate_proj", cfg),
|
||||
UpProj: makeLinear(tensors, prefix+".mlp.up_proj", cfg),
|
||||
DownProj: makeLinear(tensors, prefix+".mlp.down_proj", cfg),
|
||||
GateProj: linears.Make(prefix + ".mlp.gate_proj"),
|
||||
UpProj: linears.Make(prefix + ".mlp.up_proj"),
|
||||
DownProj: linears.Make(prefix + ".mlp.down_proj"),
|
||||
}
|
||||
|
||||
m.Layers[i] = block
|
||||
@@ -690,7 +679,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
}
|
||||
|
||||
moeGate := &MoEGate{}
|
||||
moeGate.Gate = makeLinear(tensors, prefix+".mlp.gate", cfg)
|
||||
moeGate.Gate = linears.Make(prefix + ".mlp.gate")
|
||||
if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
|
||||
moeGate.EScoreCorrectionBias = bias
|
||||
}
|
||||
@@ -703,9 +692,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
// Load shared experts if present
|
||||
if cfg.NSharedExperts > 0 {
|
||||
block.MoE.SharedExperts = &SharedExperts{
|
||||
GateProj: makeLinear(tensors, prefix+".mlp.shared_experts.gate_proj", cfg),
|
||||
UpProj: makeLinear(tensors, prefix+".mlp.shared_experts.up_proj", cfg),
|
||||
DownProj: makeLinear(tensors, prefix+".mlp.shared_experts.down_proj", cfg),
|
||||
GateProj: linears.Make(prefix + ".mlp.shared_experts.gate_proj"),
|
||||
UpProj: linears.Make(prefix + ".mlp.shared_experts.up_proj"),
|
||||
DownProj: linears.Make(prefix + ".mlp.shared_experts.down_proj"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
323
x/models/llama/llama.go
Normal file
323
x/models/llama/llama.go
Normal file
@@ -0,0 +1,323 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package llama provides a Llama-style decoder-only transformer for MLX.
|
||||
package llama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("LlamaForCausalLM", newModel)
|
||||
}
|
||||
|
||||
// Config holds Llama model configuration.
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
TieWordEmbeddings bool `json:"tie_word_embeddings"`
|
||||
|
||||
// Quantization parameters (set during load based on model quantization).
|
||||
QuantGroupSize int `json:"-"`
|
||||
QuantBits int `json:"-"`
|
||||
QuantMode string `json:"-"`
|
||||
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
||||
|
||||
// Computed fields.
|
||||
HeadDim int32 `json:"-"`
|
||||
Scale float32 `json:"-"`
|
||||
}
|
||||
|
||||
// Model is a Llama text model.
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding
|
||||
Layers []*Layer
|
||||
Norm *nn.RMSNorm
|
||||
LMHead nn.LinearLayer
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
|
||||
weightPrefix string
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
Attention *Attention
|
||||
MLP *MLP
|
||||
AttentionNorm *nn.RMSNorm
|
||||
MLPNorm *nn.RMSNorm
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QProj nn.LinearLayer
|
||||
KProj nn.LinearLayer
|
||||
VProj nn.LinearLayer
|
||||
OProj nn.LinearLayer
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
GateProj nn.LinearLayer
|
||||
UpProj nn.LinearLayer
|
||||
DownProj nn.LinearLayer
|
||||
}
|
||||
|
||||
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
|
||||
for _, prefix := range []string{"", "language_model."} {
|
||||
if tensors[prefix+"model.embed_tokens.weight"] != nil {
|
||||
return prefix
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func newModel(root *model.Root) (base.Model, error) {
|
||||
configData, err := root.Manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize <= 0 {
|
||||
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumAttentionHeads <= 0 {
|
||||
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
|
||||
}
|
||||
if cfg.NumKeyValueHeads <= 0 {
|
||||
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
|
||||
}
|
||||
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
|
||||
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
|
||||
}
|
||||
if cfg.HeadDim == 0 {
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
}
|
||||
if cfg.HeadDim <= 0 {
|
||||
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
||||
}
|
||||
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
|
||||
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
|
||||
}
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 10000
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-5
|
||||
}
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
if qt := root.QuantType(); qt != "" {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||
if gs := root.GroupSize(); gs > 0 {
|
||||
cfg.QuantGroupSize = gs
|
||||
}
|
||||
} else {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
||||
}
|
||||
cfg.TensorQuant = root.AllTensorQuant()
|
||||
|
||||
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
|
||||
tokConfig := &tokenizer.TokenizerConfig{
|
||||
ConfigJSON: configData,
|
||||
}
|
||||
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns them
|
||||
// to model fields.
|
||||
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.weightPrefix = resolveWeightPrefix(tensors)
|
||||
prefix := m.weightPrefix
|
||||
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
|
||||
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
|
||||
if embedWeight == nil {
|
||||
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
|
||||
}
|
||||
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||
|
||||
normWeight := tensors[prefix+"model.norm.weight"]
|
||||
if normWeight == nil {
|
||||
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
|
||||
}
|
||||
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
|
||||
|
||||
if m.TieWordEmbeddings {
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else {
|
||||
// Fallback used by many Llama checkpoints where output is tied.
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
}
|
||||
|
||||
for i := int32(0); i < m.NumHiddenLayers; i++ {
|
||||
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
|
||||
|
||||
layer := &Layer{
|
||||
Attention: &Attention{},
|
||||
MLP: &MLP{},
|
||||
}
|
||||
|
||||
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
|
||||
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
|
||||
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
|
||||
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
|
||||
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
|
||||
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
|
||||
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
|
||||
|
||||
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
|
||||
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
|
||||
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
|
||||
|
||||
if layer.AttentionNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing input_layernorm", i)
|
||||
}
|
||||
if layer.MLPNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
|
||||
}
|
||||
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
|
||||
return fmt.Errorf("layer %d: missing attention projections", i)
|
||||
}
|
||||
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
|
||||
return fmt.Errorf("layer %d: missing mlp projections", i)
|
||||
}
|
||||
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
return m.Norm.Forward(h, m.RMSNormEps)
|
||||
}
|
||||
|
||||
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
return m.LMHead.Forward(x)
|
||||
}
|
||||
|
||||
func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
func (m *Model) NewCaches() []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
|
||||
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
|
||||
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
}
|
||||
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = nn.RepeatKV(k, repeatFactor)
|
||||
v = nn.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
}
|
||||
338
x/models/qwen3/qwen3.go
Normal file
338
x/models/qwen3/qwen3.go
Normal file
@@ -0,0 +1,338 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen3 provides the Qwen3 text model implementation for MLX.
|
||||
package qwen3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("Qwen3ForCausalLM", newModel)
|
||||
}
|
||||
|
||||
// Config holds Qwen3 model configuration.
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
TieWordEmbeddings bool `json:"tie_word_embeddings"`
|
||||
|
||||
// Quantization parameters (set during load based on model quantization).
|
||||
QuantGroupSize int `json:"-"`
|
||||
QuantBits int `json:"-"`
|
||||
QuantMode string `json:"-"`
|
||||
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
||||
|
||||
// Computed fields.
|
||||
Scale float32 `json:"-"`
|
||||
QKNormEps float32 `json:"-"`
|
||||
}
|
||||
|
||||
// Model is the Qwen3 text-only model.
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding
|
||||
Layers []*Layer
|
||||
Norm *nn.RMSNorm
|
||||
LMHead nn.LinearLayer
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
|
||||
weightPrefix string
|
||||
}
|
||||
|
||||
// Layer is a single Qwen3 decoder block.
|
||||
type Layer struct {
|
||||
Attention *Attention
|
||||
MLP *MLP
|
||||
AttentionNorm *nn.RMSNorm
|
||||
MLPNorm *nn.RMSNorm
|
||||
}
|
||||
|
||||
// Attention implements Qwen3 attention with Q/K norms.
|
||||
type Attention struct {
|
||||
QProj nn.LinearLayer
|
||||
KProj nn.LinearLayer
|
||||
VProj nn.LinearLayer
|
||||
OProj nn.LinearLayer
|
||||
QNorm *nn.RMSNorm
|
||||
KNorm *nn.RMSNorm
|
||||
}
|
||||
|
||||
// MLP is the feed-forward network with SwiGLU activation.
|
||||
type MLP struct {
|
||||
GateProj nn.LinearLayer
|
||||
UpProj nn.LinearLayer
|
||||
DownProj nn.LinearLayer
|
||||
}
|
||||
|
||||
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
|
||||
for _, prefix := range []string{"", "language_model."} {
|
||||
if tensors[prefix+"model.embed_tokens.weight"] != nil {
|
||||
return prefix
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func newModel(root *model.Root) (base.Model, error) {
|
||||
configData, err := root.Manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize <= 0 {
|
||||
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumAttentionHeads <= 0 {
|
||||
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
|
||||
}
|
||||
if cfg.NumKeyValueHeads <= 0 {
|
||||
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
|
||||
}
|
||||
if cfg.HeadDim == 0 {
|
||||
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
|
||||
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
|
||||
}
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
}
|
||||
if cfg.HeadDim <= 0 {
|
||||
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
||||
}
|
||||
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
|
||||
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
}
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
cfg.QKNormEps = 1e-6
|
||||
|
||||
if qt := root.QuantType(); qt != "" {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||
if gs := root.GroupSize(); gs > 0 {
|
||||
cfg.QuantGroupSize = gs
|
||||
}
|
||||
} else {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
||||
}
|
||||
cfg.TensorQuant = root.AllTensorQuant()
|
||||
|
||||
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
|
||||
tokConfig := &tokenizer.TokenizerConfig{
|
||||
ConfigJSON: configData,
|
||||
}
|
||||
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns them
|
||||
// to model fields.
|
||||
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.weightPrefix = resolveWeightPrefix(tensors)
|
||||
prefix := m.weightPrefix
|
||||
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
|
||||
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
|
||||
if embedWeight == nil {
|
||||
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
|
||||
}
|
||||
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||
|
||||
normWeight := tensors[prefix+"model.norm.weight"]
|
||||
if normWeight == nil {
|
||||
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
|
||||
}
|
||||
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
|
||||
|
||||
if m.TieWordEmbeddings {
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else {
|
||||
// Qwen3 checkpoints commonly tie output projection to embeddings.
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
}
|
||||
|
||||
for i := int32(0); i < m.NumHiddenLayers; i++ {
|
||||
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
|
||||
|
||||
layer := &Layer{
|
||||
Attention: &Attention{},
|
||||
MLP: &MLP{},
|
||||
}
|
||||
|
||||
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
|
||||
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
|
||||
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
|
||||
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
|
||||
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
|
||||
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
|
||||
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
|
||||
|
||||
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
|
||||
layer.Attention.QNorm = nn.NewRMSNorm(w, m.QKNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
|
||||
layer.Attention.KNorm = nn.NewRMSNorm(w, m.QKNormEps)
|
||||
}
|
||||
|
||||
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
|
||||
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
|
||||
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
|
||||
|
||||
if layer.AttentionNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing input_layernorm", i)
|
||||
}
|
||||
if layer.MLPNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
|
||||
}
|
||||
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
|
||||
return fmt.Errorf("layer %d: missing attention projections", i)
|
||||
}
|
||||
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing attention q/k norms", i)
|
||||
}
|
||||
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
|
||||
return fmt.Errorf("layer %d: missing mlp projections", i)
|
||||
}
|
||||
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
return m.Norm.Forward(h, m.RMSNormEps)
|
||||
}
|
||||
|
||||
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
return m.LMHead.Forward(x)
|
||||
}
|
||||
|
||||
func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
func (m *Model) NewCaches() []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
|
||||
q = a.QNorm.Forward(q, cfg.QKNormEps)
|
||||
k = a.KNorm.Forward(k, cfg.QKNormEps)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
}
|
||||
|
||||
// MLX SDPA supports grouped-query attention directly (Q heads can be a
|
||||
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
|
||||
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -58,7 +59,15 @@ func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return buildModelInfo(config, totalBytes, tensorCount), nil
|
||||
info := buildModelInfo(config, totalBytes, tensorCount)
|
||||
|
||||
// For quantized models, byte-based estimation can significantly undercount
|
||||
// parameters. Prefer exact counting from tensor shapes in safetensors headers.
|
||||
if paramCount, err := getParameterCountFromManifest(mf); err == nil && paramCount > 0 {
|
||||
info["general.parameter_count"] = paramCount
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// buildModelInfo constructs the model info map from config and tensor stats.
|
||||
@@ -151,6 +160,51 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
|
||||
return info
|
||||
}
|
||||
|
||||
// getParameterCountFromManifest counts model parameters from tensor shapes.
|
||||
// This accounts for quantized tensors by using unpacked shapes from
|
||||
// getTensorInfoFromManifest.
|
||||
func getParameterCountFromManifest(mf *manifest.Manifest) (int64, error) {
|
||||
tensors, err := getTensorInfoFromManifest(mf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var total int64
|
||||
for _, tensor := range tensors {
|
||||
if len(tensor.Shape) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
elements := int64(1)
|
||||
for _, dim := range tensor.Shape {
|
||||
if dim == 0 {
|
||||
elements = 0
|
||||
break
|
||||
}
|
||||
|
||||
if dim > uint64(math.MaxInt64) {
|
||||
return 0, fmt.Errorf("tensor %s dimension too large: %d", tensor.Name, dim)
|
||||
}
|
||||
|
||||
d := int64(dim)
|
||||
if elements > math.MaxInt64/d {
|
||||
return 0, fmt.Errorf("tensor %s element count overflow", tensor.Name)
|
||||
}
|
||||
elements *= d
|
||||
}
|
||||
|
||||
if elements == 0 {
|
||||
continue
|
||||
}
|
||||
if total > math.MaxInt64-elements {
|
||||
return 0, fmt.Errorf("total parameter count overflow")
|
||||
}
|
||||
total += elements
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
|
||||
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
|
||||
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
|
||||
|
||||
@@ -714,6 +714,187 @@ func TestGetTensorInfoFromManifest_Quantized(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetParameterCountFromManifest(t *testing.T) {
|
||||
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
blobDir := filepath.Join(tempDir, "blobs")
|
||||
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create blobs dir: %v", err)
|
||||
}
|
||||
|
||||
// Unquantized tensor: [4,5] = 20 params
|
||||
header1 := map[string]any{
|
||||
"model.embed_tokens.weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{4, 5},
|
||||
"data_offsets": []int64{0, 40},
|
||||
},
|
||||
}
|
||||
header1JSON, _ := json.Marshal(header1)
|
||||
var buf1 bytes.Buffer
|
||||
binary.Write(&buf1, binary.LittleEndian, uint64(len(header1JSON)))
|
||||
buf1.Write(header1JSON)
|
||||
|
||||
digest1 := "sha256:1111111111111111111111111111111111111111111111111111111111111111"
|
||||
blobPath1, err := manifest.BlobsPath(digest1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath1, buf1.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob1: %v", err)
|
||||
}
|
||||
|
||||
// Quantized int4 tensor with packed shape [10,2] -> unpacked [10,16] = 160 params
|
||||
header2 := map[string]any{
|
||||
"__metadata__": map[string]string{
|
||||
"quant_type": "int4",
|
||||
"group_size": "32",
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{10, 2},
|
||||
"data_offsets": []int64{0, 80},
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10, 1},
|
||||
"data_offsets": []int64{80, 100},
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10, 1},
|
||||
"data_offsets": []int64{100, 120},
|
||||
},
|
||||
}
|
||||
header2JSON, _ := json.Marshal(header2)
|
||||
var buf2 bytes.Buffer
|
||||
binary.Write(&buf2, binary.LittleEndian, uint64(len(header2JSON)))
|
||||
buf2.Write(header2JSON)
|
||||
|
||||
digest2 := "sha256:2222222222222222222222222222222222222222222222222222222222222222"
|
||||
blobPath2, err := manifest.BlobsPath(digest2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath2, buf2.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob2: %v", err)
|
||||
}
|
||||
|
||||
mf := &manifest.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Layers: []manifest.Layer{
|
||||
{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: digest1,
|
||||
Size: int64(buf1.Len() + 40),
|
||||
Name: "model.embed_tokens.weight",
|
||||
},
|
||||
{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: digest2,
|
||||
Size: int64(buf2.Len() + 120),
|
||||
Name: "model.layers.0.mlp.up_proj.weight",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
paramCount, err := getParameterCountFromManifest(mf)
|
||||
if err != nil {
|
||||
t.Fatalf("getParameterCountFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
const want int64 = 180 // 20 + 160
|
||||
if paramCount != want {
|
||||
t.Errorf("parameter_count = %d, want %d", paramCount, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetParameterCountFromManifest_MixedQuantizedPacked(t *testing.T) {
|
||||
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
blobDir := filepath.Join(tempDir, "blobs")
|
||||
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create blobs dir: %v", err)
|
||||
}
|
||||
|
||||
// Packed mixed-precision blob (no global metadata):
|
||||
// - gate_proj: int4 packed [5,8] + scale [5,2] => unpacked [5,64] = 320 params
|
||||
// - down_proj: int8 packed [5,16] + scale [5,1] => unpacked [5,64] = 320 params
|
||||
header := map[string]any{
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{5, 8},
|
||||
"data_offsets": []int64{0, 160},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{5, 2},
|
||||
"data_offsets": []int64{160, 180},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{5, 2},
|
||||
"data_offsets": []int64{180, 200},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{5, 16},
|
||||
"data_offsets": []int64{200, 520},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{5, 1},
|
||||
"data_offsets": []int64{520, 530},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{5, 1},
|
||||
"data_offsets": []int64{530, 540},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
|
||||
digest := "sha256:3333333333333333333333333333333333333333333333333333333333333333"
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob: %v", err)
|
||||
}
|
||||
|
||||
mf := &manifest.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Layers: []manifest.Layer{
|
||||
{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: digest,
|
||||
Size: int64(buf.Len() + 540),
|
||||
Name: "model.layers.0.mlp.experts",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
paramCount, err := getParameterCountFromManifest(mf)
|
||||
if err != nil {
|
||||
t.Fatalf("getParameterCountFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
const want int64 = 640 // 320 + 320
|
||||
if paramCount != want {
|
||||
t.Errorf("parameter_count = %d, want %d", paramCount, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorsAllHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user