Compare commits

..

9 Commits

Author SHA1 Message Date
Patrick Devine
d5ac80125f colour changes + feed the linter 2026-02-06 12:26:47 -08:00
Patrick Devine
fb325cf88a remove logo 2026-02-06 11:15:11 -08:00
Patrick Devine
c54b99fa6a fix go.sum 2026-02-06 11:15:11 -08:00
Patrick Devine
248a183e0b feed the linter 2026-02-06 11:15:11 -08:00
Patrick Devine
cab70db823 model selector improvments 2026-02-06 11:15:11 -08:00
Patrick Devine
6e32d73afe fix unit test 2026-02-06 11:15:11 -08:00
Patrick Devine
ac1fefc52d remember last selection 2026-02-06 11:15:11 -08:00
Patrick Devine
1de2e8f1a6 gofumpt the linter 2026-02-06 11:15:11 -08:00
Patrick Devine
4ec13bc642 launch: new menu system for ollama launch 2026-02-06 11:15:11 -08:00
55 changed files with 2095 additions and 15902 deletions

13
cmd/background_unix.go Normal file
View File

@@ -0,0 +1,13 @@
//go:build !windows
package cmd
import "syscall"
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Unix.
// Setpgid prevents the server from being killed when the parent process exits.
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setpgid: true,
}
}

12
cmd/background_windows.go Normal file
View File

@@ -0,0 +1,12 @@
package cmd
import "syscall"
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Windows.
// CREATE_NO_WINDOW (0x08000000) prevents a console window from appearing.
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
CreationFlags: 0x08000000,
HideWindow: true,
}
}

View File

@@ -15,6 +15,7 @@ import (
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
@@ -37,6 +38,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/tui"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
@@ -1804,6 +1806,190 @@ Environment Variables:
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
}
// ensureServerRunning checks if the ollama server is running and starts it in the background if not.
func ensureServerRunning(ctx context.Context) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Check if server is already running
if err := client.Heartbeat(ctx); err == nil {
return nil // server is already running
}
// Server not running, start it in the background
exe, err := os.Executable()
if err != nil {
return fmt.Errorf("could not find executable: %w", err)
}
serverCmd := exec.CommandContext(ctx, exe, "serve")
serverCmd.Env = os.Environ()
serverCmd.SysProcAttr = backgroundServerSysProcAttr()
if err := serverCmd.Start(); err != nil {
return fmt.Errorf("failed to start server: %w", err)
}
// Wait for the server to be ready
for {
time.Sleep(500 * time.Millisecond)
if err := client.Heartbeat(ctx); err == nil {
return nil // server has started
}
}
}
// runInteractiveTUI runs the main interactive TUI menu.
func runInteractiveTUI(cmd *cobra.Command) {
// Ensure the server is running before showing the TUI
if err := ensureServerRunning(cmd.Context()); err != nil {
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
return
}
// errSelectionCancelled is returned when user cancels model selection
errSelectionCancelled := errors.New("cancelled")
// Selector adapters for tui
singleSelector := func(title string, items []config.ModelItem) (string, error) {
tuiItems := make([]tui.SelectItem, len(items))
for i, item := range items {
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description}
}
result, err := tui.SelectSingle(title, tuiItems)
if errors.Is(err, tui.ErrCancelled) {
return "", errSelectionCancelled
}
return result, err
}
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
tuiItems := make([]tui.SelectItem, len(items))
for i, item := range items {
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description}
}
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
if errors.Is(err, tui.ErrCancelled) {
return nil, errSelectionCancelled
}
return result, err
}
for {
result, err := tui.Run()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return
}
runModel := func(modelName string) {
_ = config.SetLastModel(modelName)
opts := runOptions{
Model: modelName,
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
ShowConnect: true,
}
if err := loadOrUnloadModel(cmd, &opts); err != nil {
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
return
}
if err := generateInteractive(cmd, opts); err != nil {
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
}
}
launchIntegration := func(name string) bool {
// If not configured or model no longer exists, prompt for model selection
configuredModel := config.IntegrationModel(name)
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
if errors.Is(err, errSelectionCancelled) {
return false // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
return true
}
}
if err := config.LaunchIntegration(name); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
}
return true
}
switch result.Selection {
case tui.SelectionNone:
// User quit
return
case tui.SelectionRunModel:
_ = config.SetLastSelection("run")
// Run last model directly if configured and still exists
if modelName := config.LastModel(); modelName != "" && config.ModelExists(cmd.Context(), modelName) {
runModel(modelName)
} else {
// No last model or model no longer exists, show picker
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
if errors.Is(err, errSelectionCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
continue
}
runModel(modelName)
}
case tui.SelectionChangeRunModel:
_ = config.SetLastSelection("run")
// Use model from modal if selected, otherwise show picker
modelName := result.Model
if modelName == "" {
var err error
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
if errors.Is(err, errSelectionCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
continue
}
}
runModel(modelName)
case tui.SelectionIntegration:
_ = config.SetLastSelection(result.Integration)
if !launchIntegration(result.Integration) {
continue // Return to main menu
}
case tui.SelectionChangeIntegration:
_ = config.SetLastSelection(result.Integration)
// Use model from modal if selected, otherwise show picker
if result.Model != "" {
// Model already selected from modal - save and launch
if err := config.SaveIntegrationModel(result.Integration, result.Model); err != nil {
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
continue
}
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
} else {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
if errors.Is(err, errSelectionCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
continue
}
if err := config.LaunchIntegration(result.Integration); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
}
}
}
}
func NewCLI() *cobra.Command {
log.SetFlags(log.LstdFlags | log.Lshortfile)
cobra.EnableCommandSorting = false
@@ -1826,11 +2012,13 @@ func NewCLI() *cobra.Command {
return
}
cmd.Print(cmd.UsageString())
runInteractiveTUI(cmd)
},
}
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
rootCmd.Flags().Bool("verbose", false, "Show timings for response")
rootCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
createCmd := &cobra.Command{
Use: "create MODEL",
@@ -2044,7 +2232,7 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
config.LaunchCmd(checkServerHeartbeat),
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
)
return rootCmd

View File

@@ -3,12 +3,15 @@
package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/api"
)
type integration struct {
@@ -17,7 +20,9 @@ type integration struct {
}
type config struct {
Integrations map[string]*integration `json:"integrations"`
Integrations map[string]*integration `json:"integrations"`
LastModel string `json:"last_model,omitempty"`
LastSelection string `json:"last_selection,omitempty"` // "run" or integration name
}
func configPath() (string, error) {
@@ -146,6 +151,74 @@ func saveIntegration(appName string, models []string) error {
return save(cfg)
}
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
func IntegrationModel(appName string) string {
ic, err := loadIntegration(appName)
if err != nil || len(ic.Models) == 0 {
return ""
}
return ic.Models[0]
}
// LastModel returns the last model that was run, or empty string if none.
func LastModel() string {
cfg, err := load()
if err != nil {
return ""
}
return cfg.LastModel
}
// SetLastModel saves the last model that was run.
func SetLastModel(model string) error {
cfg, err := load()
if err != nil {
return err
}
cfg.LastModel = model
return save(cfg)
}
// LastSelection returns the last menu selection ("run" or integration name), or empty string if none.
func LastSelection() string {
cfg, err := load()
if err != nil {
return ""
}
return cfg.LastSelection
}
// SetLastSelection saves the last menu selection ("run" or integration name).
func SetLastSelection(selection string) error {
cfg, err := load()
if err != nil {
return err
}
cfg.LastSelection = selection
return save(cfg)
}
// ModelExists checks if a model exists on the Ollama server.
func ModelExists(ctx context.Context, name string) bool {
if name == "" {
return false
}
client, err := api.ClientFromEnvironment()
if err != nil {
return false
}
models, err := client.List(ctx)
if err != nil {
return false
}
for _, m := range models.Models {
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
return true
}
}
return false
}
func loadIntegration(appName string) (*integration, error) {
cfg, err := load()
if err != nil {

View File

@@ -4,9 +4,12 @@ import (
"context"
"errors"
"fmt"
"io"
"maps"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
@@ -61,7 +64,7 @@ var integrations = map[string]Runner{
// recommendedModels are shown when the user has no models or as suggestions.
// Order matters: local models first, then cloud models.
var recommendedModels = []selectItem{
var recommendedModels = []ModelItem{
{Name: "glm-4.7-flash", Description: "Recommended (requires ~25GB VRAM)"},
{Name: "qwen3:8b", Description: "Recommended (requires ~11GB VRAM)"},
{Name: "glm-4.7:cloud", Description: "Recommended"},
@@ -74,6 +77,256 @@ var integrationAliases = map[string]bool{
"moltbot": true,
}
// integrationInstallURLs maps integration names to their install script URLs.
var integrationInstallURLs = map[string]string{
"claude": "https://claude.ai/install.sh",
"openclaw": "https://openclaw.ai/install.sh",
"droid": "https://app.factory.ai/cli",
"opencode": "https://opencode.ai/install",
}
// CanInstallIntegration returns true if we have an install script for this integration.
func CanInstallIntegration(name string) bool {
_, ok := integrationInstallURLs[name]
return ok
}
// IsIntegrationInstalled checks if an integration binary is installed.
func IsIntegrationInstalled(name string) bool {
switch name {
case "claude":
c := &Claude{}
_, err := c.findPath()
return err == nil
case "openclaw":
if _, err := exec.LookPath("openclaw"); err == nil {
return true
}
if _, err := exec.LookPath("clawdbot"); err == nil {
return true
}
return false
case "codex":
_, err := exec.LookPath("codex")
return err == nil
case "droid":
_, err := exec.LookPath("droid")
return err == nil
case "opencode":
_, err := exec.LookPath("opencode")
return err == nil
default:
return true // Assume installed for unknown integrations
}
}
// InstallIntegration downloads and runs the install script for an integration.
func InstallIntegration(name string) error {
url, ok := integrationInstallURLs[name]
if !ok {
return fmt.Errorf("no install script available for %s", name)
}
// Download the install script
resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("failed to download install script: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download install script: HTTP %d", resp.StatusCode)
}
script, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read install script: %w", err)
}
// Create a temporary file for the script
tmpDir := os.TempDir()
scriptPath := filepath.Join(tmpDir, fmt.Sprintf("install-%s.sh", name))
if err := os.WriteFile(scriptPath, script, 0o700); err != nil {
return fmt.Errorf("failed to write install script: %w", err)
}
defer os.Remove(scriptPath)
// Execute the script with bash
cmd := exec.Command("bash", scriptPath)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("install script failed: %w", err)
}
return nil
}
// SelectModel lets the user select a model to run.
// ModelItem represents a model for selection.
type ModelItem struct {
Name string
Description string
}
// SingleSelector is a function type for single item selection.
type SingleSelector func(title string, items []ModelItem) (string, error)
// MultiSelector is a function type for multi item selection.
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
// SelectModelWithSelector prompts the user to select a model using the provided selector.
func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (string, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return "", err
}
models, err := client.List(ctx)
if err != nil {
return "", err
}
var existing []modelInfo
for _, m := range models.Models {
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
}
lastModel := LastModel()
var preChecked []string
if lastModel != "" {
preChecked = []string{lastModel}
}
items, _, existingModels, cloudModels := buildModelList(existing, preChecked, lastModel)
if len(items) == 0 {
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
// Sort with last model first, then existing models, then recommendations
slices.SortStableFunc(items, func(a, b ModelItem) int {
aIsLast := a.Name == lastModel
bIsLast := b.Name == lastModel
if aIsLast != bIsLast {
if aIsLast {
return -1
}
return 1
}
aExists := existingModels[a.Name]
bExists := existingModels[b.Name]
if aExists != bExists {
if aExists {
return -1
}
return 1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
selected, err := selector("Select model to run:", items)
if err != nil {
return "", err
}
// If the selected model isn't installed, pull it first
if !existingModels[selected] {
msg := fmt.Sprintf("Download %s?", selected)
if ok, err := confirmPrompt(msg); err != nil {
return "", err
} else if !ok {
return "", errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
if err := pullModel(ctx, client, selected); err != nil {
return "", fmt.Errorf("failed to pull %s: %w", selected, err)
}
}
// If it's a cloud model, ensure user is signed in
if cloudModels[selected] {
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return selected, nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return "", err
}
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", selected))
if err != nil || !yes {
return "", fmt.Errorf("%s requires sign in", selected)
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
// Auto-open browser (best effort, fail silently)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return "", ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return selected, nil
}
}
}
}
}
return selected, nil
}
func SelectModel(ctx context.Context) (string, error) {
return SelectModelWithSelector(ctx, defaultSingleSelector)
}
func defaultSingleSelector(title string, items []ModelItem) (string, error) {
selectItems := make([]selectItem, len(items))
for i, item := range items {
selectItems[i] = selectItem(item)
}
return selectPrompt(title, selectItems)
}
func defaultMultiSelector(title string, items []ModelItem, preChecked []string) ([]string, error) {
selectItems := make([]selectItem, len(items))
for i, item := range items {
selectItems[i] = selectItem(item)
}
return multiSelectPrompt(title, selectItems, preChecked)
}
func selectIntegration() (string, error) {
if len(integrations) == 0 {
return "", fmt.Errorf("no integrations available")
@@ -96,8 +349,8 @@ func selectIntegration() (string, error) {
return selectPrompt("Select integration:", items)
}
// selectModels lets the user select models for an integration
func selectModels(ctx context.Context, name, current string) ([]string, error) {
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
func selectModelsWithSelectors(ctx context.Context, name, current string, single SingleSelector, multi MultiSelector) ([]string, error) {
r, ok := integrations[name]
if !ok {
return nil, fmt.Errorf("unknown integration: %s", name)
@@ -133,7 +386,7 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
var selected []string
if _, ok := r.(Editor); ok {
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
selected, err = multi(fmt.Sprintf("Select models for %s:", r), items, preChecked)
if err != nil {
return nil, err
}
@@ -142,7 +395,7 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
if _, ok := r.(AliasConfigurer); ok {
prompt = fmt.Sprintf("Select Primary model for %s:", r)
}
model, err := selectPrompt(prompt, items)
model, err := single(prompt, items)
if err != nil {
return nil, err
}
@@ -227,12 +480,17 @@ func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]
})
}
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
modelItems, _, existingModels, cloudModels := buildModelList(existing, nil, "")
if len(items) == 0 {
if len(modelItems) == 0 {
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
items := make([]selectItem, len(modelItems))
for i, mi := range modelItems {
items[i] = selectItem(mi)
}
return items, existingModels, cloudModels, client, nil
}
@@ -303,6 +561,11 @@ func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]
}
}
// selectModels lets the user select models for an integration using default selectors.
func selectModels(ctx context.Context, name, current string) ([]string, error) {
return selectModelsWithSelectors(ctx, name, current, defaultSingleSelector, defaultMultiSelector)
}
func runIntegration(name, modelName string, args []string) error {
r, ok := integrations[name]
if !ok {
@@ -335,15 +598,110 @@ func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, na
return saveAliases(name, aliases)
}
// LaunchIntegration launches the named integration using saved config or prompts for setup.
func LaunchIntegration(name string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
// Try to use saved config
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0], nil)
}
// No saved config - prompt user to run setup
return fmt.Errorf("%s is not configured. Run 'ollama launch %s' to set it up", r, name)
}
// LaunchIntegrationWithModel launches the named integration with the specified model.
func LaunchIntegrationWithModel(name, modelName string) error {
return runIntegration(name, modelName, nil)
}
// SaveIntegrationModel saves the model for an integration.
func SaveIntegrationModel(name, modelName string) error {
// Load existing models and prepend the new one
var models []string
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
models = existing.Models
// Remove the model if it already exists
for i, m := range models {
if m == modelName {
models = append(models[:i], models[i+1:]...)
break
}
}
}
// Prepend the new model
models = append([]string{modelName}, models...)
return saveIntegration(name, models)
}
// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors.
func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
models, err := selectModelsWithSelectors(ctx, name, "", single, multi)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if editor, isEditor := r.(Editor); isEditor {
paths := editor.Paths()
if len(paths) > 0 {
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
for _, p := range paths {
fmt.Fprintf(os.Stderr, " %s\n", p)
}
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
if ok, _ := confirmPrompt("Proceed?"); !ok {
return nil
}
}
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
}
if err := saveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
if len(models) == 1 {
fmt.Fprintf(os.Stderr, "Configured %s with %s\n", r, models[0])
} else {
fmt.Fprintf(os.Stderr, "Configured %s with %d models (default: %s)\n", r, len(models), models[0])
}
return nil
}
// ConfigureIntegration allows the user to select/change the model for an integration.
func ConfigureIntegration(ctx context.Context, name string) error {
return ConfigureIntegrationWithSelectors(ctx, name, defaultSingleSelector, defaultMultiSelector)
}
// LaunchCmd returns the cobra command for launching integrations.
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
// The runTUI callback is called when no arguments are provided (alias for main TUI).
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
var modelFlag string
var configFlag bool
cmd := &cobra.Command{
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
Short: "Launch an integration with Ollama",
Long: `Launch an integration configured with Ollama models.
Short: "Launch the Ollama menu or an integration",
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
Without arguments, this is equivalent to running 'ollama' directly.
Supported integrations:
claude Claude Code
@@ -362,6 +720,12 @@ Examples:
Args: cobra.ArbitraryArgs,
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
// No args - run the main TUI (same as 'ollama')
if len(args) == 0 && modelFlag == "" && !configFlag {
runTUI(cmd)
return nil
}
// Extract integration name and args to pass through using -- separator
var name string
var passArgs []string
@@ -582,7 +946,7 @@ type modelInfo struct {
// buildModelList merges existing models with recommendations, sorts them, and returns
// the ordered items along with maps of existing and cloud model names.
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []selectItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
existingModels = make(map[string]bool)
cloudModels = make(map[string]bool)
recommended := make(map[string]bool)
@@ -602,7 +966,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
}
displayName := strings.TrimSuffix(m.Name, ":latest")
existingModels[displayName] = true
item := selectItem{Name: displayName}
item := ModelItem{Name: displayName}
if recommended[displayName] {
item.Description = "recommended"
}
@@ -651,7 +1015,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
}
if hasLocalModel || hasCloudModel {
slices.SortStableFunc(items, func(a, b selectItem) int {
slices.SortStableFunc(items, func(a, b ModelItem) int {
ac, bc := checked[a.Name], checked[b.Name]
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
@@ -686,6 +1050,56 @@ func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
return resp.RemoteModel != ""
}
// GetModelItems returns a list of model items including recommendations for the TUI.
// It includes all locally available models plus recommended models that aren't installed.
func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, nil
}
models, err := client.List(ctx)
if err != nil {
return nil, nil
}
var existing []modelInfo
for _, m := range models.Models {
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
}
lastModel := LastModel()
var preChecked []string
if lastModel != "" {
preChecked = []string{lastModel}
}
items, _, existingModels, _ := buildModelList(existing, preChecked, lastModel)
// Sort with last model first, then existing models, then recommendations
slices.SortStableFunc(items, func(a, b ModelItem) int {
aIsLast := a.Name == lastModel
bIsLast := b.Name == lastModel
if aIsLast != bIsLast {
if aIsLast {
return -1
}
return 1
}
aExists := existingModels[a.Name]
bExists := existingModels[b.Name]
if aExists != bExists {
if aExists {
return -1
}
return 1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
return items, existingModels
}
func pullModel(ctx context.Context, client *api.Client, model string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()

View File

@@ -94,8 +94,10 @@ func TestLaunchCmd(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
// Mock TUI function (not called in these tests)
mockTUI := func(cmd *cobra.Command) {}
cmd := LaunchCmd(mockCheck)
cmd := LaunchCmd(mockCheck, mockTUI)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
@@ -128,6 +130,75 @@ func TestLaunchCmd(t *testing.T) {
})
}
func TestLaunchCmd_TUICallback(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
t.Run("no args calls TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{})
_ = cmd.Execute()
if !tuiCalled {
t.Error("TUI callback should be called when no args provided")
}
})
t.Run("integration arg bypasses TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"claude"})
// Will error because claude isn't configured, but that's OK
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when integration arg provided")
}
})
t.Run("--model flag bypasses TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--model", "test-model"})
// Will error because no integration specified, but that's OK
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when --model flag provided")
}
})
t.Run("--config flag bypasses TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--config"})
// Will error because no integration specified, but that's OK
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when --config flag provided")
}
})
}
func TestRunIntegration_UnknownIntegration(t *testing.T) {
err := runIntegration("unknown-integration", "model", nil)
if err == nil {
@@ -168,7 +239,7 @@ func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
// This should not panic - cmd creation should work even with nil
cmd := LaunchCmd(nil)
cmd := LaunchCmd(nil, nil)
if cmd == nil {
t.Fatal("LaunchCmd returned nil")
}
@@ -314,7 +385,7 @@ func TestIsCloudModel(t *testing.T) {
})
}
func names(items []selectItem) []string {
func names(items []ModelItem) []string {
var out []string
for _, item := range items {
out = append(out, item.Name)

507
cmd/tui/selector.go Normal file
View File

@@ -0,0 +1,507 @@
package tui
import (
"errors"
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
var (
selectorTitleStyle = lipgloss.NewStyle().
Bold(true)
selectorItemStyle = lipgloss.NewStyle().
PaddingLeft(4)
selectorSelectedItemStyle = lipgloss.NewStyle().
PaddingLeft(2).
Bold(true)
selectorDescStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241"))
selectorFilterStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
Italic(true)
selectorInputStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("252"))
selectorCheckboxStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241"))
selectorCheckboxCheckedStyle = lipgloss.NewStyle().
Bold(true)
selectorDefaultTagStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
Italic(true)
selectorHelpStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241"))
selectorMoreStyle = lipgloss.NewStyle().
PaddingLeft(4).
Foreground(lipgloss.Color("241")).
Italic(true)
)
const maxSelectorItems = 10
// ErrCancelled is returned when the user cancels the selection.
var ErrCancelled = errors.New("cancelled")
// SelectItem represents an item that can be selected.
type SelectItem struct {
Name string
Description string
}
// selectorModel is the bubbletea model for single selection.
type selectorModel struct {
title string
items []SelectItem
filter string
cursor int
scrollOffset int
selected string
cancelled bool
}
func (m selectorModel) filteredItems() []SelectItem {
if m.filter == "" {
return m.items
}
filterLower := strings.ToLower(m.filter)
var result []SelectItem
for _, item := range m.items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}
func (m selectorModel) Init() tea.Cmd {
return nil
}
func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
filtered := m.filteredItems()
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
case tea.KeyEnter:
if len(filtered) > 0 && m.cursor < len(filtered) {
m.selected = filtered[m.cursor].Name
}
return m, tea.Quit
case tea.KeyUp:
if m.cursor > 0 {
m.cursor--
if m.cursor < m.scrollOffset {
m.scrollOffset = m.cursor
}
}
case tea.KeyDown:
if m.cursor < len(filtered)-1 {
m.cursor++
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
}
case tea.KeyPgUp:
m.cursor -= maxSelectorItems
if m.cursor < 0 {
m.cursor = 0
}
m.scrollOffset -= maxSelectorItems
if m.scrollOffset < 0 {
m.scrollOffset = 0
}
case tea.KeyPgDown:
m.cursor += maxSelectorItems
if m.cursor >= len(filtered) {
m.cursor = len(filtered) - 1
}
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
case tea.KeyBackspace:
if len(m.filter) > 0 {
m.filter = m.filter[:len(m.filter)-1]
m.cursor = 0
m.scrollOffset = 0
}
case tea.KeyRunes:
m.filter += string(msg.Runes)
m.cursor = 0
m.scrollOffset = 0
}
}
return m, nil
}
func (m selectorModel) View() string {
// Clear screen when exiting
if m.cancelled || m.selected != "" {
return ""
}
var s strings.Builder
// Title with filter
s.WriteString(selectorTitleStyle.Render(m.title))
s.WriteString(" ")
if m.filter == "" {
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
} else {
s.WriteString(selectorInputStyle.Render(m.filter))
}
s.WriteString("\n\n")
filtered := m.filteredItems()
if len(filtered) == 0 {
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
s.WriteString("\n")
} else {
displayCount := min(len(filtered), maxSelectorItems)
for i := range displayCount {
idx := m.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
} else {
s.WriteString(selectorItemStyle.Render(item.Name))
}
if item.Description != "" {
s.WriteString(" ")
s.WriteString(selectorDescStyle.Render("- " + item.Description))
}
s.WriteString("\n")
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
s.WriteString("\n")
}
}
s.WriteString("\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • esc cancel"))
return s.String()
}
// SelectSingle prompts the user to select a single item from a list.
func SelectSingle(title string, items []SelectItem) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
m := selectorModel{
title: title,
items: items,
}
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return "", fmt.Errorf("error running selector: %w", err)
}
fm := finalModel.(selectorModel)
if fm.cancelled {
return "", ErrCancelled
}
return fm.selected, nil
}
// multiSelectorModel is the bubbletea model for multi selection.
type multiSelectorModel struct {
title string
items []SelectItem
itemIndex map[string]int
filter string
cursor int
scrollOffset int
checked map[int]bool
checkOrder []int
cancelled bool
confirmed bool
}
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
m := multiSelectorModel{
title: title,
items: items,
itemIndex: make(map[string]int, len(items)),
checked: make(map[int]bool),
}
for i, item := range items {
m.itemIndex[item.Name] = i
}
for _, name := range preChecked {
if idx, ok := m.itemIndex[name]; ok {
m.checked[idx] = true
m.checkOrder = append(m.checkOrder, idx)
}
}
return m
}
func (m multiSelectorModel) filteredItems() []SelectItem {
if m.filter == "" {
return m.items
}
filterLower := strings.ToLower(m.filter)
var result []SelectItem
for _, item := range m.items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}
func (m *multiSelectorModel) toggleItem() {
filtered := m.filteredItems()
if len(filtered) == 0 || m.cursor >= len(filtered) {
return
}
item := filtered[m.cursor]
origIdx := m.itemIndex[item.Name]
if m.checked[origIdx] {
delete(m.checked, origIdx)
for i, idx := range m.checkOrder {
if idx == origIdx {
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
break
}
}
} else {
m.checked[origIdx] = true
m.checkOrder = append(m.checkOrder, origIdx)
}
}
func (m multiSelectorModel) selectedCount() int {
return len(m.checkOrder)
}
func (m multiSelectorModel) Init() tea.Cmd {
return nil
}
func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
filtered := m.filteredItems()
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
case tea.KeyEnter:
// Enter confirms if at least one item is selected
if len(m.checkOrder) > 0 {
m.confirmed = true
return m, tea.Quit
}
case tea.KeySpace:
// Space always toggles selection
m.toggleItem()
case tea.KeyUp:
if m.cursor > 0 {
m.cursor--
if m.cursor < m.scrollOffset {
m.scrollOffset = m.cursor
}
}
case tea.KeyDown:
if m.cursor < len(filtered)-1 {
m.cursor++
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
}
case tea.KeyPgUp:
m.cursor -= maxSelectorItems
if m.cursor < 0 {
m.cursor = 0
}
m.scrollOffset -= maxSelectorItems
if m.scrollOffset < 0 {
m.scrollOffset = 0
}
case tea.KeyPgDown:
m.cursor += maxSelectorItems
if m.cursor >= len(filtered) {
m.cursor = len(filtered) - 1
}
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
case tea.KeyBackspace:
if len(m.filter) > 0 {
m.filter = m.filter[:len(m.filter)-1]
m.cursor = 0
m.scrollOffset = 0
}
case tea.KeyRunes:
m.filter += string(msg.Runes)
m.cursor = 0
m.scrollOffset = 0
}
}
return m, nil
}
func (m multiSelectorModel) View() string {
// Clear screen when exiting
if m.cancelled || m.confirmed {
return ""
}
var s strings.Builder
// Title with filter
s.WriteString(selectorTitleStyle.Render(m.title))
s.WriteString(" ")
if m.filter == "" {
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
} else {
s.WriteString(selectorInputStyle.Render(m.filter))
}
s.WriteString("\n\n")
filtered := m.filteredItems()
if len(filtered) == 0 {
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
s.WriteString("\n")
} else {
displayCount := min(len(filtered), maxSelectorItems)
for i := range displayCount {
idx := m.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
origIdx := m.itemIndex[item.Name]
// Checkbox
var checkbox string
if m.checked[origIdx] {
checkbox = selectorCheckboxCheckedStyle.Render("[x]")
} else {
checkbox = selectorCheckboxStyle.Render("[ ]")
}
// Cursor and name
var line string
if idx == m.cursor {
line = selectorSelectedItemStyle.Render("▸ ") + checkbox + " " + selectorSelectedItemStyle.Render(item.Name)
} else {
line = " " + checkbox + " " + item.Name
}
// Default tag
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
line += " " + selectorDefaultTagStyle.Render("(default)")
}
s.WriteString(line)
s.WriteString("\n")
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
s.WriteString("\n")
}
}
s.WriteString("\n")
// Status line
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
} else {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
return s.String()
}
// SelectMultiple prompts the user to select multiple items from a list.
func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) {
if len(items) == 0 {
return nil, fmt.Errorf("no items to select from")
}
m := newMultiSelectorModel(title, items, preChecked)
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return nil, fmt.Errorf("error running selector: %w", err)
}
fm := finalModel.(multiSelectorModel)
if fm.cancelled {
return nil, ErrCancelled
}
if !fm.confirmed {
return nil, ErrCancelled
}
var result []string
for _, idx := range fm.checkOrder {
result = append(result, fm.items[idx].Name)
}
return result, nil
}

731
cmd/tui/tui.go Normal file
View File

@@ -0,0 +1,731 @@
package tui
import (
"context"
"errors"
"fmt"
"os/exec"
"runtime"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/version"
)
var (
titleStyle = lipgloss.NewStyle().
Bold(true).
MarginBottom(1)
versionStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("245"))
itemStyle = lipgloss.NewStyle().
PaddingLeft(2)
selectedStyle = lipgloss.NewStyle().
PaddingLeft(2).
Bold(true)
greyedStyle = lipgloss.NewStyle().
PaddingLeft(2).
Foreground(lipgloss.Color("241"))
greyedSelectedStyle = lipgloss.NewStyle().
PaddingLeft(2).
Foreground(lipgloss.Color("243"))
descStyle = lipgloss.NewStyle().
PaddingLeft(4).
Foreground(lipgloss.Color("241"))
modelStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("245"))
notInstalledStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
Italic(true)
)
type menuItem struct {
title string
description string
integration string // integration name for loading model config, empty if not an integration
isRunModel bool // true for the "Run a model" option
isOthers bool // true for the "Others..." toggle item
}
var mainMenuItems = []menuItem{
{
title: "Run a model",
description: "Start an interactive chat with a local model",
isRunModel: true,
},
{
title: "Launch Claude Code",
description: "Open Claude Code AI assistant",
integration: "claude",
},
{
title: "Launch Codex",
description: "Open Codex CLI",
integration: "codex",
},
{
title: "Launch Open Claw",
description: "Open the Open Claw integration",
integration: "openclaw",
},
}
var othersMenuItem = menuItem{
title: "Others...",
description: "Show additional integrations",
isOthers: true,
}
// getOtherIntegrations returns the list of other integrations, filtering out
// Codex if it's not installed (since it requires npm install).
func getOtherIntegrations() []menuItem {
return []menuItem{
{
title: "Launch Droid",
description: "Open Droid integration",
integration: "droid",
},
{
title: "Launch Open Code",
description: "Open Open Code integration",
integration: "opencode",
},
}
}
type model struct {
items []menuItem
cursor int
quitting bool
selected bool // true if user made a selection (enter/space)
changeModel bool // true if user pressed right arrow to change model
showOthers bool // true if "Others..." is expanded
availableModels map[string]bool // cache of available model names
err error
// Modal state
showingModal bool // true when model picker modal is visible
modalSelector selectorModel // the selector model for the modal
modalItems []SelectItem // cached items for the modal
// Sign-in dialog state
showingSignIn bool // true when sign-in dialog is visible
signInURL string // URL for sign-in
signInModel string // model that requires sign-in
signInSpinner int // spinner frame index
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
}
// signInTickMsg is sent to animate the sign-in spinner
type signInTickMsg struct{}
// signInCheckMsg is sent to check if sign-in is complete
type signInCheckMsg struct {
signedIn bool
userName string
}
// modelExists checks if a model exists in the cached available models.
func (m *model) modelExists(name string) bool {
if m.availableModels == nil || name == "" {
return false
}
if m.availableModels[name] {
return true
}
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
for modelName := range m.availableModels {
if strings.HasPrefix(modelName, name+":") {
return true
}
}
return false
}
// buildModalItems creates the list of models for the modal selector.
func (m *model) buildModalItems() []SelectItem {
modelItems, _ := config.GetModelItems(context.Background())
var items []SelectItem
for _, item := range modelItems {
items = append(items, SelectItem{Name: item.Name, Description: item.Description})
}
return items
}
// openModelModal opens the model picker modal.
func (m *model) openModelModal() {
m.modalItems = m.buildModalItems()
m.modalSelector = selectorModel{
title: "Select model:",
items: m.modalItems,
}
m.showingModal = true
}
// isCloudModel returns true if the model name indicates a cloud model.
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud")
}
// checkCloudSignIn checks if a cloud model needs sign-in.
// Returns a command to start sign-in if needed, or nil if already signed in.
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
if modelName == "" || !isCloudModel(modelName) {
return nil
}
client, err := api.ClientFromEnvironment()
if err != nil {
return nil
}
user, err := client.Whoami(context.Background())
if err == nil && user != nil && user.Name != "" {
return nil // Already signed in
}
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.SigninURL != "" {
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
}
return nil
}
// startSignIn initiates the sign-in flow for a cloud model.
// fromModal indicates if this was triggered from the model picker modal.
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
m.showingModal = false
m.showingSignIn = true
m.signInURL = signInURL
m.signInModel = modelName
m.signInSpinner = 0
m.signInFromModal = fromModal
// Open browser (best effort)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", signInURL).Start()
case "linux":
_ = exec.Command("xdg-open", signInURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", signInURL).Start()
}
// Start the spinner tick
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
}
// checkSignIn checks if the user has completed sign-in.
func checkSignIn() tea.Msg {
client, err := api.ClientFromEnvironment()
if err != nil {
return signInCheckMsg{signedIn: false}
}
user, err := client.Whoami(context.Background())
if err == nil && user != nil && user.Name != "" {
return signInCheckMsg{signedIn: true, userName: user.Name}
}
return signInCheckMsg{signedIn: false}
}
// loadAvailableModels fetches and caches the list of available models.
func (m *model) loadAvailableModels() {
m.availableModels = make(map[string]bool)
client, err := api.ClientFromEnvironment()
if err != nil {
return
}
models, err := client.List(context.Background())
if err != nil {
return
}
for _, mdl := range models.Models {
m.availableModels[mdl.Name] = true
}
}
func (m *model) buildItems() {
others := getOtherIntegrations()
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
m.items = append(m.items, mainMenuItems...)
if m.showOthers {
// Change "Others..." to "Hide others..."
hideItem := menuItem{
title: "Hide others...",
description: "Hide additional integrations",
isOthers: true,
}
m.items = append(m.items, hideItem)
m.items = append(m.items, others...)
} else {
m.items = append(m.items, othersMenuItem)
}
}
// isOthersIntegration returns true if the integration is in the "Others" menu
func isOthersIntegration(name string) bool {
switch name {
case "droid", "opencode":
return true
}
return false
}
func initialModel() model {
m := model{
cursor: 0,
}
m.loadAvailableModels()
// Check last selection to determine if we need to expand "Others"
lastSelection := config.LastSelection()
if isOthersIntegration(lastSelection) {
m.showOthers = true
}
m.buildItems()
// Position cursor on last selection
if lastSelection != "" {
for i, item := range m.items {
if lastSelection == "run" && item.isRunModel {
m.cursor = i
break
} else if item.integration == lastSelection {
m.cursor = i
break
}
}
}
return m
}
func (m model) Init() tea.Cmd {
return nil
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Handle sign-in dialog
if m.showingSignIn {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
// Cancel sign-in and go back
m.showingSignIn = false
if m.signInFromModal {
m.showingModal = true
}
// If from main menu, just return to main menu (default state)
return m, nil
}
case signInTickMsg:
m.signInSpinner++
// Check sign-in status every 5th tick (~1 second)
if m.signInSpinner%5 == 0 {
return m, tea.Batch(
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
}),
checkSignIn,
)
}
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
case signInCheckMsg:
if msg.signedIn {
// Sign-in complete - proceed with selection
if m.signInFromModal {
// Came from modal - set changeModel
m.modalSelector.selected = m.signInModel
m.changeModel = true
} else {
// Came from main menu - just select
m.selected = true
}
m.quitting = true
return m, tea.Quit
}
}
return m, nil
}
// Handle modal input if modal is showing
if m.showingModal {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
// Close modal without selection
m.showingModal = false
return m, nil
case tea.KeyEnter:
filtered := m.modalSelector.filteredItems()
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
}
if m.modalSelector.selected != "" {
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
return m, cmd
}
// Selection made - exit with changeModel
m.changeModel = true
m.quitting = true
return m, tea.Quit
}
return m, nil
case tea.KeyUp:
if m.modalSelector.cursor > 0 {
m.modalSelector.cursor--
if m.modalSelector.cursor < m.modalSelector.scrollOffset {
m.modalSelector.scrollOffset = m.modalSelector.cursor
}
}
case tea.KeyDown:
filtered := m.modalSelector.filteredItems()
if m.modalSelector.cursor < len(filtered)-1 {
m.modalSelector.cursor++
if m.modalSelector.cursor >= m.modalSelector.scrollOffset+maxSelectorItems {
m.modalSelector.scrollOffset = m.modalSelector.cursor - maxSelectorItems + 1
}
}
case tea.KeyPgUp:
filtered := m.modalSelector.filteredItems()
m.modalSelector.cursor -= maxSelectorItems
if m.modalSelector.cursor < 0 {
m.modalSelector.cursor = 0
}
m.modalSelector.scrollOffset -= maxSelectorItems
if m.modalSelector.scrollOffset < 0 {
m.modalSelector.scrollOffset = 0
}
_ = filtered // suppress unused warning
case tea.KeyPgDown:
filtered := m.modalSelector.filteredItems()
m.modalSelector.cursor += maxSelectorItems
if m.modalSelector.cursor >= len(filtered) {
m.modalSelector.cursor = len(filtered) - 1
}
if m.modalSelector.cursor >= m.modalSelector.scrollOffset+maxSelectorItems {
m.modalSelector.scrollOffset = m.modalSelector.cursor - maxSelectorItems + 1
}
case tea.KeyBackspace:
if len(m.modalSelector.filter) > 0 {
m.modalSelector.filter = m.modalSelector.filter[:len(m.modalSelector.filter)-1]
m.modalSelector.cursor = 0
m.modalSelector.scrollOffset = 0
}
case tea.KeyRunes:
m.modalSelector.filter += string(msg.Runes)
m.modalSelector.cursor = 0
m.modalSelector.scrollOffset = 0
}
}
return m, nil
}
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c", "q", "esc":
m.quitting = true
return m, tea.Quit
case "up", "k":
if m.cursor > 0 {
m.cursor--
}
case "down", "j":
if m.cursor < len(m.items)-1 {
m.cursor++
}
case "enter", " ":
item := m.items[m.cursor]
// Handle "Others..." toggle
if item.isOthers {
m.showOthers = !m.showOthers
m.buildItems()
// Keep cursor on the Others/Hide item
if m.cursor >= len(m.items) {
m.cursor = len(m.items) - 1
}
return m, nil
}
// Don't allow selecting uninstalled integrations
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
return m, nil
}
// Check if a cloud model is configured and needs sign-in
var configuredModel string
if item.isRunModel {
configuredModel = config.LastModel()
} else if item.integration != "" {
configuredModel = config.IntegrationModel(item.integration)
}
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
return m, cmd
}
m.selected = true
m.quitting = true
return m, tea.Quit
case "right", "l":
// Allow model change for integrations and run model
item := m.items[m.cursor]
if item.integration != "" || item.isRunModel {
// Don't allow for uninstalled integrations
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
return m, nil
}
m.openModelModal()
}
}
}
return m, nil
}
func (m model) View() string {
if m.quitting {
return ""
}
// Render sign-in dialog if showing
if m.showingSignIn {
return m.renderSignInDialog()
}
// Render modal overlay if showing - replaces main view
if m.showingModal {
return m.renderModal()
}
s := titleStyle.Render(" Ollama "+versionStyle.Render("v"+version.Version)) + "\n\n"
for i, item := range m.items {
cursor := " "
style := itemStyle
isInstalled := true
if item.integration != "" {
isInstalled = config.IsIntegrationInstalled(item.integration)
}
if m.cursor == i {
cursor = "▸ "
if isInstalled {
style = selectedStyle
} else {
style = greyedSelectedStyle
}
} else if !isInstalled && item.integration != "" {
style = greyedStyle
}
title := item.title
if item.integration != "" {
if !isInstalled {
title += " " + notInstalledStyle.Render("(not installed)")
} else if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
title += " " + modelStyle.Render("("+mdl+")")
}
} else if item.isRunModel {
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
title += " " + modelStyle.Render("("+mdl+")")
}
}
s += style.Render(cursor+title) + "\n"
s += descStyle.Render(item.description) + "\n\n"
}
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Render("↑/↓ navigate • enter select • → change model • esc quit")
return s
}
// renderModal renders the model picker modal.
func (m model) renderModal() string {
modalStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("245")).
Padding(1, 2).
MarginLeft(2)
var content strings.Builder
// Title with filter
content.WriteString(selectorTitleStyle.Render(m.modalSelector.title))
content.WriteString(" ")
if m.modalSelector.filter == "" {
content.WriteString(selectorFilterStyle.Render("Type to filter..."))
} else {
content.WriteString(selectorInputStyle.Render(m.modalSelector.filter))
}
content.WriteString("\n\n")
filtered := m.modalSelector.filteredItems()
if len(filtered) == 0 {
content.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
content.WriteString("\n")
} else {
displayCount := min(len(filtered), maxSelectorItems)
for i := range displayCount {
idx := m.modalSelector.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
if idx == m.modalSelector.cursor {
content.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
} else {
content.WriteString(selectorItemStyle.Render(item.Name))
}
if item.Description != "" {
content.WriteString(" ")
content.WriteString(selectorDescStyle.Render("- " + item.Description))
}
content.WriteString("\n")
}
if remaining := len(filtered) - m.modalSelector.scrollOffset - displayCount; remaining > 0 {
content.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
content.WriteString("\n")
}
}
content.WriteString("\n")
content.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • esc cancel"))
return modalStyle.Render(content.String())
}
// renderSignInDialog renders the sign-in dialog.
func (m model) renderSignInDialog() string {
dialogStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("245")).
Padding(1, 2).
MarginLeft(2)
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
spinner := spinnerFrames[m.signInSpinner%len(spinnerFrames)]
var content strings.Builder
content.WriteString(selectorTitleStyle.Render("Sign in required"))
content.WriteString("\n\n")
content.WriteString(fmt.Sprintf("To use %s, please sign in.\n\n", selectedStyle.Render(m.signInModel)))
content.WriteString("Navigate to:\n")
content.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("117")).Render(" " + m.signInURL))
content.WriteString("\n\n")
content.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Render(
fmt.Sprintf("%s Waiting for sign in to complete...", spinner)))
content.WriteString("\n\n")
content.WriteString(selectorHelpStyle.Render("esc cancel"))
return dialogStyle.Render(content.String())
}
// Selection represents what the user selected
type Selection int
const (
SelectionNone Selection = iota
SelectionRunModel
SelectionChangeRunModel
SelectionIntegration // Generic integration selection
SelectionChangeIntegration // Generic change model for integration
)
// Result contains the selection and any associated data
type Result struct {
Selection Selection
Integration string // integration name if applicable
Model string // model name if selected from modal
}
// Run starts the TUI and returns the user's selection
func Run() (Result, error) {
m := initialModel()
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
}
fm := finalModel.(model)
if fm.err != nil {
return Result{Selection: SelectionNone}, fm.err
}
// User quit without selecting
if !fm.selected && !fm.changeModel {
return Result{Selection: SelectionNone}, nil
}
item := fm.items[fm.cursor]
// Handle model change request
if fm.changeModel {
if item.isRunModel {
return Result{
Selection: SelectionChangeRunModel,
Model: fm.modalSelector.selected,
}, nil
}
return Result{
Selection: SelectionChangeIntegration,
Integration: item.integration,
Model: fm.modalSelector.selected,
}, nil
}
// Handle selection
if item.isRunModel {
return Result{Selection: SelectionRunModel}, nil
}
return Result{
Selection: SelectionIntegration,
Integration: item.integration,
}, nil
}

23
go.mod
View File

@@ -13,7 +13,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.24
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.10.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
@@ -21,16 +21,16 @@ require (
require (
github.com/agnivade/levenshtein v1.1.1
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/mattn/go-runewidth v0.0.14
github.com/mattn/go-runewidth v0.0.16
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/tree-sitter/go-tree-sitter v0.25.0
github.com/tree-sitter/tree-sitter-cpp v0.23.4
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0
@@ -40,23 +40,34 @@ require (
require (
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/ansi v0.10.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.11.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-pointer v0.0.1 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/tkrajina/go-reflector v0.5.5 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect

67
go.sum
View File

@@ -14,6 +14,8 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
@@ -24,6 +26,18 @@ github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
@@ -59,6 +73,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
@@ -148,15 +164,17 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -164,6 +182,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
@@ -184,8 +208,9 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
@@ -208,39 +233,12 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
@@ -249,6 +247,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -335,6 +335,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -4,7 +4,6 @@ import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
func Execute(args []string) error {
@@ -18,8 +17,6 @@ func Execute(args []string) error {
return ollamarunner.Execute(args[1:])
case "--imagegen-engine":
return imagegen.Execute(args[1:])
case "--mlx-engine":
return mlxrunner.Execute(args[1:])
}
}
return llamarunner.Execute(args)

View File

@@ -5,13 +5,9 @@ import (
"errors"
"fmt"
"log/slog"
"math/rand"
"os"
"os/exec"
"reflect"
"slices"
"sort"
"strconv"
"strings"
"sync"
"time"
@@ -26,7 +22,6 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
type LlmRequest struct {
@@ -200,14 +195,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if s.loadSafetensors(pending) {
// Check for image generation models - all use MLX runner
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadMLX(pending) {
break
}
continue
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) {
break
}
continue
}
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -557,90 +563,9 @@ iGPUScan:
return false
}
func subproc(args, environ []string) (*exec.Cmd, int, error) {
exe, err := os.Executable()
if err != nil {
return nil, 0, fmt.Errorf("unable to lookup executable path: %w", err)
}
for range 3 {
// get a random port in the ephemeral range
port := rand.Intn(65535-49152) + 49152
cmd := exec.Command(exe, slices.Concat([]string{"runner"}, args, []string{"--port", strconv.Itoa(port)})...)
cmd.Env = slices.Concat(os.Environ(), environ)
cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
continue
}
return cmd, port, nil
}
return nil, 0, fmt.Errorf("unable to start subprocess after multiple attempts")
}
func (s *Scheduler) loadSafetensors(req *LlmRequest) bool {
if slices.Contains(req.model.Config.Capabilities, "image") {
return s.loadImageGen(req)
}
args := []string{"--mlx-engine", "--model", req.model.ShortName}
environ := []string{}
cmd, port, err := subproc(args, environ)
if err != nil {
req.errCh <- fmt.Errorf("failed to start mlx subprocess: %w", err)
return true
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
llama: &mlxrunner.Client{
Cmd: cmd,
Port: port,
},
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loadedMu.Unlock()
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
s.expiredCh <- runner
})
}
runner.refMu.Unlock()
req.useLoadedRunner(runner, s.finishedReqCh)
for range time.Tick(20 * time.Millisecond) {
if err := func() error {
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
return runner.llama.Ping(ctx)
}(); err != nil {
continue
}
break
}
return true
}
// loadImageGen loads an experimental safetensors model using the unified MLX runner.
// loadMLX loads an experimental safetensors model using the unified MLX runner.
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode imagegen.ModelMode
if slices.Contains(req.model.Config.Capabilities, "image") {

View File

@@ -1,14 +1,5 @@
package tokenizer
import (
"encoding/json"
"errors"
"io"
"os"
"github.com/ollama/ollama/types/model"
)
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
@@ -24,287 +15,3 @@ type Tokenizer interface {
Is(int32, Special) bool
Vocabulary() *Vocabulary
}
func New(root *model.Root) (Tokenizer, error) {
f, err := root.Open("tokenizer.json")
if err != nil {
return nil, err
}
defer f.Close()
var tokenizer struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges json.RawMessage `json:"merges"`
} `json:"model"`
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
Decoder json.RawMessage `json:"decoder"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
if err := json.NewDecoder(f).Decode(&tokenizer); err != nil {
return nil, err
}
special := make(map[int32]struct{})
for _, token := range tokenizer.AddedTokens {
tokenizer.Model.Vocab[token.Content] = token.ID
special[token.ID] = struct{}{}
}
vocab, err := specialTokens(root, tokenizer.Model.Vocab)
if err != nil {
return nil, err
}
vocab.Values = make([]string, len(tokenizer.Model.Vocab))
vocab.Scores = make([]float32, len(tokenizer.Model.Vocab))
vocab.Types = make([]int32, len(tokenizer.Model.Vocab))
for content, id := range tokenizer.Model.Vocab {
vocab.Values[id] = content
vocab.Scores[id] = float32(id)
vocab.Types[id] = TOKEN_TYPE_NORMAL
if _, ok := special[id]; ok {
vocab.Types[id] = TOKEN_TYPE_USER_DEFINED
}
}
if tokenizer.Model.Merges != nil {
var pairs [][]string
if err := json.Unmarshal(tokenizer.Model.Merges, &pairs); err == nil {
vocab.Merges = make([]string, len(pairs))
for i, pair := range pairs {
vocab.Merges[i] = pair[0] + " " + pair[1]
}
} else if err := json.Unmarshal(tokenizer.Model.Merges, &vocab.Merges); err != nil {
return nil, err
}
}
vocab.valuesOnce.Do(func() {})
vocab.values = tokenizer.Model.Vocab
if tokenizer.Model.Type == "WordPiece" {
return NewWordPiece(vocab, true), nil
}
if tokenizer.Decoder != nil {
var decoder struct {
Type string `json:"type"`
Decoders []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"string"`
} `json:"pattern"`
} `json:"decoders"`
}
if err := json.Unmarshal(tokenizer.Decoder, &decoder); err != nil {
return nil, err
}
if decoder.Type == "Sequence" {
for _, d := range decoder.Decoders {
if d.Type == "Replace" && d.Pattern.String == "▁" {
return NewSentencePiece(vocab), nil
}
}
}
}
var pretokenizers []string
if tokenizer.PreTokenizer != nil {
var pretokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string
} `json:"pattern"`
IndividualDigits bool `json:"individual_digits"`
}
}
if err := json.Unmarshal(tokenizer.PreTokenizer, &pretokenizer); err != nil {
return nil, err
}
if pretokenizer.Type == "Sequence" {
for _, pretokenizer := range pretokenizer.Pretokenizers {
switch pretokenizer.Type {
case "Digits":
if pretokenizer.IndividualDigits {
pretokenizers = append(pretokenizers, `\d`)
} else {
pretokenizers = append(pretokenizers, `\d+`)
}
case "Punctuation":
pretokenizers = append(pretokenizers, `[^\p{L}\p{N}]+`)
case "Split":
pretokenizers = append(pretokenizers, pretokenizer.Pattern.Regex)
case "WhitespaceSplit":
pretokenizers = append(pretokenizers, `\s+(?!\S)|\s+`)
}
}
}
}
return NewBytePairEncoding(vocab, pretokenizers...), nil
}
// valueOrValues is a type that can unmarshal from either a single value or an array of values.
type valueOrValues[E any] []E
func (m *valueOrValues[E]) UnmarshalJSON(data []byte) error {
var s []E
if err := json.Unmarshal(data, &s); err != nil {
var e E
if err := json.Unmarshal(data, &e); err != nil {
return err
}
s = []E{e}
}
*m = valueOrValues[E](s)
return nil
}
type specialTokenIDs struct {
BOSTokenID valueOrValues[int32] `json:"bos_token_id"`
EOSTokenID valueOrValues[int32] `json:"eos_token_id"`
}
// stringOrContent is a type that can unmarshal from either a string or an object with a "content" field.
type stringOrContent string
func (t *stringOrContent) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return err
}
if content, ok := m["content"].(string); ok {
s = content
}
}
*t = stringOrContent(s)
return nil
}
func specialTokens(root *model.Root, values map[string]int32) (*Vocabulary, error) {
var vocab Vocabulary
for _, c := range []struct {
name string
fn func(io.Reader) error
}{
{
name: "generation_config.json",
fn: func(r io.Reader) error {
var c specialTokenIDs
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
vocab.BOS = c.BOSTokenID
vocab.EOS = c.EOSTokenID
return nil
},
},
{
name: "config.json",
fn: func(r io.Reader) error {
var c specialTokenIDs
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if len(vocab.BOS) == 0 {
vocab.BOS = c.BOSTokenID
}
if len(vocab.EOS) == 0 {
vocab.EOS = c.EOSTokenID
}
return nil
},
},
{
name: "tokenizer_config.json",
fn: func(r io.Reader) error {
var c struct {
BOSToken stringOrContent `json:"bos_token"`
EOSToken stringOrContent `json:"eos_token"`
PADToken stringOrContent `json:"pad_token"`
AddBOSToken bool `json:"add_bos_token"`
AddEOSToken bool `json:"add_eos_token"`
}
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if len(vocab.BOS) == 0 && c.BOSToken != "" {
if id, ok := values[string(c.BOSToken)]; ok {
vocab.BOS = []int32{id}
}
}
if len(vocab.EOS) == 0 && c.EOSToken != "" {
if id, ok := values[string(c.EOSToken)]; ok {
vocab.EOS = []int32{id}
}
}
vocab.AddBOS = c.AddBOSToken
vocab.AddEOS = c.AddEOSToken
return nil
},
},
{
name: "special_tokens_map.json",
fn: func(r io.Reader) error {
var c map[string]stringOrContent
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if bos, ok := c["bos_token"]; ok && len(vocab.BOS) == 0 {
if id, ok := values[string(bos)]; ok {
vocab.BOS = []int32{id}
}
}
if eos, ok := c["eos_token"]; ok && len(vocab.EOS) == 0 {
if id, ok := values[string(eos)]; ok {
vocab.EOS = []int32{id}
}
}
return nil
},
},
} {
if err := func() error {
f, err := root.Open(c.name)
if errors.Is(err, os.ErrNotExist) {
return nil
} else if err != nil {
return err
}
defer f.Close()
return c.fn(f)
}(); err != nil {
return nil, err
}
}
return &vocab, nil
}

View File

@@ -1,316 +0,0 @@
package model
import (
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"io/fs"
"iter"
"maps"
"mime"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/envconfig"
)
func root() (*os.Root, error) {
root, err := os.OpenRoot(envconfig.Models())
if err != nil {
return nil, err
}
for _, sub := range []string{"manifests", "blobs"} {
if _, err := root.Stat(sub); errors.Is(err, fs.ErrNotExist) {
if err := root.MkdirAll(sub, 0o750); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
}
return root, nil
}
// Open opens an existing file for reading. It will return [fs.ErrNotExist]
// if the file does not exist. The returned [*Root] can only be used for reading.
// It is the caller's responsibility to close the file when done.
func Open(n Name) (*Root, error) {
r, err := root()
if err != nil {
return nil, err
}
f, err := r.Open(filepath.Join("manifests", n.Filepath()))
if err != nil {
return nil, err
}
defer f.Close()
var m manifest
if err := json.NewDecoder(f).Decode(&m); err != nil {
return nil, err
}
blobs := make(map[string]*blob, len(m.Layers)+1)
blobs[NamePrefix] = m.Config
for _, layer := range m.Layers {
if layer.Name == "" && layer.MediaType != "" {
mediatype, _, err := mime.ParseMediaType(layer.MediaType)
if err != nil {
return nil, err
}
if suffix, ok := strings.CutPrefix(mediatype, MediaTypePrefix); ok {
layer.Name = NamePrefix + suffix
}
}
blobs[layer.Name] = layer
}
return &Root{
root: r,
name: n,
blobs: blobs,
flags: os.O_RDONLY,
}, nil
}
// Create creates a new file. The returned [Root] can be used for both reading
// and writing. It is the caller's responsibility to close the file when done
// in order to finalize any new blobs and write the manifest.
func Create(n Name) (*Root, error) {
r, err := root()
if err != nil {
return nil, err
}
return &Root{
root: r,
name: n,
blobs: make(map[string]*blob),
flags: os.O_RDWR,
}, nil
}
type blob struct {
Digest string `json:"digest"`
MediaType string `json:"mediaType"`
Name string `json:"name,omitempty"`
Size int64 `json:"size"`
// tempfile is the temporary file where the blob data is written.
tempfile *os.File
// hash is the hash.Hash used to compute the blob digest.
hash hash.Hash
}
func (b *blob) Write(p []byte) (int, error) {
return io.MultiWriter(b.tempfile, b.hash).Write(p)
}
func (b *blob) Filepath() string {
return strings.ReplaceAll(b.Digest, ":", "-")
}
type manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *blob `json:"config"`
Layers []*blob `json:"layers"`
}
// Root represents a model file. It can be used to read and write blobs
// associated with the model.
//
// Blobs are identified by name. Certain names are special and reserved;
// see [NamePrefix] for details.
type Root struct {
root *os.Root
name Name
blobs map[string]*blob
flags int
}
const MediaTypePrefix = "application/vnd.ollama"
// NamePrefix is the prefix used for identifying special names. Names
// with this prefix are idenfitied by their media types:
//
// - name: NamePrefix + suffix
// - mediaType: [MediaTypePrefix] + suffix
//
// For example:
//
// - name: "./..image.model"
// - mediaType: "application/vnd.ollama.image.model"
//
// NamePrefix by itself identifies the manifest config.
const NamePrefix = "./."
// Open opens the named blob for reading. It is the caller's responsibility
// to close the returned [io.ReadCloser] when done. It will return
// [fs.ErrNotExist] if the blob does not exist.
func (r Root) Open(name string) (io.ReadCloser, error) {
if b, ok := r.blobs[name]; ok {
r, err := r.root.Open(filepath.Join("blobs", b.Filepath()))
if err != nil {
return nil, err
}
return r, nil
}
return nil, fs.ErrNotExist
}
func (r Root) ReadFile(name string) ([]byte, error) {
f, err := r.Open(name)
if err != nil {
return nil, err
}
defer f.Close()
return io.ReadAll(f)
}
// Create creates or replaces a named blob in the file. If the blob already
// exists, it will be overwritten. It will return [fs.ErrInvalid] if the file
// was opened in read-only mode. The returned [io.Writer] can be used to write
// to the blob and does not need be closed, but the file must be closed to
// finalize the blob.
func (r *Root) Create(name string) (io.Writer, error) {
if r.flags&os.O_RDWR != 0 {
w, err := os.CreateTemp(r.root.Name(), "")
if err != nil {
return nil, err
}
r.blobs[name] = &blob{Name: name, tempfile: w, hash: sha256.New()}
return r.blobs[name], nil
}
return nil, fs.ErrInvalid
}
// Close closes the file. If the file was opened in read-write mode, it
// will finalize any writeable blobs and write the manifest.
func (r *Root) Close() error {
if r.flags&os.O_RDWR != 0 {
for _, b := range r.blobs {
if b.tempfile != nil {
fi, err := b.tempfile.Stat()
if err != nil {
return err
}
if err := b.tempfile.Close(); err != nil {
return err
}
b.Size = fi.Size()
b.Digest = fmt.Sprintf("sha256:%x", b.hash.Sum(nil))
if suffix, ok := strings.CutPrefix(b.Name, NamePrefix); ok {
if b.Name == NamePrefix {
b.MediaType = "application/vnd.docker.container.image.v1+json"
} else {
b.MediaType = MediaTypePrefix + suffix
}
b.Name = ""
}
rel, err := filepath.Rel(r.root.Name(), b.tempfile.Name())
if err != nil {
return err
}
if err := r.root.Rename(rel, filepath.Join("blobs", b.Filepath())); err != nil {
return err
}
}
}
p := filepath.Join("manifests", r.name.Filepath())
if _, err := r.root.Stat(filepath.Dir(p)); errors.Is(err, os.ErrNotExist) {
if err := r.root.MkdirAll(filepath.Dir(p), 0o750); err != nil {
return err
}
} else if err != nil {
return err
}
f, err := r.root.OpenFile(p, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o640)
if err != nil {
return err
}
defer f.Close()
if err := json.NewEncoder(f).Encode(manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: r.blobs[NamePrefix],
Layers: func() []*blob {
blobs := make([]*blob, 0, len(r.blobs))
for name, b := range r.blobs {
if name != NamePrefix {
blobs = append(blobs, b)
}
}
return blobs
}(),
}); err != nil {
return err
}
}
return r.root.Close()
}
// Name returns the name of the file.
func (r Root) Name() Name {
return r.name
}
// Names returns an iterator over the names in the file.
func (r Root) Names() iter.Seq[string] {
return maps.Keys(r.blobs)
}
// Glob returns an iterator over the names in the file that match the given
// pattern.
//
// The pattern syntax is the same as [filepath.Match]. As with filepath.Match,
// the only possible returned error is ErrBadPattern, when pattern is malformed.
func (r Root) Glob(pattern string) (iter.Seq[string], error) {
if _, err := filepath.Match(pattern, ""); err != nil {
return nil, err
}
return func(yield func(string) bool) {
for name := range r.blobs {
if matched, _ := filepath.Match(pattern, name); matched {
if !yield(name) {
return
}
}
}
}, nil
}
func (r Root) JoinPath(parts ...string) string {
return filepath.Join(append([]string{r.root.Name()}, parts...)...)
}
func (r Root) Real(name string) string {
if b, ok := r.blobs[name]; ok {
return b.Filepath()
}
return ""
}

View File

@@ -1,90 +0,0 @@
package model
import (
"io"
"strings"
"testing"
)
// setup is a helper function to set up the test environment.
func setup(t *testing.T, models map[Name]map[string]io.Reader) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
for m, s := range models {
f, err := Create(m)
if err != nil {
t.Fatal(err)
}
for n, r := range s {
w, err := f.Create(n)
if err != nil {
t.Fatal(err)
}
if _, err := io.Copy(w, r); err != nil {
t.Fatal(err)
}
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
}
}
func TestOpen(t *testing.T) {
setup(t, map[Name]map[string]io.Reader{
ParseName("namespace/model"): {
"./.": strings.NewReader(`{"key":"value"}`),
},
ParseName("namespace/model:8b"): {
"./.": strings.NewReader(`{"foo":"bar"}`),
},
ParseName("another/model"): {
"./.": strings.NewReader(`{"another":"config"}`),
},
})
f, err := Open(ParseName("namespace/model"))
if err != nil {
t.Fatal(err)
}
for _, name := range []string{"./."} {
r, err := f.Open(name)
if err != nil {
t.Fatal(err)
}
if _, err := io.ReadAll(r); err != nil {
t.Fatal(err)
}
if err := r.Close(); err != nil {
t.Fatal(err)
}
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
t.Run("does not exist", func(t *testing.T) {
if _, err := Open(ParseName("namespace/unknown")); err == nil {
t.Error("expected error for unknown model")
}
})
t.Run("write", func(t *testing.T) {
f, err := Open(ParseName("namespace/model"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := f.Create("new-blob"); err == nil {
t.Error("expected error creating blob in read-only mode")
}
})
}

View File

@@ -1,33 +0,0 @@
package model
import (
"io/fs"
"iter"
"path/filepath"
)
func All() (iter.Seq[Name], error) {
r, err := root()
if err != nil {
return nil, err
}
manifests, err := r.OpenRoot("manifests")
if err != nil {
return nil, err
}
matches, err := fs.Glob(manifests.FS(), "*/*/*/*")
if err != nil {
return nil, err
}
return func(yield func(Name) bool) {
for _, match := range matches {
name := ParseNameFromFilepath(filepath.ToSlash(match))
if !yield(name) {
return
}
}
}, nil
}

View File

@@ -227,17 +227,6 @@ func (n Name) String() string {
return b.String()
}
// Set implements [flag.Value]. It parses the provided input as a name string
// and sets the receiver to the parsed value. If the parsed name is not valid,
// ErrUnqualifiedName is returned.
func (n *Name) Set(s string) error {
*n = ParseName(s)
if !n.IsValid() {
return ErrUnqualifiedName
}
return nil
}
// DisplayShortest returns a short string version of the name.
func (n Name) DisplayShortest() string {
var sb strings.Builder

View File

@@ -1,94 +0,0 @@
package mlxrunner
import (
"log/slog"
"github.com/ollama/ollama/x/mlxrunner/cache"
)
type CacheEntry struct {
Caches []cache.Cache
Count int
Entries map[int32]*CacheEntry
}
func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
current := &CacheEntry{Entries: s.CacheEntries}
index, cacheIndex := 0, -1
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
break
}
current = current.Entries[token]
if len(current.Caches) > 0 {
cacheIndex = index
}
index += 1
}
if cacheIndex == len(tokens)-1 {
slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
return current.Caches, []int32{}
} else if cacheIndex > 1 {
slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
return current.Caches, tokens[cacheIndex+1:]
} else if index > 0 && cacheIndex < 0 {
type stackItem struct {
entry *CacheEntry
tokens []int32
}
var best, item stackItem
stack := []stackItem{{entry: current, tokens: []int32{}}}
for len(stack) > 0 {
item, stack = stack[len(stack)-1], stack[:len(stack)-1]
if len(item.entry.Caches) > 0 {
if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
best = item
}
} else {
for token, entry := range item.entry.Entries {
stack = append(stack, stackItem{
entry: entry,
tokens: append(item.tokens, token),
})
}
}
}
prefix := min(len(tokens)-1, index)
caches := make([]cache.Cache, len(best.entry.Caches))
trim := len(best.tokens)+1
for i := range caches {
caches[i] = best.entry.Caches[i].Clone()
caches[i].Trim(trim)
}
slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
return caches, tokens[prefix:]
}
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
}
func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
current := &CacheEntry{Entries: s.CacheEntries}
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
current.Entries[token] = &CacheEntry{
Entries: make(map[int32]*CacheEntry),
}
}
current = current.Entries[token]
}
if len(current.Caches) > 0 {
current.Count += 1
} else {
current.Caches = caches
}
}

View File

@@ -1,196 +0,0 @@
package cache
import (
"log/slog"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Cache interface {
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
State() (keys, values *mlx.Array)
Trim(int) int
Clone() Cache
Offset() int
Len() int
}
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
// Grow buffer if needed
if c.keys == nil || (prev+L) > c.keys.Dim(2) {
steps := (c.step + L - 1) / c.step
newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
if c.keys != nil {
if prev%c.step != 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
}
c.keys.Set(c.keys.Concatenate(2, newKeys))
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
}
}
c.offset += L
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
if c.offset == c.keys.Dim(2) {
return c.keys, c.values
}
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n
return n
}
func (c *KVCache) Clone() Cache {
return &KVCache{
keys: c.keys.Clone(),
values: c.values.Clone(),
offset: c.offset,
step: c.step,
}
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
maxSize int
idx int
*KVCache
}
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
}
func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
if keys.Dim(2) > 1 {
return c.concat(keys, values)
}
return c.update(keys, values)
}
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
if c.keys == nil {
c.keys, c.values = keys, values
} else {
if c.idx < c.keys.Dim(2) {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
}
// Trim to max_size to maintain sliding window
if trim := c.idx - c.maxSize + 1; trim > 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
}
c.keys.Set(c.keys.Concatenate(2, keys))
c.values.Set(c.values.Concatenate(2, values))
c.idx = c.keys.Dim(2)
}
c.offset += keys.Dim(2)
c.idx = c.keys.Dim(2)
return c.keys, c.values
}
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
slog.Debug("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
// Grow buffer if not yet at max
if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
newSize := min(c.step, c.maxSize-prev)
newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
if c.keys != nil {
c.keys.Set(c.keys.Concatenate(2, newKeys))
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
}
c.idx = prev
}
// Trim to max_size to maintain sliding window
if trim := c.keys.Dim(2) - c.maxSize; trim > 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
c.idx = c.maxSize
}
// Rotate when hitting max
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
c.offset += L
c.idx += L
validLen := min(c.offset, c.maxSize)
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
}
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
if c.offset < c.keys.Dim(2) {
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
return c.keys, c.values
}
func (c *RotatingKVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n
c.idx -= n
return n
}
func (c *RotatingKVCache) Clone() Cache {
return &RotatingKVCache{
maxSize: c.maxSize,
idx: c.idx,
KVCache: c.KVCache.Clone().(*KVCache),
}
}
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }

View File

@@ -1,174 +0,0 @@
package mlxrunner
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"math"
"net"
"net/http"
"net/url"
"os/exec"
"strconv"
"strings"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
)
type Client struct {
Port int
*exec.Cmd
}
func (c *Client) JoinPath(path string) string {
return (&url.URL{
Scheme: "http",
Host: net.JoinHostPort("127.0.0.1", strconv.Itoa(c.Port)),
}).JoinPath(path).String()
}
func (c *Client) CheckError(w *http.Response) error {
if w.StatusCode >= 400 {
return errors.New(w.Status)
}
return nil
}
// Close implements llm.LlamaServer.
func (c *Client) Close() error {
return c.Cmd.Process.Kill()
}
// Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(req); err != nil {
return err
}
w, err := http.Post(c.JoinPath("/v1/completions"), "application/json", &b)
if err != nil {
return err
}
defer w.Body.Close()
if err := c.CheckError(w); err != nil {
return err
}
scanner := bufio.NewScanner(w.Body)
for scanner.Scan() {
bts := scanner.Bytes()
var resp llm.CompletionResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
fn(resp)
}
return nil
}
func (c *Client) ContextLength() int {
return math.MaxInt
}
// Detokenize implements llm.LlamaServer.
func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
panic("unimplemented")
}
// Embedding implements llm.LlamaServer.
func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
panic("unimplemented")
}
// GetDeviceInfos implements llm.LlamaServer.
func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
panic("unimplemented")
}
// GetPort implements llm.LlamaServer.
func (c *Client) GetPort() int {
return c.Port
}
// HasExited implements llm.LlamaServer.
func (c *Client) HasExited() bool {
panic("unimplemented")
}
// Load implements llm.LlamaServer.
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) {
w, err := http.Post(c.JoinPath("/v1/models"), "application/json", nil)
if err != nil {
return nil, err
}
defer w.Body.Close()
return []ml.DeviceID{}, nil
}
// ModelPath implements llm.LlamaServer.
func (c *Client) ModelPath() string {
panic("unimplemented")
}
// Pid implements llm.LlamaServer.
func (c *Client) Pid() int {
panic("unimplemented")
}
// Ping implements llm.LlamaServer.
func (c *Client) Ping(ctx context.Context) error {
w, err := http.Get(c.JoinPath("/v1/status"))
if err != nil {
return err
}
defer w.Body.Close()
return nil
}
// Tokenize implements llm.LlamaServer.
func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
w, err := http.Post(c.JoinPath("/v1/tokenize"), "text/plain", strings.NewReader(content))
if err != nil {
return nil, err
}
defer w.Body.Close()
var tokens []int
if err := json.NewDecoder(w.Body).Decode(&tokens); err != nil {
return nil, err
}
return tokens, nil
}
// TotalSize implements llm.LlamaServer.
func (c *Client) TotalSize() uint64 {
panic("unimplemented")
}
// VRAMByGPU implements llm.LlamaServer.
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
panic("unimplemented")
}
// VRAMSize implements llm.LlamaServer.
func (c *Client) VRAMSize() uint64 {
panic("unimplemented")
}
// WaitUntilRunning implements llm.LlamaServer.
func (c *Client) WaitUntilRunning(ctx context.Context) error {
panic("unimplemented")
}
var _ llm.LlamaServer = (*Client)(nil)

View File

@@ -1,3 +0,0 @@
_deps
build
dist

View File

@@ -1,26 +0,0 @@
cmake_minimum_required(VERSION 3.5)
project(mlx)
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
endif()
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG ${MLX_C_GIT_TAG}
)
FetchContent_MakeAvailable(mlx-c)

View File

@@ -1,45 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"math"
"sync"
)
var geluApprox = sync.OnceValue(func() *Closure {
return Compile(func(inputs []*Array) []*Array {
input := inputs[0]
return []*Array{
input.Multiply(
FromValue[float32](0.5),
).Multiply(
input.Add(
input.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
).Multiply(
FromValue(float32(math.Sqrt(2 / math.Pi))),
).Tanh().Add(FromValue[float32](1.0)),
).AsType(input.DType()),
}
}, true)
})
var silu = sync.OnceValue(func() *Closure {
return Compile(func(inputs []*Array) []*Array {
input := inputs[0]
return []*Array{
input.Multiply(
input.Sigmoid(),
).AsType(input.DType()),
}
}, true)
})
func GELUApprox(t *Array) *Array {
return geluApprox().Call([]*Array{t})[0]
}
func SILU(t *Array) *Array {
return silu().Call([]*Array{t})[0]
}

View File

@@ -1,271 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"encoding/binary"
"log/slog"
"reflect"
"strings"
"time"
"unsafe"
"github.com/ollama/ollama/logutil"
)
type tensorDesc struct {
name string
inputs []*Array
numRefs int
}
func (d tensorDesc) LogValue() slog.Value {
return slog.GroupValue(
slog.String("name", d.name),
slog.Int("inputs", len(d.inputs)),
slog.Int("num_refs", d.numRefs),
)
}
type Array struct {
ctx C.mlx_array
desc tensorDesc
}
// constructor utilities
func New(name string, inputs ...*Array) *Array {
t := &Array{
desc: tensorDesc{
name: name,
inputs: inputs,
},
}
for _, input := range inputs {
input.desc.numRefs++
}
logutil.Trace("New", "t", t)
return t
}
type scalarTypes interface {
~bool | ~int | ~float32 | ~float64 | ~complex64
}
func FromValue[T scalarTypes](t T) *Array {
tt := New("")
switch v := any(t).(type) {
case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v))
case int:
tt.ctx = C.mlx_array_new_int(C.int(v))
case float32:
tt.ctx = C.mlx_array_new_float32(C.float(v))
case float64:
tt.ctx = C.mlx_array_new_float64(C.double(v))
case complex64:
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
default:
panic("unsupported type")
}
return tt
}
type arrayTypes interface {
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
~int8 | ~int16 | ~int32 | ~int64 |
~float32 | ~float64 |
~complex64
}
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
if len(shape) == 0 {
panic("shape must be provided for non-scalar tensors")
}
cShape := make([]C.int, len(shape))
for i := range shape {
cShape[i] = C.int(shape[i])
}
var dtype DType
switch reflect.TypeOf(s).Elem().Kind() {
case reflect.Bool:
dtype = DTypeBool
case reflect.Uint8:
dtype = DTypeUint8
case reflect.Uint16:
dtype = DTypeUint16
case reflect.Uint32:
dtype = DTypeUint32
case reflect.Uint64:
dtype = DTypeUint64
case reflect.Int8:
dtype = DTypeInt8
case reflect.Int16:
dtype = DTypeInt16
case reflect.Int32:
dtype = DTypeInt32
case reflect.Int64:
dtype = DTypeInt64
case reflect.Float32:
dtype = DTypeFloat32
case reflect.Float64:
dtype = DTypeFloat64
case reflect.Complex64:
dtype = DTypeComplex64
default:
panic("unsupported type")
}
bts := make([]byte, binary.Size(s))
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
panic(err)
}
tt := New("")
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
return tt
}
func (t *Array) Set(other *Array) {
other.desc.numRefs++
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx)
}
func (t *Array) Clone() *Array {
tt := New(t.desc.name, t.desc.inputs...)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
// misc. utilities
func (t *Array) Valid() bool {
return t.ctx.ctx != nil
}
func (t *Array) String() string {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_array_tostring(&str, t.ctx)
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
}
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{slog.Any("", t.desc)}
if t.Valid() {
attrs = append(attrs,
slog.Any("dtype", t.DType()),
slog.Any("shape", t.Dims()),
slog.Int("num_bytes", t.NumBytes()),
)
}
return slog.GroupValue(attrs...)
}
// shape utilities
func (t Array) Size() int {
return int(C.mlx_array_size(t.ctx))
}
func (t Array) NumBytes() int {
return int(C.mlx_array_nbytes(t.ctx))
}
func (t Array) NumDims() int {
return int(C.mlx_array_ndim(t.ctx))
}
func (t Array) Dims() []int {
dims := make([]int, t.NumDims())
for i := range dims {
dims[i] = t.Dim(i)
}
return dims
}
func (t Array) Dim(dim int) int {
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
}
func (t Array) DType() DType {
return DType(C.mlx_array_dtype(t.ctx))
}
// data utilities
func (t Array) Int() int {
var item C.int64_t
C.mlx_array_item_int64(&item, t.ctx)
return int(item)
}
func (t Array) Float() float64 {
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
}
func (t Array) Ints() []int {
ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f)
}
return ints
}
func (t Array) Floats() []float32 {
floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f)
}
return floats
}
func (t Array) Save(name string) error {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
C.mlx_save(cName, t.ctx)
return nil
}
func Free(s ...*Array) (n int) {
now := time.Now()
defer func() {
if n > 0 {
logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now))
}
}()
free := make([]*Array, 0, 8192)
fn := func(t *Array) {
if t.Valid() {
free = append(free, t.desc.inputs...)
t.desc.numRefs--
if t.desc.numRefs <= 0 {
logutil.Trace("Free", "t", t)
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
}
for _, t := range s {
fn(t)
}
for len(free) > 0 {
tail := free[len(free)-1]
free = free[:len(free)-1]
fn(tail)
}
return n
}

View File

@@ -1,43 +0,0 @@
package mlx
import "testing"
func TestFromValue(t *testing.T) {
for got, want := range map[*Tensor]DType{
FromValue(true): DTypeBool,
FromValue(false): DTypeBool,
FromValue(int(7)): DTypeInt32,
FromValue(float32(3.14)): DTypeFloat32,
FromValue(float64(2.71)): DTypeFloat64,
FromValue(complex64(1 + 2i)): DTypeComplex64,
} {
t.Run(want.String(), func(t *testing.T) {
if got.DType() != want {
t.Errorf("want %v, got %v", want, got)
}
})
}
}
func TestFromValues(t *testing.T) {
for got, want := range map[*Tensor]DType{
FromValues([]bool{true, false, true}, 3): DTypeBool,
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
FromValues([]uint16{1, 2, 3}, 3): DTypeUint16,
FromValues([]uint32{1, 2, 3}, 3): DTypeUint32,
FromValues([]uint64{1, 2, 3}, 3): DTypeUint64,
FromValues([]int8{-1, -2, -3}, 3): DTypeInt8,
FromValues([]int16{-1, -2, -3}, 3): DTypeInt16,
FromValues([]int32{-1, -2, -3}, 3): DTypeInt32,
FromValues([]int64{-1, -2, -3}, 3): DTypeInt64,
FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32,
FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64,
FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64,
} {
t.Run(want.String(), func(t *testing.T) {
if got.DType() != want {
t.Errorf("want %v, got %v", want, got)
}
})
}
}

View File

@@ -1,76 +0,0 @@
package mlx
// #include "generated.h"
// int goClosureFunc(mlx_vector_array*, mlx_vector_array, void*);
// void goClosureDestructor(void*);
import "C"
import (
"runtime/cgo"
"unsafe"
)
type Closure struct {
ctx C.mlx_closure
}
func (c Closure) Call(inputs []*Array) []*Array {
inputsVector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inputsVector)
for _, input := range inputs {
C.mlx_vector_array_append_value(inputsVector, input.ctx)
}
outputsVector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outputsVector)
C.mlx_closure_apply(&outputsVector, c.ctx, inputsVector)
outputs := make([]*Array, int(C.mlx_vector_array_size(outputsVector)))
for i := range outputs {
t := New("", inputs...)
C.mlx_vector_array_get(&t.ctx, outputsVector, C.size_t(i))
outputs[i] = t
}
return outputs
}
func Compile(fn func([]*Array) []*Array, shapeless bool) *Closure {
closure := C.mlx_closure_new_func_payload(
(*[0]byte)(C.goClosureFunc),
unsafe.Pointer(cgo.NewHandle(fn)),
(*[0]byte)(C.goClosureDestructor),
)
compiled := C.mlx_closure_new()
C.mlx_compile(&compiled, closure, C.bool(shapeless))
return &Closure{ctx: compiled}
}
//export goClosureFunc
func goClosureFunc(outputsVector *C.mlx_vector_array, inputsVector C.mlx_vector_array, payload unsafe.Pointer) C.int {
handle := cgo.Handle(payload)
fn := handle.Value().(func([]*Array) []*Array)
inputs := make([]*Array, int(C.mlx_vector_array_size(inputsVector)))
for i := range inputs {
t := New("")
C.mlx_vector_array_get(&t.ctx, inputsVector, C.size_t(i))
inputs[i] = t
}
var outputs []C.mlx_array
for _, output := range fn(inputs) {
outputs = append(outputs, output.ctx)
}
C.mlx_vector_array_set_data(outputsVector, unsafe.SliceData(outputs), C.size_t(len(outputs)))
return 0
}
//export goClosureDestructor
func goClosureDestructor(payload unsafe.Pointer) {
cgo.Handle(payload).Delete()
}

View File

@@ -1,94 +0,0 @@
package mlx
// #include "generated.h"
import "C"
type DType int
func (t DType) String() string {
switch t {
case DTypeBool:
return "BOOL"
case DTypeUint8:
return "U8"
case DTypeUint16:
return "U16"
case DTypeUint32:
return "U32"
case DTypeUint64:
return "U64"
case DTypeInt8:
return "I8"
case DTypeInt16:
return "I16"
case DTypeInt32:
return "I32"
case DTypeInt64:
return "I64"
case DTypeFloat16:
return "F16"
case DTypeFloat32:
return "F32"
case DTypeFloat64:
return "F64"
case DTypeBFloat16:
return "BF16"
case DTypeComplex64:
return "C64"
default:
return "Unknown"
}
}
func (t *DType) UnmarshalJSON(b []byte) error {
switch string(b) {
case `"BOOL"`:
*t = DTypeBool
case `"U8"`:
*t = DTypeUint8
case `"U16"`:
*t = DTypeUint16
case `"U32"`:
*t = DTypeUint32
case `"U64"`:
*t = DTypeUint64
case `"I8"`:
*t = DTypeInt8
case `"I16"`:
*t = DTypeInt16
case `"I32"`:
*t = DTypeInt32
case `"I64"`:
*t = DTypeInt64
case `"F16"`:
*t = DTypeFloat16
case `"F64"`:
*t = DTypeFloat64
case `"F32"`:
*t = DTypeFloat32
case `"BF16"`:
*t = DTypeBFloat16
case `"C64"`:
*t = DTypeComplex64
default:
return nil
}
return nil
}
const (
DTypeBool DType = C.MLX_BOOL
DTypeUint8 DType = C.MLX_UINT8
DTypeUint16 DType = C.MLX_UINT16
DTypeUint32 DType = C.MLX_UINT32
DTypeUint64 DType = C.MLX_UINT64
DTypeInt8 DType = C.MLX_INT8
DTypeInt16 DType = C.MLX_INT16
DTypeInt32 DType = C.MLX_INT32
DTypeInt64 DType = C.MLX_INT64
DTypeFloat16 DType = C.MLX_FLOAT16
DTypeFloat32 DType = C.MLX_FLOAT32
DTypeFloat64 DType = C.MLX_FLOAT64
DTypeBFloat16 DType = C.MLX_BFLOAT16
DTypeComplex64 DType = C.MLX_COMPLEX64
)

View File

@@ -1,34 +0,0 @@
#include "dynamic.h"
#include <stdio.h>
#ifdef _WIN32
#include <windows.h>
#define DLOPEN(path) LoadLibraryA(path)
#define DLCLOSE(handle) FreeLibrary((HMODULE)(handle))
#else
#ifdef __APPLE__
#include <mach-o/dyld.h>
#include <libgen.h>
#endif
#include <dlfcn.h>
#define DLOPEN(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
#define DLCLOSE(handle) dlclose(handle)
#endif
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
handle->ctx = (void*) DLOPEN(path);
CHECK(handle->ctx != NULL);
return 0;
}
int mlx_dynamic_load(mlx_dynamic_handle* handle, const char *path) {
return mlx_dynamic_open(handle, path);
}
void mlx_dynamic_unload(mlx_dynamic_handle* handle) {
if (handle->ctx) {
DLCLOSE(handle->ctx);
handle->ctx = NULL;
}
}

View File

@@ -1,63 +0,0 @@
package mlx
// #include "dynamic.h"
// #include "generated.h"
// #include <stdlib.h>
import "C"
import (
"io/fs"
"log/slog"
"os"
"path/filepath"
"runtime"
"unsafe"
)
func init() {
switch runtime.GOOS {
case "darwin":
case "windows":
default:
return
}
paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH")
if !ok {
slog.Debug("OLLAMA_LIBRARY_PATH not set, skipping mlx dynamic loading")
return
}
for _, path := range filepath.SplitList(paths) {
matches, err := fs.Glob(os.DirFS(path), "libmlxc.*")
if err != nil {
panic(err)
}
for _, match := range matches {
path := filepath.Join(paths, match)
slog.Info("Loading MLX dynamic library", "path", path)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
slog.Error("Failed to load MLX dynamic library", "path", path)
continue
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
C.mlx_dynamic_unload(&handle)
continue
}
slog.Info("Loaded MLX dynamic library", "path", path)
return
}
}
panic("Failed to load any MLX dynamic library")
}

View File

@@ -1,27 +0,0 @@
#ifndef MLX_DYNAMIC_H
#define MLX_DYNAMIC_H
#ifdef _WIN32
#include <windows.h>
#define DLSYM(handle, symbol) GetProcAddress((HMODULE)(handle), symbol)
#else
#include <dlfcn.h>
#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
#endif
#define ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
#define CHECK(x) if (!(x)) { ERROR("CHECK failed: " #x); }
#define CHECK_LOAD(handle, x) x##_ = DLSYM(handle, #x); CHECK(x##_)
typedef struct {
void* ctx;
} mlx_dynamic_handle;
int mlx_dynamic_load(
mlx_dynamic_handle* handle,
const char *path);
void mlx_dynamic_unload(
mlx_dynamic_handle* handle);
#endif // MLX_DYNAMIC_H

View File

@@ -1,72 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array {
if mask == nil {
mask = New("")
}
sinks := New("")
mode := "causal"
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", query, key, value, mask, sinks)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
type LayerNorm struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_LAYERNORM", x)
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RMSNorm struct {
Weight Array `weight:"weight"`
}
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RoPE struct {
Dims int
Traditional bool
Base float32 `json:"rope_theta"`
Scale float32
}
func (r RoPE) Forward(t *Array, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", t, freqs)
C.mlx_fast_rope(
&out.ctx,
t.ctx,
C.int(r.Dims),
C._Bool(r.Traditional),
C.mlx_optional_float{
value: C.float(r.Base),
has_value: C._Bool(func() bool { return r.Base != 0 }()),
},
C.float(r.Scale),
C.int(offset),
freqs.ctx,
DefaultStream().ctx,
)
return out
}

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,24 +0,0 @@
// This code is auto-generated; DO NOT EDIT.
#include "generated.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
{{ range .Functions }}
{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL;
{{- end }}
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
{{- range .Functions }}
CHECK_LOAD(handle, {{ .Name }});
{{- end }}
return 0;
}
{{- range .Functions }}
{{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }}
return {{ .Name }}_({{ .Args }});
{{ "}" }}
{{- end }}

View File

@@ -1,20 +0,0 @@
// This code is auto-generated; DO NOT EDIT.
#ifndef MLX_GENERATED_H
#define MLX_GENERATED_H
#include "dynamic.h"
#include "mlx/c/mlx.h"
{{ range .Functions }}
#undef {{ .Name }}
{{- end }}
{{ range .Functions }}
extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }};
{{- end }}
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
{{ range .Functions }}
{{ .Type }} {{ .Name }}{{ .Parameters }};
{{- end }}
#endif // MLX_GENERATED_H

View File

@@ -1,135 +0,0 @@
package main
import (
"embed"
"flag"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"text/template"
tree_sitter "github.com/tree-sitter/go-tree-sitter"
tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
)
//go:embed *.gotmpl
var fsys embed.FS
type Function struct {
Type,
Name,
Parameters,
Args string
}
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
var fn Function
fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
if params := node.ChildByFieldName("parameters"); params != nil {
fn.Parameters = params.Utf8Text(source)
fn.Args = ParseParameters(params, tc, source)
}
var types []string
for node.Parent() != nil && node.Parent().Kind() != "declaration" {
if node.Parent().Kind() == "pointer_declarator" {
types = append(types, "*")
}
node = node.Parent()
}
for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
types = append(types, sibling.Utf8Text(source))
}
slices.Reverse(types)
fn.Type = strings.Join(types, " ")
return fn
}
func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
var s []string
for _, child := range node.Children(tc) {
if child.IsNamed() {
child := child.ChildByFieldName("declarator")
for child != nil && child.Kind() != "identifier" {
if child.Kind() == "parenthesized_declarator" {
child = child.Child(1)
} else {
child = child.ChildByFieldName("declarator")
}
}
if child != nil {
s = append(s, child.Utf8Text(source))
}
}
}
return strings.Join(s, ", ")
}
func main() {
var output string
flag.StringVar(&output, "output", ".", "Output directory for generated files")
flag.Parse()
parser := tree_sitter.NewParser()
defer parser.Close()
language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
parser.SetLanguage(language)
query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
defer query.Close()
qc := tree_sitter.NewQueryCursor()
defer qc.Close()
var funs []Function
for _, arg := range flag.Args() {
bts, err := os.ReadFile(arg)
if err != nil {
fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
continue
}
tree := parser.Parse(bts, nil)
defer tree.Close()
tc := tree.Walk()
defer tc.Close()
matches := qc.Matches(query, tree.RootNode(), bts)
for match := matches.Next(); match != nil; match = matches.Next() {
for _, capture := range match.Captures {
funs = append(funs, ParseFunction(&capture.Node, tc, bts))
}
}
}
tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
return
}
for _, tmpl := range tmpl.Templates() {
name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
fmt.Println("Generating", name)
f, err := os.Create(name)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
continue
}
defer f.Close()
if err := tmpl.Execute(f, map[string]any{
"Functions": funs,
}); err != nil {
fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
}
}
}

View File

@@ -1,166 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"encoding/binary"
"encoding/json"
"errors"
"io"
"iter"
"log/slog"
"maps"
"slices"
"unsafe"
"github.com/ollama/ollama/types/model"
)
func Load(path string) iter.Seq2[string, *Array] {
return func(yield func(string, *Array) bool) {
string2array := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(string2array)
string2string := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(string2string)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
cpu := C.mlx_default_cpu_stream_new()
defer C.mlx_stream_free(cpu)
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
it := C.mlx_map_string_to_array_iterator_new(string2array)
defer C.mlx_map_string_to_array_iterator_free(it)
for {
var key *C.char
value := C.mlx_array_new()
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
break
}
name := C.GoString(key)
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
break
}
}
}
}
func Parse(root *model.Root, path string) (map[string]Quantization, error) {
f, err := root.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var n uint64
if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
return nil, err
}
bts := make([]byte, n)
if _, err := io.ReadFull(f, bts); err != nil {
return nil, err
}
var m struct {
Metadata struct {
Quantization map[string]Quantization `json:"quantization"`
} `json:"__metadata__"`
}
if err := json.Unmarshal(bts, &m); err != nil {
return nil, err
}
return m.Metadata.Quantization, nil
}
func LoadWeights(root *model.Root, match string, states map[string]*Array) error {
slog.Debug("Loading weights from", "file", match)
for name, weight := range Load(root.JoinPath("blobs", root.Real(match))) {
if state, ok := states[name]; ok {
*state = *weight
}
}
return nil
}
func LoadQuantizations(root *model.Root, match string, quantizations map[string]*Quantization) error {
slog.Debug("Loading quantizations from", "file", match)
metadata, err := Parse(root, match)
if err != nil {
return err
}
for name := range metadata {
if q, ok := quantizations[name+".weight"]; ok {
q.GroupSize = metadata[name].GroupSize
q.Bits = metadata[name].Bits
q.Mode = metadata[name].Mode
}
}
return nil
}
type AfterLoadFunc func(*model.Root) ([]*Array, error)
func LoadAll(root *model.Root, states map[string]*Array, quantizations map[string]*Quantization, afterLoadFuncs []AfterLoadFunc) error {
matches, err := root.Glob("model*.safetensors")
if err != nil {
return err
}
for match := range matches {
if err := errors.Join(
LoadWeights(root, match, states),
LoadQuantizations(root, match, quantizations),
); err != nil {
return err
}
}
for _, afterLoadFunc := range afterLoadFuncs {
weights, err := afterLoadFunc(root)
if err != nil {
return err
}
for _, weight := range weights {
weight.desc.numRefs = 1000
Eval(weight)
var freeAll func(...*Array)
freeAll = func(inputs ...*Array) {
for _, input := range inputs {
input.desc.numRefs = 0
freeAll(input.desc.inputs...)
}
Free(inputs...)
}
freeAll(weight.desc.inputs...)
}
}
Eval(slices.Collect(maps.Values(states))...)
ClearCache()
slog.Info("Loaded weights", "count", len(states), "memory", Memory{})
return nil
}
func UnloadAll(states map[string]*Array) {
weights := slices.Collect(maps.Values(states))
for _, weight := range weights {
weight.desc.numRefs = 0
}
numBytes := Free(weights...)
slog.Info("Unloaded weights", "count", len(states), "num_bytes", PrettyBytes(numBytes), "memory", Memory{})
}

View File

@@ -1,85 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"fmt"
"log/slog"
"strconv"
)
func (b Byte) String() string {
return strconv.FormatInt(int64(b), 10) + " B"
}
func (b KibiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<10), 'f', 2, 64) + " KiB"
}
func (b MebiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(2*10)), 'f', 2, 64) + " MiB"
}
func (b GibiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(3*10)), 'f', 2, 64) + " GiB"
}
func (b TebiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(4*10)), 'f', 2, 64) + " TiB"
}
func PrettyBytes(n int) fmt.Stringer {
switch {
case n < 1<<10:
return Byte(n)
case n < 1<<(2*10):
return KibiByte(n)
case n < 1<<(3*10):
return MebiByte(n)
case n < 1<<(4*10):
return GibiByte(n)
default:
return TebiByte(n)
}
}
func ActiveMemory() int {
var active C.size_t
C.mlx_get_active_memory(&active)
return int(active)
}
func CacheMemory() int {
var cache C.size_t
C.mlx_get_cache_memory(&cache)
return int(cache)
}
func PeakMemory() int {
var peak C.size_t
C.mlx_get_peak_memory(&peak)
return int(peak)
}
type Memory struct{}
func (Memory) LogValue() slog.Value {
return slog.GroupValue(
slog.Any("active", PrettyBytes(ActiveMemory())),
slog.Any("cache", PrettyBytes(CacheMemory())),
slog.Any("peak", PrettyBytes(PeakMemory())),
)
}
type (
Byte int
KibiByte int
MebiByte int
GibiByte int
TebiByte int
)
func ClearCache() {
C.mlx_clear_cache()
}

View File

@@ -1,38 +0,0 @@
package mlx
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
//go:generate cmake --build build --parallel
//go:generate cmake --install build
//go:generate sh -c "go run generator/main.go -output=. ./dist/include/mlx/c/*.h"
// #cgo CXXFLAGS: -std=c++17
// #cgo CPPFLAGS: -I${SRCDIR}/dist/include
// #cgo LDFLAGS: -L${SRCDIR}/dist/lib -lstdc++
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
// #include "generated.h"
import "C"
func doEval(outputs []*Array, async bool) {
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
for _, output := range outputs {
if output.Valid() {
C.mlx_vector_array_append_value(vector, output.ctx)
}
}
if async {
C.mlx_async_eval(vector)
} else {
C.mlx_eval(vector)
}
}
func AsyncEval(outputs ...*Array) {
doEval(outputs, true)
}
func Eval(outputs ...*Array) {
doEval(outputs, false)
}

View File

@@ -1,102 +0,0 @@
package mlx
import "cmp"
type Quantization struct {
Scales Array `weight:"scales"`
Biases Array `weight:"biases"`
GroupSize int `json:"group_size"`
Bits int `json:"bits"`
Mode string `json:"mode"`
}
type Linear struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
Quantization
}
// Forward computes the linear transformation: x @ Weight.T + Bias
func (m Linear) Forward(x *Array) *Array {
if m.Scales.Valid() {
x = x.QuantizedMatmul(
&m.Weight,
&m.Scales,
&m.Biases,
true,
m.GroupSize,
m.Bits,
cmp.Or(m.Mode, "affine"),
)
if m.Bias.Valid() {
x = m.Bias.Add(x)
}
return x
}
w := m.Weight.Transpose(1, 0)
if m.Bias.Valid() {
return m.Bias.Addmm(x, w, 1.0, 1.0)
}
return x.Matmul(w)
}
func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
if m.Scales.Valid() {
x = x.GatherQMM(
&m.Weight,
&m.Scales,
&m.Biases,
lhs,
rhs,
sorted,
m.GroupSize,
m.Bits,
cmp.Or(m.Mode, "affine"),
sorted,
)
if m.Bias.Valid() {
x = m.Bias.Add(x)
}
return x
} else {
w := m.Weight.Transpose(0, 2, 1)
x = x.GatherMM(w, lhs, rhs, sorted)
}
if m.Bias.Valid() {
x = m.Bias.Add(x)
}
return x
}
type Embedding struct {
Weight Array `weight:"weight"`
Quantization
}
func (e *Embedding) Forward(indices *Array) *Array {
if e.Scales.Valid() {
w := e.Weight.TakeAxis(indices, 0)
return w.Dequantize(
e.Scales.TakeAxis(indices, 0),
e.Biases.TakeAxis(indices, 0),
e.GroupSize,
e.Bits,
cmp.Or(e.Mode, "affine"),
)
}
return e.Weight.TakeAxis(indices, 0)
}
func (e *Embedding) AsLinear() Linear {
return Linear{
Weight: e.Weight,
Quantization: e.Quantization,
}
}

View File

@@ -1,341 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func (t *Array) Abs() *Array {
out := New("ABS", t)
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Add(other *Array) *Array {
out := New("ADD", t, other)
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
out := New("ADDMM", t, a, b)
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
return out
}
func (t *Array) Argmax(axis int, keepDims bool) *Array {
out := New("ARGMAX", t)
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
out := New("ARGPARTITION", t)
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) ArgsortAxis(axis int) *Array {
out := New("ARGSORT_AXIS", t)
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) AsType(dtype DType) *Array {
out := New("AS_TYPE", t)
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
cStrides := make([]C.int64_t, len(strides))
for i, s := range strides {
cStrides[i] = C.int64_t(s)
}
out := New("AS_STRIDED", t)
C.mlx_as_strided(
&out.ctx, t.ctx,
unsafe.SliceData(cShape), C.size_t(len(shape)),
unsafe.SliceData(cStrides), C.size_t(len(strides)),
C.size_t(offset),
DefaultStream().ctx,
)
return out
}
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
s := append([]*Array{t}, others...)
for _, other := range s {
C.mlx_vector_array_append_value(vector, other.ctx)
}
out := New("CONCATENATE", s...)
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Divide(other *Array) *Array {
out := New("DIVIDE", t, other)
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Dequantize(scales, biases *Array, groupSize, bits int, mode string) *Array {
out := New("DEQUANTIZE", t, scales, biases)
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
C.mlx_dequantize(
&out.ctx,
t.ctx,
scales.ctx,
biases.ctx,
C.mlx_optional_int{
value: C.int(groupSize),
has_value: C.bool(groupSize > 0),
},
C.mlx_optional_int{
value: C.int(bits),
has_value: C.bool(bits > 0),
},
cMode,
C.mlx_optional_dtype{
has_value: false,
},
DefaultStream().ctx,
)
return out
}
func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS", t)
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Flatten(startAxis, endAxis int) *Array {
out := New("FLATTEN", t)
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
return out
}
func (t *Array) FloorDivide(other *Array) *Array {
out := New("FLOOR_DIVIDE", t, other)
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
if lhs == nil {
lhs = New("")
}
if rhs == nil {
rhs = New("")
}
out := New("GATHER_MM", t, other, lhs, rhs)
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
return out
}
func (t *Array) GatherQMM(weight, scales, biases, lhs, rhs *Array, transpose bool, groupSize, bits int, mode string, sorted bool) *Array {
if lhs == nil {
lhs = New("")
}
if rhs == nil {
rhs = New("")
}
out := New("GATHER_QMM", t, weight, scales, biases, lhs, rhs)
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
C.mlx_gather_qmm(
&out.ctx,
t.ctx,
weight.ctx,
scales.ctx,
biases.ctx,
lhs.ctx,
rhs.ctx,
C.bool(transpose),
C.mlx_optional_int{
value: C.int(groupSize),
has_value: C.bool(groupSize > 0),
},
C.mlx_optional_int{
value: C.int(bits),
has_value: C.bool(bits > 0),
},
cMode,
C.bool(sorted),
DefaultStream().ctx,
)
return out
}
func (t *Array) Logsumexp(keepDims bool) *Array {
out := New("LOGSUMEXP", t)
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL", t, other)
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Multiply(other *Array) *Array {
out := New("MULTIPLY", t, other)
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Negative() *Array {
out := New("NEGATIVE", t)
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Power(exponent *Array) *Array {
out := New("POWER", t, exponent)
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
return out
}
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS", t, indices, values)
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) QuantizedMatmul(weight, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
out := New("QUANTIZED_MATMUL", t, weight, scales, biases)
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
C.mlx_quantized_matmul(
&out.ctx,
t.ctx,
weight.ctx,
scales.ctx,
biases.ctx,
C.bool(transpose),
C.mlx_optional_int{
value: C.int(groupSize),
has_value: C.bool(groupSize > 0),
},
C.mlx_optional_int{
value: C.int(bits),
has_value: C.bool(bits > 0),
},
cMode,
DefaultStream().ctx,
)
return out
}
func (t *Array) Reshape(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
out := New("RESHAPE", t)
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func (t *Array) Sigmoid() *Array {
out := New("SIGMOID", t)
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Sqrt() *Array {
out := New("SQRT", t)
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Squeeze(axis int) *Array {
out := New("SQUEEZE", t)
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) StackAxis(axis int, others ...*Array) *Array {
vectorData := make([]C.mlx_array, len(others)+1)
vectorData[0] = t.ctx
for i := range others {
vectorData[i+1] = others[i].ctx
}
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK_AXIS", append(others, t)...)
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Subtract(other *Array) *Array {
out := New("SUBTRACT", t, other)
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
out := New("SUM_AXIS", t)
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
out := New("TAKE_AXIS", t, indices)
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
out := New("TAKE_ALONG_AXIS", t, indices)
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Tanh() *Array {
out := New("TANH", t)
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Transpose(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i, axis := range axes {
cAxes[i] = C.int(axis)
}
out := New("TRANSPOSE", t)
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func Zeros(dtype DType, shape ...int) *Array {
cAxes := make([]C.int, len(shape))
for i := range shape {
cAxes[i] = C.int(shape[i])
}
t := New("ZEROS")
C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx)
return t
}

View File

@@ -1,11 +0,0 @@
package mlx
// #include "generated.h"
import "C"
func (t *Array) Categorical(axis int) *Array {
key := New("")
out := New("", t, key)
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
return out
}

View File

@@ -1,84 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"cmp"
"unsafe"
)
type slice struct {
args []int
}
func Slice(args ...int) slice {
return slice{args: args}
}
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
if len(slices) != len(dims) {
panic("number of slice arguments must match number of tensor dimensions")
}
args := [3][]C.int{
make([]C.int, len(slices)),
make([]C.int, len(slices)),
make([]C.int, len(slices)),
}
for i, s := range slices {
switch len(s.args) {
case 0:
// slice[:]
args[0][i] = C.int(0)
args[1][i] = C.int(dims[i])
args[2][i] = C.int(1)
case 1:
// slice[i]
args[0][i] = C.int(s.args[0])
args[1][i] = C.int(s.args[0] + 1)
args[2][i] = C.int(1)
case 2:
// slice[i:j]
args[0][i] = C.int(s.args[0])
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
args[2][i] = C.int(1)
case 3:
// slice[i:j:k]
args[0][i] = C.int(s.args[0])
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
args[2][i] = C.int(s.args[2])
default:
panic("invalid slice arguments")
}
}
return args[0], args[1], args[2]
}
func (t *Array) Slice(slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE", t)
C.mlx_slice(
&out.ctx, t.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE_UPDATE", t, other)
C.mlx_slice_update(
&out.ctx, t.ctx, other.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}

View File

@@ -1,43 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"log/slog"
"sync"
)
type Device struct {
ctx C.mlx_device
}
func (d Device) LogValue() slog.Value {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_device_tostring(&str, d.ctx)
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
}
var DefaultDevice = sync.OnceValue(func() Device {
d := C.mlx_device_new()
C.mlx_get_default_device(&d)
return Device{d}
})
type Stream struct {
ctx C.mlx_stream
}
func (s Stream) LogValue() slog.Value {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_stream_tostring(&str, s.ctx)
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
}
var DefaultStream = sync.OnceValue(func() Stream {
s := C.mlx_stream_new()
C.mlx_get_default_stream(&s, DefaultDevice().ctx)
return Stream{s}
})

View File

@@ -1,8 +0,0 @@
package base
import "github.com/ollama/ollama/x/mlxrunner/cache"
// Cacher is implemented by models that support custom caching mechanisms.
type Cacher interface {
Cache() []cache.Cache
}

View File

@@ -1,116 +0,0 @@
package base
import (
"encoding/json"
"errors"
"log/slog"
"reflect"
"strconv"
"strings"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Model interface {
// Forward performs a forward pass through the model.
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
// NumLayers returns the number of layers in the model.
// This is used to initialize caches.
// TODO: consider moving cache initialization into the model itself.
NumLayers() int
}
type TextGeneration interface {
Model
Unembed(*mlx.Array) *mlx.Array
}
func Walk(m Model) (map[string]*mlx.Array, map[string]*mlx.Quantization, []mlx.AfterLoadFunc) {
weights := make(map[string]*mlx.Array)
quantizations := make(map[string]*mlx.Quantization)
var afterLoadFuncs []mlx.AfterLoadFunc
var fn func(v reflect.Value, tags []string)
fn = func(v reflect.Value, tags []string) {
t := v.Type()
if method := v.Addr().MethodByName("AfterLoad"); method.IsValid() {
var afterLoadFunc mlx.AfterLoadFunc
reflect.ValueOf(&afterLoadFunc).Elem().Set(method)
afterLoadFuncs = append(afterLoadFuncs, afterLoadFunc)
}
if t == reflect.TypeOf((*mlx.Array)(nil)).Elem() {
name := strings.Join(tags, ".")
weights[name] = v.Addr().Interface().(*mlx.Array)
return
} else if t == reflect.TypeOf((*mlx.Quantization)(nil)).Elem() {
quantizations[strings.Join(tags, ".")] = v.Addr().Interface().(*mlx.Quantization)
}
for _, field := range reflect.VisibleFields(t) {
if field.IsExported() {
tt, vv := field.Type, v.FieldByIndex(field.Index)
// create local copy so tags are not modified between fields
tags := tags
if tag := field.Tag.Get("weight"); tag != "" {
// TODO: use model.Tag
tags = append(tags, tag)
}
switch tt.Kind() {
case reflect.Interface:
vv = vv.Elem()
fallthrough
case reflect.Pointer:
vv = vv.Elem()
fallthrough
case reflect.Struct:
fn(vv, tags)
case reflect.Slice, reflect.Array:
for i := range vv.Len() {
fn(vv.Index(i), append(tags, strconv.Itoa(i)))
}
}
}
}
}
fn(reflect.ValueOf(m).Elem(), []string{})
return weights, quantizations, afterLoadFuncs
}
var m = make(map[string]func(*model.Root) (Model, error))
func Register(name string, f func(*model.Root) (Model, error)) {
if _, exists := m[name]; exists {
panic("model already registered: " + name)
}
m[name] = f
}
func New(root *model.Root) (Model, error) {
c, err := root.Open("config.json")
if err != nil {
return nil, err
}
defer c.Close()
var config struct {
Architectures []string `json:"architectures"`
}
if err := json.NewDecoder(c).Decode(&config); err != nil {
return nil, err
}
slog.Info("Model architecture", "arch", config.Architectures[0])
if f, exists := m[config.Architectures[0]]; exists {
return f(root)
}
return nil, errors.New("unknown architecture")
}

View File

@@ -1,84 +0,0 @@
package gemma
import (
"cmp"
"encoding/json"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
type Model struct {
Text TextModel `weight:"language_model"`
}
func (m *Model) NumLayers() int {
return len(m.Text.Layers)
}
func (m Model) Cache() []cache.Cache {
caches := make([]cache.Cache, m.NumLayers())
for i := range caches {
if (i+1)%m.Text.Options.SlidingWindowPattern == 0 {
caches[i] = cache.NewKVCache()
} else {
caches[i] = cache.NewRotatingKVCache(m.Text.Options.SlidingWindow)
}
}
return caches
}
func (m *Model) Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array {
return m.Text.Forward(inputs, cache)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.Text.EmbedTokens.AsLinear().Forward(x)
}
func init() {
base.Register("Gemma3ForConditionalGeneration", func(root *model.Root) (base.Model, error) {
bts, err := root.ReadFile("config.json")
if err != nil {
return nil, err
}
var opts struct {
Text TextOptions `json:"text_config"`
}
if err := json.Unmarshal(bts, &opts); err != nil {
return nil, err
}
opts.Text.NumAttentionHeads = cmp.Or(opts.Text.NumAttentionHeads, 8)
opts.Text.NumKeyValueHeads = cmp.Or(opts.Text.NumKeyValueHeads, 4)
opts.Text.HeadDim = cmp.Or(opts.Text.HeadDim, 256)
opts.Text.RMSNormEps = cmp.Or(opts.Text.RMSNormEps, 1e-6)
opts.Text.SlidingWindowPattern = cmp.Or(opts.Text.SlidingWindowPattern, 6)
// TODO: implement json.Unmarshaler
opts.Text.RoPE = map[bool]mlx.RoPE{
true: {Dims: opts.Text.HeadDim, Traditional: false, Base: 1_000_000, Scale: 1. / 8.},
false: {Dims: opts.Text.HeadDim, Traditional: false, Base: 10_000, Scale: 1},
}
return &Model{
Text: TextModel{
Layers: make([]TextDecoderLayer, opts.Text.NumHiddenLayers),
Options: opts.Text,
},
}, nil
})
}
type RMSNorm struct {
mlx.RMSNorm
}
func (m *RMSNorm) AfterLoad(*model.Root) ([]*mlx.Array, error) {
m.Weight.Set(m.Weight.Add(mlx.FromValue(1)))
return []*mlx.Array{}, nil
}

View File

@@ -1,118 +0,0 @@
package gemma
import (
"math"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type TextOptions struct {
HiddenSize int `json:"hidden_size"`
NumHiddenLayers int `json:"num_hidden_layers"`
IntermediateSize int `json:"intermediate_size"`
NumAttentionHeads int `json:"num_attention_heads"`
NumKeyValueHeads int `json:"num_key_value_heads"`
HeadDim int `json:"head_dim"`
RMSNormEps float32 `json:"rms_norm_eps"`
SlidingWindow int `json:"sliding_window"`
SlidingWindowPattern int `json:"sliding_window_pattern"`
RoPE map[bool]mlx.RoPE
}
type TextModel struct {
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
Layers []TextDecoderLayer `weight:"model.layers"`
Norm RMSNorm `weight:"model.norm"`
Options TextOptions
}
func (m TextModel) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := inputs.Dim(0), inputs.Dim(1)
hiddenStates := m.EmbedTokens.Forward(inputs)
hiddenSize := mlx.FromValue(m.Options.HiddenSize).AsType(hiddenStates.DType())
hiddenStates = hiddenStates.Multiply(hiddenSize.Sqrt())
for i, layer := range m.Layers {
hiddenStates = layer.Forward(hiddenStates, caches[i], B, L, m.Options.RoPE[(i+1)%m.Options.SlidingWindowPattern == 0], m.Options)
}
hiddenStates = m.Norm.Forward(hiddenStates, m.Options.RMSNormEps)
return hiddenStates
}
type TextDecoderLayer struct {
InputNorm RMSNorm `weight:"input_layernorm"`
Attention TextAttention `weight:"self_attn"`
PostAttnNorm RMSNorm `weight:"post_attention_layernorm"`
PreFFNorm RMSNorm `weight:"pre_feedforward_layernorm"`
MLP TextMLP `weight:"mlp"`
PostFFNorm RMSNorm `weight:"post_feedforward_layernorm"`
}
func (m TextDecoderLayer) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
residual := hiddenStates
hiddenStates = m.InputNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.Attention.Forward(hiddenStates, cache, B, L, rope, opts)
hiddenStates = m.PostAttnNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = hiddenStates.Add(residual)
residual = hiddenStates
hiddenStates = m.PreFFNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.MLP.Forward(hiddenStates, opts)
hiddenStates = m.PostFFNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = hiddenStates.Add(residual)
return hiddenStates
}
type TextAttention struct {
QProj mlx.Linear `weight:"q_proj"`
QNorm RMSNorm `weight:"q_norm"`
KProj mlx.Linear `weight:"k_proj"`
KNorm RMSNorm `weight:"k_norm"`
VProj mlx.Linear `weight:"v_proj"`
OProj mlx.Linear `weight:"o_proj"`
}
func (m TextAttention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
query := m.QProj.Forward(hiddenStates)
key := m.KProj.Forward(hiddenStates)
value := m.VProj.Forward(hiddenStates)
query = query.AsStrided(
[]int{B, opts.NumAttentionHeads, L, opts.HeadDim},
[]int{L * opts.NumAttentionHeads * opts.HeadDim, opts.HeadDim, opts.NumAttentionHeads * opts.HeadDim, 1},
0)
key = key.AsStrided(
[]int{B, opts.NumKeyValueHeads, L, opts.HeadDim},
[]int{L * opts.NumKeyValueHeads * opts.HeadDim, opts.HeadDim, opts.NumKeyValueHeads * opts.HeadDim, 1},
0)
value = value.AsStrided(
[]int{B, opts.NumKeyValueHeads, L, opts.HeadDim},
[]int{L * opts.NumKeyValueHeads * opts.HeadDim, opts.HeadDim, opts.NumKeyValueHeads * opts.HeadDim, 1},
0)
query = m.QNorm.Forward(query, opts.RMSNormEps)
key = m.KNorm.Forward(key, opts.RMSNormEps)
query = rope.Forward(query, cache.Offset())
key = rope.Forward(key, cache.Offset())
key, value = cache.Update(key, value)
attention := mlx.ScaledDotProductAttention(query, key, value, nil, 1.0/float32(math.Sqrt(float64(opts.HeadDim))))
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
return m.OProj.Forward(attention)
}
type TextMLP struct {
GateProj mlx.Linear `weight:"gate_proj"`
UpProj mlx.Linear `weight:"up_proj"`
DownProj mlx.Linear `weight:"down_proj"`
}
func (m TextMLP) Forward(h *mlx.Array, opts TextOptions) *mlx.Array {
return m.DownProj.Forward(mlx.GELUApprox(m.GateProj.Forward(h)).Multiply(m.UpProj.Forward(h)))
}

View File

@@ -1,334 +0,0 @@
package glm
import (
"encoding/json"
"math"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
type Options struct {
HiddenSize int `json:"hidden_size"`
NumHiddenLayers int `json:"num_hidden_layers"`
IntermediateSize int `json:"intermediate_size"`
NumAttentionHeads int `json:"num_attention_heads"`
NumKeyValueHeads int `json:"num_key_value_heads"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
QLoraRank int `json:"q_lora_rank"`
KVLoraRank int `json:"kv_lora_rank"`
QKRopeHeadDim int `json:"qk_rope_head_dim"`
QKNopeHeadDim int `json:"qk_nope_head_dim"`
NumRoutedExperts int `json:"n_routed_experts"`
NumSharedExperts int `json:"n_shared_experts"`
NumExpertsPerTok int `json:"num_experts_per_tok"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
NormTopKProb bool `json:"norm_topk_prob"`
FirstKDenseReplace int `json:"first_k_dense_replace"`
mlx.RoPE
}
type Model struct {
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
Layers []Layer `weight:"model.layers"`
Norm mlx.RMSNorm `weight:"model.norm"`
LMHead mlx.Linear `weight:"lm_head"`
Options
}
func (m Model) NumLayers() int {
return len(m.Layers)
}
func (m Model) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := inputs.Dim(0), inputs.Dim(1)
h := m.EmbedTokens.Forward(inputs)
for i, layer := range m.Layers {
h = layer.Forward(h, caches[i], B, L, m.Options)
}
h = m.Norm.Forward(h, m.RMSNormEps)
return h
}
func (m Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
type Layer struct {
InputLayernorm mlx.RMSNorm `weight:"input_layernorm"`
Attention Attention `weight:"self_attn"`
PostAttentionLayernorm mlx.RMSNorm `weight:"post_attention_layernorm"`
MLP MLP `weight:"mlp"`
}
func (m Layer) Forward(h *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
r := h
h = m.InputLayernorm.Forward(h, opts.RMSNormEps)
h = m.Attention.Forward(h, cache, B, L, opts)
h = h.Add(r)
r = h
h = m.PostAttentionLayernorm.Forward(h, opts.RMSNormEps)
h = m.MLP.Forward(h, B, L, opts)
h = h.Add(r)
return h
}
type MultiLinear struct {
Weight mlx.Array `weight:"weight"`
}
func (m MultiLinear) Forward(x *mlx.Array) *mlx.Array {
return x.Matmul(m.Weight.Transpose(0, 2, 1))
}
type Attention struct {
QAProj mlx.Linear `weight:"q_a_proj"`
QALayernorm mlx.RMSNorm `weight:"q_a_layernorm"`
QBProj mlx.Linear `weight:"q_b_proj"`
KVAProjWithMQA mlx.Linear `weight:"kv_a_proj_with_mqa"`
KVALayernorm mlx.RMSNorm `weight:"kv_a_layernorm"`
KVBProj mlx.Linear `weight:"kv_b_proj"`
embedQ MultiLinear
unembedOut MultiLinear
OProj mlx.Linear `weight:"o_proj"`
}
func (m *Attention) AfterLoad(root *model.Root) ([]*mlx.Array, error) {
bts, err := root.ReadFile("config.json")
if err != nil {
return nil, err
}
var opts struct {
NumAttentionHeads int `json:"num_attention_heads"`
QKNopeHeadDim int `json:"qk_nope_head_dim"`
KVLoraRank int `json:"kv_lora_rank"`
}
if err := json.Unmarshal(bts, &opts); err != nil {
return nil, err
}
w := m.KVBProj.Weight.Reshape(opts.NumAttentionHeads, -1, opts.KVLoraRank)
m.embedQ.Weight.Set(w.Slice(mlx.Slice(), mlx.Slice(0, opts.QKNopeHeadDim), mlx.Slice()).Transpose(0, 2, 1))
m.unembedOut.Weight.Set(w.Slice(mlx.Slice(), mlx.Slice(opts.QKNopeHeadDim, 0), mlx.Slice()))
return []*mlx.Array{
&m.embedQ.Weight,
&m.unembedOut.Weight,
}, nil
}
func (m Attention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
query := m.QAProj.Forward(hiddenStates)
query = m.QALayernorm.Forward(query, opts.RMSNormEps)
query = m.QBProj.Forward(query)
query = query.Reshape(B, L, opts.NumAttentionHeads, -1)
query = query.Transpose(0, 2, 1, 3)
queryNope := query.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.QKNopeHeadDim))
queryRope := query.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(opts.QKNopeHeadDim, 0))
compressedKV := m.KVAProjWithMQA.Forward(hiddenStates)
keyRope := compressedKV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(opts.KVLoraRank, 0))
keyRope = keyRope.Reshape(B, L, 1, opts.QKRopeHeadDim)
keyRope = keyRope.Transpose(0, 2, 1, 3)
kvCompressed := compressedKV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.KVLoraRank))
var offset int
if cache != nil {
offset = cache.Offset()
}
queryRope = opts.RoPE.Forward(queryRope, offset)
keyRope = opts.RoPE.Forward(keyRope, offset)
key := m.KVALayernorm.Forward(kvCompressed, opts.RMSNormEps).
ExpandDims(1).
Concatenate(3, keyRope)
if cache != nil {
key, _ = cache.Update(key, mlx.Zeros(mlx.DTypeBFloat16, B, 1, L, 0))
}
value := key.Clone().Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.KVLoraRank))
query = m.embedQ.Forward(queryNope).Concatenate(3, queryRope)
attention := mlx.ScaledDotProductAttention(query, key, value, nil, float32(1.0/math.Sqrt(float64(opts.QKNopeHeadDim+opts.QKRopeHeadDim))))
attention = m.unembedOut.Forward(attention)
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
return m.OProj.Forward(attention)
}
type MLP interface {
Forward(*mlx.Array, int, int, Options) *mlx.Array
}
type dense struct {
GateProj mlx.Linear `weight:"gate_proj"`
UpProj mlx.Linear `weight:"up_proj"`
DownProj mlx.Linear `weight:"down_proj"`
}
func (m dense) Forward(h *mlx.Array, _, _ int, opts Options) *mlx.Array {
h = mlx.SILU(m.GateProj.Forward(h)).Multiply(m.UpProj.Forward(h))
return m.DownProj.Forward(h)
}
type Gate struct {
Gate mlx.Linear `weight:"gate"`
CorrectionBias mlx.Array `weight:"gate.e_score_correction_bias"`
}
var expertSelect *mlx.Closure
func ExpertSelect(opts Options) *mlx.Closure {
if expertSelect == nil {
expertSelect = mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
scores, correctionBias := inputs[0], inputs[1]
scores = scores.Sigmoid()
original := scores
scores = scores.Add(correctionBias)
indices := scores.Negative().ArgpartitionAxis(opts.NumExpertsPerTok-1, -1)
indices = indices.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.NumExpertsPerTok))
scores = original.TakeAlongAxis(indices, -1)
if opts.NumExpertsPerTok > 1 && opts.NormTopKProb {
scores = scores.Divide(scores.SumAxis(-1, true).Add(mlx.FromValue[float32](1e-20)))
}
scores = scores.Multiply(mlx.FromValue(opts.RoutedScalingFactor))
return []*mlx.Array{indices, scores}
}, false)
}
return expertSelect
}
func (m Gate) Forward(h *mlx.Array, opts Options) (indices, scores *mlx.Array) {
outputs := ExpertSelect(opts).Call([]*mlx.Array{
m.Gate.Forward(h).AsType(mlx.DTypeFloat32),
&m.CorrectionBias,
})
return outputs[0], outputs[1]
}
type sparse struct {
Gate
Experts []dense `weight:"experts"`
fused struct {
GateProj mlx.Linear
UpProj mlx.Linear
DownProj mlx.Linear
}
SharedExperts dense `weight:"shared_experts"`
}
func (m *sparse) AfterLoad(*model.Root) ([]*mlx.Array, error) {
w1 := make([]*mlx.Array, len(m.Experts))
w2 := make([]*mlx.Array, len(m.Experts))
w3 := make([]*mlx.Array, len(m.Experts))
for i := range m.Experts {
w1[i] = &m.Experts[i].GateProj.Weight
w2[i] = &m.Experts[i].UpProj.Weight
w3[i] = &m.Experts[i].DownProj.Weight
}
m.fused.GateProj.Weight.Set(w1[0].StackAxis(0, w1[1:]...))
m.fused.UpProj.Weight.Set(w2[0].StackAxis(0, w2[1:]...))
m.fused.DownProj.Weight.Set(w3[0].StackAxis(0, w3[1:]...))
return []*mlx.Array{
&m.fused.GateProj.Weight,
&m.fused.UpProj.Weight,
&m.fused.DownProj.Weight,
}, nil
}
func (m sparse) Forward(h *mlx.Array, B, L int, opts Options) *mlx.Array {
indices, scores := m.Gate.Forward(h, opts)
scores = scores.ExpandDims(-1)
flat := h.ExpandDims(-2).ExpandDims(-2).Reshape(-1, 1, 1, opts.HiddenSize)
indices = indices.Reshape(-1, opts.NumExpertsPerTok)
sort := B*L >= 64
var inverseOrder *mlx.Array
if sort {
indicesAll := indices.Flatten(0, len(indices.Dims())-1)
order := indicesAll.ArgsortAxis(0)
inverseOrder = order.ArgsortAxis(0)
flat = flat.Squeeze(1).TakeAxis(order.FloorDivide(mlx.FromValue(opts.NumExpertsPerTok)), 0).ExpandDims(1)
indices = indicesAll.TakeAxis(order, 0).Reshape(B*L*opts.NumExpertsPerTok, 1)
}
experts := mlx.SILU(m.fused.GateProj.Gather(flat, nil, indices, sort)).
Multiply(m.fused.UpProj.Gather(flat, nil, indices, sort))
experts = m.fused.DownProj.Gather(experts, nil, indices, sort)
if sort {
experts = experts.Squeeze(2).Squeeze(1).TakeAxis(inverseOrder, 0)
experts = experts.Reshape(-1, opts.NumExpertsPerTok, opts.HiddenSize)
} else {
experts = experts.Squeeze(2)
}
experts = experts.Reshape(B, L, opts.NumExpertsPerTok, opts.HiddenSize)
experts = experts.Multiply(scores).SumAxis(-2, false).AsType(experts.DType())
experts = experts.Add(m.SharedExperts.Forward(h, B, L, opts))
return experts.Reshape(B, L, -1)
}
func init() {
base.Register("Glm4MoeLiteForCausalLM", func(root *model.Root) (base.Model, error) {
bts, err := root.ReadFile("config.json")
if err != nil {
return nil, err
}
var opts Options
if err := json.Unmarshal(bts, &opts); err != nil {
return nil, err
}
opts.RoPE = mlx.RoPE{
Dims: opts.QKRopeHeadDim,
Traditional: true,
Base: opts.RopeTheta,
Scale: 1,
}
layers := make([]Layer, opts.NumHiddenLayers)
for i := range layers {
if i < opts.FirstKDenseReplace {
layers[i].MLP = &dense{}
} else {
layers[i].MLP = &sparse{Experts: make([]dense, opts.NumRoutedExperts)}
}
}
return &Model{
Layers: layers,
Options: opts,
}, nil
})
}

View File

@@ -1,130 +0,0 @@
package llama
import (
"encoding/json"
"log/slog"
"math"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
type Options struct {
HiddenAct string `json:"hidden_act"`
HiddenSize int `json:"hidden_size"`
IntermediateSize int `json:"intermediate_size"`
NumAttentionHeads int `json:"num_attention_heads"`
NumHiddenLayers int `json:"num_hidden_layers"`
NumKeyValueHeads int `json:"num_key_value_heads"`
RMSNormEps float32 `json:"rms_norm_eps"`
mlx.RoPE
}
type Model struct {
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
Layers []Layer `weight:"model.layers"`
Norm mlx.RMSNorm `weight:"model.norm"`
Output mlx.Linear `weight:"lm_head"`
Options
}
func (m Model) NumLayers() int {
return len(m.Layers)
}
func (m Model) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
slog.Debug("Model.forward", "input shape", inputs.Dims(), "m.EmbedTokens", m.EmbedTokens.Weight.Dims())
B, L := inputs.Dim(0), inputs.Dim(1)
hiddenStates := m.EmbedTokens.Forward(inputs)
for i, layer := range m.Layers {
hiddenStates = layer.Forward(hiddenStates, caches[i], B, L, m.Options)
}
hiddenStates = m.Norm.Forward(hiddenStates, m.RMSNormEps)
hiddenStates = m.Output.Forward(hiddenStates)
slog.Debug("Model.forward", "output shape", hiddenStates.Dims(), "m.Output", m.Output.Weight.Dims())
return hiddenStates
}
type Layer struct {
AttentionNorm mlx.RMSNorm `weight:"input_layernorm"`
Attention Attention `weight:"self_attn"`
MLPNorm mlx.RMSNorm `weight:"post_attention_layernorm"`
MLP MLP `weight:"mlp"`
}
func (m Layer) Forward(hiddenStates *mlx.Array, c cache.Cache, B, L int, opts Options) *mlx.Array {
residual := hiddenStates
hiddenStates = m.AttentionNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.Attention.Forward(hiddenStates, c, B, L, opts)
hiddenStates = hiddenStates.Add(residual)
residual = hiddenStates
hiddenStates = m.MLPNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.MLP.Forward(hiddenStates)
hiddenStates = hiddenStates.Add(residual)
return hiddenStates
}
type Attention struct {
QueryProj mlx.Linear `weight:"q_proj"`
KeyProj mlx.Linear `weight:"k_proj"`
ValueProj mlx.Linear `weight:"v_proj"`
OutputProj mlx.Linear `weight:"o_proj"`
}
func (m Attention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
query := m.QueryProj.Forward(hiddenStates)
query = query.Reshape(B, L, opts.NumAttentionHeads, -1).Transpose(0, 2, 1, 3)
key := m.KeyProj.Forward(hiddenStates)
key = key.Reshape(B, L, opts.NumKeyValueHeads, -1).Transpose(0, 2, 1, 3)
value := m.ValueProj.Forward(hiddenStates)
value = value.Reshape(B, L, opts.NumKeyValueHeads, -1).Transpose(0, 2, 1, 3)
query = opts.RoPE.Forward(query, cache.Offset())
key = opts.RoPE.Forward(key, cache.Offset())
key, value = cache.Update(key, value)
attention := mlx.ScaledDotProductAttention(query, key, value, nil, 1.0/float32(math.Sqrt(float64(key.Dim(-1)))))
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
return m.OutputProj.Forward(attention)
}
type MLP struct {
Gate mlx.Linear `weight:"gate_proj"`
Up mlx.Linear `weight:"up_proj"`
Down mlx.Linear `weight:"down_proj"`
}
func (m MLP) Forward(h *mlx.Array) *mlx.Array {
return m.Down.Forward(mlx.SILU(m.Gate.Forward(h)).Multiply(m.Up.Forward(h)))
}
func init() {
base.Register("MistralForCausalLM", func(root *model.Root) (base.Model, error) {
bts, err := root.ReadFile("config.json")
if err != nil {
return nil, err
}
var opts Options
// TODO: implement json.Unmarshal for Options
if err := json.Unmarshal(bts, &opts); err != nil {
return nil, err
}
if err := json.Unmarshal(bts, &opts.RoPE); err != nil {
return nil, err
}
return &Model{
Layers: make([]Layer, opts.NumHiddenLayers),
Options: opts,
}, nil
})
}

View File

@@ -1,7 +0,0 @@
package model
import (
_ "github.com/ollama/ollama/x/mlxrunner/model/gemma/3"
_ "github.com/ollama/ollama/x/mlxrunner/model/glm/4/moe/lite"
_ "github.com/ollama/ollama/x/mlxrunner/model/llama"
)

View File

@@ -1,138 +0,0 @@
package mlxrunner
import (
"bytes"
"errors"
"log/slog"
"time"
"unicode/utf8"
"github.com/ollama/ollama/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
func (r *Runner) TextGenerationPipeline(request Request) error {
model, ok := r.Model.(base.TextGeneration)
if !ok {
return errors.New("model does not support causal language modeling")
}
inputs, err := r.Tokenizer.Encode(request.Prompt, true)
if err != nil {
return err
}
caches, tokens := r.FindNearestCache(inputs)
if len(caches) == 0 {
if cacher, ok := model.(base.Cacher); ok {
caches = cacher.Cache()
} else {
caches = make([]cache.Cache, model.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
}
}
}
total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
n := min(2<<10, total-processed-1)
temp := model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
defer mlx.Free(temp)
mlx.Eval(func() []*mlx.Array {
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
s[2*i], s[2*i+1] = c.State()
}
return s
}()...)
processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
mlx.ClearCache()
}
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
logits := model.Unembed(model.Forward(token.ExpandDims(0), caches))
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
// TODO: additional logit processing (logit bias, repetition penalty, etc.)
logprobs := logits.Subtract(logits.Logsumexp(true))
return request.Sample(logprobs), logprobs
}
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
mlx.AsyncEval(sample, logprobs)
// buffer partial, multibyte unicode
var b bytes.Buffer
now := time.Now()
final := Response{PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
nextSample, nextLogprobs := step(sample)
mlx.AsyncEval(nextSample, nextLogprobs)
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
mlx.Eval(sample)
final.PromptTokensDuration = time.Since(now)
now = time.Now()
}
output := int32(sample.Int())
outputs = append(outputs, output)
if r.Tokenizer.Is(output, tokenizer.SpecialEOS) {
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
break
}
request.Responses <- Response{
Text: r.Decode(output, &b),
Token: int(output),
}
mlx.Free(sample, logprobs)
if i%256 == 0 {
mlx.ClearCache()
}
sample, logprobs = nextSample, nextLogprobs
}
mlx.Free(sample, logprobs)
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
r.InsertCache(append(inputs, outputs...), caches)
return nil
}
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
token, err := r.Tokenizer.Decode([]int32{sample})
if err != nil {
slog.Error("Failed to decode tokens", "error", err)
return ""
}
if _, err := b.WriteString(token); err != nil {
slog.Error("Failed to write token to buffer", "error", err)
return ""
}
if text := b.String(); utf8.ValidString(text) {
b.Reset()
return text
} else if b.Len() >= utf8.UTFMax {
b.Reset()
return text
}
return ""
}

View File

@@ -1,110 +0,0 @@
package mlxrunner
import (
"context"
"log/slog"
"net"
"net/http"
"time"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/tokenizer"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
_ "github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/mlxrunner/sample"
)
type Request struct {
TextCompletionsRequest
Responses chan Response
Pipeline func(Request) error
sample.Sampler
caches []cache.Cache
}
type TextCompletionsRequest struct {
Prompt string `json:"prompt"`
Options struct {
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TopK int `json:"top_k"`
MaxTokens int `json:"max_tokens"`
// Deprecated: use MaxTokens instead
NumPredict int `json:"num_predict"`
} `json:"options"`
}
type Response struct {
Text string `json:"content,omitempty"`
Token int `json:"token,omitempty"`
Logprobs []float32 `json:"logprobs,omitempty"`
Done bool `json:"done,omitempty"`
DoneReason int `json:"done_reason,omitempty"`
PromptTokens int `json:"prompt_eval_count,omitempty"`
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
CompletionTokens int `json:"eval_count,omitempty"`
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Runner struct {
Model base.Model
Tokenizer tokenizer.Tokenizer
Requests chan Request
CacheEntries map[int32]*CacheEntry
}
func (r *Runner) Load(name model.Name) (err error) {
root, err := model.Open(name)
if err != nil {
return err
}
defer root.Close()
r.Model, err = base.New(root)
if err != nil {
return err
}
r.Tokenizer, err = tokenizer.New(root)
if err != nil {
return err
}
weights, quantizations, afterLoadFuncs := base.Walk(r.Model)
return mlx.LoadAll(root, weights, quantizations, afterLoadFuncs)
}
func (r *Runner) Run(host, port string, mux http.Handler) error {
g, ctx := errgroup.WithContext(context.Background())
g.Go(func() error {
for {
select {
case <-ctx.Done():
return nil
case request := <-r.Requests:
if err := request.Pipeline(request); err != nil {
break
}
close(request.Responses)
}
}
})
g.Go(func() error {
slog.Info("Starting HTTP server", "host", host, "port", port)
return http.ListenAndServe(net.JoinHostPort(host, port), mux)
})
return g.Wait()
}

View File

@@ -1,75 +0,0 @@
package sample
import (
"math"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Sampler interface {
Sample(*mlx.Array) *mlx.Array
}
func New(temp, top_p, min_p float32, top_k int) Sampler {
if temp == 0 {
return greedy{}
}
var samplers []Sampler
if top_p > 0 && top_p < 1 {
samplers = append(samplers, TopP(top_p))
}
if min_p != 0 {
samplers = append(samplers, MinP(min_p))
}
if top_k > 0 {
samplers = append(samplers, TopK(top_k))
}
samplers = append(samplers, Temperature(temp))
return chain(samplers)
}
type greedy struct{}
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
return logits.Argmax(-1, false)
}
type chain []Sampler
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
for _, sampler := range c {
logits = sampler.Sample(logits)
}
return logits
}
type Temperature float32
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
return logits.Multiply(mlx.FromValue(1 / float32(t))).Categorical(-1)
}
type TopP float32
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
// TODO: implement
return logprobs
}
type MinP float32
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
// TODO: implement
return logprobs
}
type TopK int
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}

View File

@@ -1,180 +0,0 @@
package mlxrunner
import (
"bytes"
"cmp"
"encoding/json"
"flag"
"io"
"log/slog"
"net/http"
"os"
"strconv"
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/sample"
)
func Execute(args []string) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
var (
name model.Name
port int
)
flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError)
flagSet.Var(&name, "model", "Model name")
flagSet.IntVar(&port, "port", 0, "Port to listen on")
_ = flagSet.Bool("verbose", false, "Enable debug logging")
flagSet.Parse(args)
runner := Runner{
Requests: make(chan Request),
CacheEntries: make(map[int32]*CacheEntry),
}
if err := runner.Load(name); err != nil {
return err
}
mux := http.NewServeMux()
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(map[string]any{
"status": 0,
"progress": 100,
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
})
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "POST":
fallthrough
case "GET":
if err := json.NewEncoder(w).Encode(map[string]any{
"Success": true,
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
case "DELETE":
// TODO: cleanup model and cache
}
})
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan Response)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
slog.Error("Failed to decode request", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
if request.Options.MaxTokens < 1 {
request.Options.MaxTokens = 16 << 10
}
request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New(
request.Options.Temperature,
request.Options.TopP,
request.Options.MinP,
request.Options.TopK,
)
runner.Requests <- request
w.Header().Set("Content-Type", "application/jsonl")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
for response := range request.Responses {
if err := enc.Encode(response); err != nil {
slog.Error("Failed to encode response", "error", err)
return
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
})
mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
var b bytes.Buffer
if _, err := io.Copy(&b, r.Body); err != nil {
slog.Error("Failed to read request body", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
tokens, err := runner.Tokenizer.Encode(b.String(), true)
if err != nil {
slog.Error("Failed to tokenize text", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(tokens); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
})
for source, target := range map[string]string{
"GET /health": "/v1/status",
"POST /load": "/v1/models",
"POST /completion": "/v1/completions",
} {
mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect))
}
return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK}
t := time.Now()
mux.ServeHTTP(recorder, r)
var level slog.Level
switch {
case recorder.code >= 500:
level = slog.LevelError
case recorder.code >= 400:
level = slog.LevelWarn
case recorder.code >= 300:
return
}
slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status())
}))
}
type statusRecorder struct {
http.ResponseWriter
code int
}
func (w *statusRecorder) WriteHeader(code int) {
w.code = code
w.ResponseWriter.WriteHeader(code)
}
func (w *statusRecorder) Status() string {
return strconv.Itoa(w.code) + " " + http.StatusText(w.code)
}
func (w *statusRecorder) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}