Compare commits

...

12 Commits

Author SHA1 Message Date
Patrick Devine
f2279f9a9d feature: add ctrl-g to allow users to use an editor to edit their prompt 2026-02-10 16:42:10 -08:00
Patrick Devine
44bdd9a2ef Add MLX runner with GLM4-MoE-Lite model support (#14185)
This change adds a new MLX based runner which includes:

  * Method-based MLX bindings
  * Subprocess-based MLX runner (x/mlxrunner)
  * KV cache with tree management
  * A basic sampler

The GLM4-MoE-Lite model has been ported to use the new bindings.

---------

Co-authored-by: Michael Yang <git@mxy.ng>
2026-02-10 14:57:57 -08:00
Michael
db493d6e5e docs: update broken links on FAQ and quick cleanup (#14194)
docs: update broken links on FAQ and quick cleanup
2026-02-10 16:52:20 -05:00
Bruce MacDonald
75695f16a5 docs: integration overview (#13831)
Group integrations into high-level types
2026-02-10 11:41:09 -08:00
Patrick Devine
a0407d07fa safetensors quantization for mlx (#14184)
This change includes:
  - changes to the safetensors metadata format
  - changes to the create command to properly create the blobs with the new format
  - changes to load the new format
  - fixes ollama show to properly show each tensor
2026-02-10 11:29:17 -08:00
Jeffrey Morgan
9ec733e527 cmd: make 'ollama login' and 'ollama logout' aliases for 'ollama signin' and 'ollama signout' respectively (#14144) 2026-02-09 19:12:42 -08:00
Parth Sareen
5ef04dab52 cmd: ollama launch pi (#14084) 2026-02-09 19:07:41 -08:00
Daniel Hiltgen
aea316f1e9 win: add curl-style install script (#14178)
This adds a new powershell install script suitable for running via

  irm https://ollama.com/install.ps1 | iex

If you download the script and run '-?' it reports basic usage
information, as well as usage examples for common customization
options.  The script is signed as part of the release process
to ensure it can run on a typically configured Windows system.

This does not include doc updates - we can merge those after a release
ships to avoid user confusion.
2026-02-09 15:28:11 -08:00
Patrick Devine
235ba3df5c cmd: ollama menu and launch improvements (#14038) 2026-02-09 11:30:16 -08:00
Jeffrey Morgan
099a0f18ef build: fix Dockerfile mlx directory (#14131) 2026-02-06 17:08:53 -08:00
Richard Lyons
fff696ee31 docs: increased RAM requirement for parallelism 2026-02-06 15:49:39 -08:00
Jeffrey Morgan
2e3ce6eab3 anthropic: do not count image tokens for now (#14127) 2026-02-06 15:33:18 -08:00
78 changed files with 20213 additions and 529 deletions

View File

@@ -337,6 +337,7 @@ jobs:
name: bundles-windows
path: |
dist/*.zip
dist/*.ps1
dist/OllamaSetup.exe
linux-build:
@@ -514,6 +515,9 @@ jobs:
- name: Log dist contents
run: |
ls -l dist/
- name: Copy install scripts to dist
run: |
cp scripts/install.sh dist/install.sh
- name: Generate checksum file
run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
working-directory: dist
@@ -536,7 +540,7 @@ jobs:
- name: Upload release artifacts
run: |
pids=()
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$!

View File

@@ -147,7 +147,7 @@ ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY x/imagegen/mlx x/imagegen/mlx
COPY go.mod go.sum .
COPY MLX_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local

View File

@@ -897,11 +897,5 @@ func countContentBlock(block any) int {
}
}
if source, ok := blockMap["source"].(map[string]any); ok {
if data, ok := source["data"].(string); ok {
total += len(data)
}
}
return total
}

13
cmd/background_unix.go Normal file
View File

@@ -0,0 +1,13 @@
//go:build !windows
package cmd
import "syscall"
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Unix.
// Setpgid prevents the server from being killed when the parent process exits.
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setpgid: true,
}
}

12
cmd/background_windows.go Normal file
View File

@@ -0,0 +1,12 @@
package cmd
import "syscall"
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Windows.
// CREATE_NO_WINDOW (0x08000000) prevents a console window from appearing.
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
CreationFlags: 0x08000000,
HideWindow: true,
}
}

View File

@@ -15,6 +15,7 @@ import (
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
@@ -37,6 +38,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/tui"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
@@ -1804,6 +1806,190 @@ Environment Variables:
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
}
// ensureServerRunning checks if the ollama server is running and starts it in the background if not.
func ensureServerRunning(ctx context.Context) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Check if server is already running
if err := client.Heartbeat(ctx); err == nil {
return nil // server is already running
}
// Server not running, start it in the background
exe, err := os.Executable()
if err != nil {
return fmt.Errorf("could not find executable: %w", err)
}
serverCmd := exec.CommandContext(ctx, exe, "serve")
serverCmd.Env = os.Environ()
serverCmd.SysProcAttr = backgroundServerSysProcAttr()
if err := serverCmd.Start(); err != nil {
return fmt.Errorf("failed to start server: %w", err)
}
// Wait for the server to be ready
for {
time.Sleep(500 * time.Millisecond)
if err := client.Heartbeat(ctx); err == nil {
return nil // server has started
}
}
}
// runInteractiveTUI runs the main interactive TUI menu.
func runInteractiveTUI(cmd *cobra.Command) {
// Ensure the server is running before showing the TUI
if err := ensureServerRunning(cmd.Context()); err != nil {
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
return
}
// errSelectionCancelled is returned when user cancels model selection
errSelectionCancelled := errors.New("cancelled")
// Selector adapters for tui
singleSelector := func(title string, items []config.ModelItem) (string, error) {
tuiItems := make([]tui.SelectItem, len(items))
for i, item := range items {
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description}
}
result, err := tui.SelectSingle(title, tuiItems)
if errors.Is(err, tui.ErrCancelled) {
return "", errSelectionCancelled
}
return result, err
}
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
tuiItems := make([]tui.SelectItem, len(items))
for i, item := range items {
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description}
}
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
if errors.Is(err, tui.ErrCancelled) {
return nil, errSelectionCancelled
}
return result, err
}
for {
result, err := tui.Run()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return
}
runModel := func(modelName string) {
_ = config.SetLastModel(modelName)
opts := runOptions{
Model: modelName,
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
ShowConnect: true,
}
if err := loadOrUnloadModel(cmd, &opts); err != nil {
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
return
}
if err := generateInteractive(cmd, opts); err != nil {
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
}
}
launchIntegration := func(name string) bool {
// If not configured or model no longer exists, prompt for model selection
configuredModel := config.IntegrationModel(name)
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
if errors.Is(err, errSelectionCancelled) {
return false // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
return true
}
}
if err := config.LaunchIntegration(name); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
}
return true
}
switch result.Selection {
case tui.SelectionNone:
// User quit
return
case tui.SelectionRunModel:
_ = config.SetLastSelection("run")
// Run last model directly if configured and still exists
if modelName := config.LastModel(); modelName != "" && config.ModelExists(cmd.Context(), modelName) {
runModel(modelName)
} else {
// No last model or model no longer exists, show picker
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
if errors.Is(err, errSelectionCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
continue
}
runModel(modelName)
}
case tui.SelectionChangeRunModel:
_ = config.SetLastSelection("run")
// Use model from modal if selected, otherwise show picker
modelName := result.Model
if modelName == "" {
var err error
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
if errors.Is(err, errSelectionCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
continue
}
}
runModel(modelName)
case tui.SelectionIntegration:
_ = config.SetLastSelection(result.Integration)
if !launchIntegration(result.Integration) {
continue // Return to main menu
}
case tui.SelectionChangeIntegration:
_ = config.SetLastSelection(result.Integration)
// Use model from modal if selected, otherwise show picker
if result.Model != "" {
// Model already selected from modal - save and launch
if err := config.SaveIntegrationModel(result.Integration, result.Model); err != nil {
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
continue
}
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
} else {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
if errors.Is(err, errSelectionCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
continue
}
if err := config.LaunchIntegration(result.Integration); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
}
}
}
}
func NewCLI() *cobra.Command {
log.SetFlags(log.LstdFlags | log.Lshortfile)
cobra.EnableCommandSorting = false
@@ -1826,11 +2012,13 @@ func NewCLI() *cobra.Command {
return
}
cmd.Print(cmd.UsageString())
runInteractiveTUI(cmd)
},
}
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
rootCmd.Flags().Bool("verbose", false, "Show timings for response")
rootCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
createCmd := &cobra.Command{
Use: "create MODEL",
@@ -1934,6 +2122,15 @@ func NewCLI() *cobra.Command {
RunE: SigninHandler,
}
loginCmd := &cobra.Command{
Use: "login",
Short: "Sign in to ollama.com",
Hidden: true,
Args: cobra.ExactArgs(0),
PreRunE: checkServerHeartbeat,
RunE: SigninHandler,
}
signoutCmd := &cobra.Command{
Use: "signout",
Short: "Sign out from ollama.com",
@@ -1942,6 +2139,15 @@ func NewCLI() *cobra.Command {
RunE: SignoutHandler,
}
logoutCmd := &cobra.Command{
Use: "logout",
Short: "Sign out from ollama.com",
Hidden: true,
Args: cobra.ExactArgs(0),
PreRunE: checkServerHeartbeat,
RunE: SignoutHandler,
}
listCmd := &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
@@ -2004,7 +2210,7 @@ func NewCLI() *cobra.Command {
switch cmd {
case runCmd:
imagegen.AppendFlagsDocs(cmd)
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_EDITOR"], envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
case serveCmd:
appendEnvDocs(cmd, []envconfig.EnvVar{
envVars["OLLAMA_DEBUG"],
@@ -2038,13 +2244,15 @@ func NewCLI() *cobra.Command {
pullCmd,
pushCmd,
signinCmd,
loginCmd,
signoutCmd,
logoutCmd,
listCmd,
psCmd,
copyCmd,
deleteCmd,
runnerCmd,
config.LaunchCmd(checkServerHeartbeat),
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
)
return rootCmd

View File

@@ -3,12 +3,15 @@
package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/api"
)
type integration struct {
@@ -17,7 +20,9 @@ type integration struct {
}
type config struct {
Integrations map[string]*integration `json:"integrations"`
Integrations map[string]*integration `json:"integrations"`
LastModel string `json:"last_model,omitempty"`
LastSelection string `json:"last_selection,omitempty"` // "run" or integration name
}
func configPath() (string, error) {
@@ -146,6 +151,74 @@ func saveIntegration(appName string, models []string) error {
return save(cfg)
}
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
func IntegrationModel(appName string) string {
ic, err := loadIntegration(appName)
if err != nil || len(ic.Models) == 0 {
return ""
}
return ic.Models[0]
}
// LastModel returns the last model that was run, or empty string if none.
func LastModel() string {
cfg, err := load()
if err != nil {
return ""
}
return cfg.LastModel
}
// SetLastModel saves the last model that was run.
func SetLastModel(model string) error {
cfg, err := load()
if err != nil {
return err
}
cfg.LastModel = model
return save(cfg)
}
// LastSelection returns the last menu selection ("run" or integration name), or empty string if none.
func LastSelection() string {
cfg, err := load()
if err != nil {
return ""
}
return cfg.LastSelection
}
// SetLastSelection saves the last menu selection ("run" or integration name).
func SetLastSelection(selection string) error {
cfg, err := load()
if err != nil {
return err
}
cfg.LastSelection = selection
return save(cfg)
}
// ModelExists checks if a model exists on the Ollama server.
func ModelExists(ctx context.Context, name string) bool {
if name == "" {
return false
}
client, err := api.ClientFromEnvironment()
if err != nil {
return false
}
models, err := client.List(ctx)
if err != nil {
return false
}
for _, m := range models.Models {
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
return true
}
}
return false
}
func loadIntegration(appName string) (*integration, error) {
cfg, err := load()
if err != nil {

View File

@@ -4,9 +4,12 @@ import (
"context"
"errors"
"fmt"
"io"
"maps"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
@@ -57,11 +60,12 @@ var integrations = map[string]Runner{
"droid": &Droid{},
"opencode": &OpenCode{},
"openclaw": &Openclaw{},
"pi": &Pi{},
}
// recommendedModels are shown when the user has no models or as suggestions.
// Order matters: local models first, then cloud models.
var recommendedModels = []selectItem{
var recommendedModels = []ModelItem{
{Name: "glm-4.7-flash", Description: "Recommended (requires ~25GB VRAM)"},
{Name: "qwen3:8b", Description: "Recommended (requires ~11GB VRAM)"},
{Name: "glm-4.7:cloud", Description: "Recommended"},
@@ -74,6 +78,256 @@ var integrationAliases = map[string]bool{
"moltbot": true,
}
// integrationInstallURLs maps integration names to their install script URLs.
var integrationInstallURLs = map[string]string{
"claude": "https://claude.ai/install.sh",
"openclaw": "https://openclaw.ai/install.sh",
"droid": "https://app.factory.ai/cli",
"opencode": "https://opencode.ai/install",
}
// CanInstallIntegration returns true if we have an install script for this integration.
func CanInstallIntegration(name string) bool {
_, ok := integrationInstallURLs[name]
return ok
}
// IsIntegrationInstalled checks if an integration binary is installed.
func IsIntegrationInstalled(name string) bool {
switch name {
case "claude":
c := &Claude{}
_, err := c.findPath()
return err == nil
case "openclaw":
if _, err := exec.LookPath("openclaw"); err == nil {
return true
}
if _, err := exec.LookPath("clawdbot"); err == nil {
return true
}
return false
case "codex":
_, err := exec.LookPath("codex")
return err == nil
case "droid":
_, err := exec.LookPath("droid")
return err == nil
case "opencode":
_, err := exec.LookPath("opencode")
return err == nil
default:
return true // Assume installed for unknown integrations
}
}
// InstallIntegration downloads and runs the install script for an integration.
func InstallIntegration(name string) error {
url, ok := integrationInstallURLs[name]
if !ok {
return fmt.Errorf("no install script available for %s", name)
}
// Download the install script
resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("failed to download install script: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download install script: HTTP %d", resp.StatusCode)
}
script, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read install script: %w", err)
}
// Create a temporary file for the script
tmpDir := os.TempDir()
scriptPath := filepath.Join(tmpDir, fmt.Sprintf("install-%s.sh", name))
if err := os.WriteFile(scriptPath, script, 0o700); err != nil {
return fmt.Errorf("failed to write install script: %w", err)
}
defer os.Remove(scriptPath)
// Execute the script with bash
cmd := exec.Command("bash", scriptPath)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("install script failed: %w", err)
}
return nil
}
// SelectModel lets the user select a model to run.
// ModelItem represents a model for selection.
type ModelItem struct {
Name string
Description string
}
// SingleSelector is a function type for single item selection.
type SingleSelector func(title string, items []ModelItem) (string, error)
// MultiSelector is a function type for multi item selection.
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
// SelectModelWithSelector prompts the user to select a model using the provided selector.
func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (string, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return "", err
}
models, err := client.List(ctx)
if err != nil {
return "", err
}
var existing []modelInfo
for _, m := range models.Models {
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
}
lastModel := LastModel()
var preChecked []string
if lastModel != "" {
preChecked = []string{lastModel}
}
items, _, existingModels, cloudModels := buildModelList(existing, preChecked, lastModel)
if len(items) == 0 {
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
// Sort with last model first, then existing models, then recommendations
slices.SortStableFunc(items, func(a, b ModelItem) int {
aIsLast := a.Name == lastModel
bIsLast := b.Name == lastModel
if aIsLast != bIsLast {
if aIsLast {
return -1
}
return 1
}
aExists := existingModels[a.Name]
bExists := existingModels[b.Name]
if aExists != bExists {
if aExists {
return -1
}
return 1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
selected, err := selector("Select model to run:", items)
if err != nil {
return "", err
}
// If the selected model isn't installed, pull it first
if !existingModels[selected] {
msg := fmt.Sprintf("Download %s?", selected)
if ok, err := confirmPrompt(msg); err != nil {
return "", err
} else if !ok {
return "", errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
if err := pullModel(ctx, client, selected); err != nil {
return "", fmt.Errorf("failed to pull %s: %w", selected, err)
}
}
// If it's a cloud model, ensure user is signed in
if cloudModels[selected] {
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return selected, nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return "", err
}
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", selected))
if err != nil || !yes {
return "", fmt.Errorf("%s requires sign in", selected)
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
// Auto-open browser (best effort, fail silently)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return "", ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return selected, nil
}
}
}
}
}
return selected, nil
}
func SelectModel(ctx context.Context) (string, error) {
return SelectModelWithSelector(ctx, defaultSingleSelector)
}
func defaultSingleSelector(title string, items []ModelItem) (string, error) {
selectItems := make([]selectItem, len(items))
for i, item := range items {
selectItems[i] = selectItem(item)
}
return selectPrompt(title, selectItems)
}
func defaultMultiSelector(title string, items []ModelItem, preChecked []string) ([]string, error) {
selectItems := make([]selectItem, len(items))
for i, item := range items {
selectItems[i] = selectItem(item)
}
return multiSelectPrompt(title, selectItems, preChecked)
}
func selectIntegration() (string, error) {
if len(integrations) == 0 {
return "", fmt.Errorf("no integrations available")
@@ -96,8 +350,8 @@ func selectIntegration() (string, error) {
return selectPrompt("Select integration:", items)
}
// selectModels lets the user select models for an integration
func selectModels(ctx context.Context, name, current string) ([]string, error) {
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
func selectModelsWithSelectors(ctx context.Context, name, current string, single SingleSelector, multi MultiSelector) ([]string, error) {
r, ok := integrations[name]
if !ok {
return nil, fmt.Errorf("unknown integration: %s", name)
@@ -133,7 +387,7 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
var selected []string
if _, ok := r.(Editor); ok {
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
selected, err = multi(fmt.Sprintf("Select models for %s:", r), items, preChecked)
if err != nil {
return nil, err
}
@@ -142,7 +396,7 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
if _, ok := r.(AliasConfigurer); ok {
prompt = fmt.Sprintf("Select Primary model for %s:", r)
}
model, err := selectPrompt(prompt, items)
model, err := single(prompt, items)
if err != nil {
return nil, err
}
@@ -227,12 +481,17 @@ func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]
})
}
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
modelItems, _, existingModels, cloudModels := buildModelList(existing, nil, "")
if len(items) == 0 {
if len(modelItems) == 0 {
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
items := make([]selectItem, len(modelItems))
for i, mi := range modelItems {
items[i] = selectItem(mi)
}
return items, existingModels, cloudModels, client, nil
}
@@ -303,6 +562,11 @@ func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]
}
}
// selectModels lets the user select models for an integration using default selectors.
func selectModels(ctx context.Context, name, current string) ([]string, error) {
return selectModelsWithSelectors(ctx, name, current, defaultSingleSelector, defaultMultiSelector)
}
func runIntegration(name, modelName string, args []string) error {
r, ok := integrations[name]
if !ok {
@@ -335,15 +599,110 @@ func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, na
return saveAliases(name, aliases)
}
// LaunchIntegration launches the named integration using saved config or prompts for setup.
func LaunchIntegration(name string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
// Try to use saved config
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0], nil)
}
// No saved config - prompt user to run setup
return fmt.Errorf("%s is not configured. Run 'ollama launch %s' to set it up", r, name)
}
// LaunchIntegrationWithModel launches the named integration with the specified model.
func LaunchIntegrationWithModel(name, modelName string) error {
return runIntegration(name, modelName, nil)
}
// SaveIntegrationModel saves the model for an integration.
func SaveIntegrationModel(name, modelName string) error {
// Load existing models and prepend the new one
var models []string
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
models = existing.Models
// Remove the model if it already exists
for i, m := range models {
if m == modelName {
models = append(models[:i], models[i+1:]...)
break
}
}
}
// Prepend the new model
models = append([]string{modelName}, models...)
return saveIntegration(name, models)
}
// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors.
func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
models, err := selectModelsWithSelectors(ctx, name, "", single, multi)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if editor, isEditor := r.(Editor); isEditor {
paths := editor.Paths()
if len(paths) > 0 {
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
for _, p := range paths {
fmt.Fprintf(os.Stderr, " %s\n", p)
}
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
if ok, _ := confirmPrompt("Proceed?"); !ok {
return nil
}
}
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
}
if err := saveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
if len(models) == 1 {
fmt.Fprintf(os.Stderr, "Configured %s with %s\n", r, models[0])
} else {
fmt.Fprintf(os.Stderr, "Configured %s with %d models (default: %s)\n", r, len(models), models[0])
}
return nil
}
// ConfigureIntegration allows the user to select/change the model for an integration.
func ConfigureIntegration(ctx context.Context, name string) error {
return ConfigureIntegrationWithSelectors(ctx, name, defaultSingleSelector, defaultMultiSelector)
}
// LaunchCmd returns the cobra command for launching integrations.
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
// The runTUI callback is called when no arguments are provided (alias for main TUI).
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
var modelFlag string
var configFlag bool
cmd := &cobra.Command{
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
Short: "Launch an integration with Ollama",
Long: `Launch an integration configured with Ollama models.
Short: "Launch the Ollama menu or an integration",
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
Without arguments, this is equivalent to running 'ollama' directly.
Supported integrations:
claude Claude Code
@@ -362,6 +721,12 @@ Examples:
Args: cobra.ArbitraryArgs,
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
// No args - run the main TUI (same as 'ollama')
if len(args) == 0 && modelFlag == "" && !configFlag {
runTUI(cmd)
return nil
}
// Extract integration name and args to pass through using -- separator
var name string
var passArgs []string
@@ -582,7 +947,7 @@ type modelInfo struct {
// buildModelList merges existing models with recommendations, sorts them, and returns
// the ordered items along with maps of existing and cloud model names.
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []selectItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
existingModels = make(map[string]bool)
cloudModels = make(map[string]bool)
recommended := make(map[string]bool)
@@ -602,7 +967,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
}
displayName := strings.TrimSuffix(m.Name, ":latest")
existingModels[displayName] = true
item := selectItem{Name: displayName}
item := ModelItem{Name: displayName}
if recommended[displayName] {
item.Description = "recommended"
}
@@ -651,7 +1016,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
}
if hasLocalModel || hasCloudModel {
slices.SortStableFunc(items, func(a, b selectItem) int {
slices.SortStableFunc(items, func(a, b ModelItem) int {
ac, bc := checked[a.Name], checked[b.Name]
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
@@ -686,6 +1051,56 @@ func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
return resp.RemoteModel != ""
}
// GetModelItems returns a list of model items including recommendations for the TUI.
// It includes all locally available models plus recommended models that aren't installed.
func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, nil
}
models, err := client.List(ctx)
if err != nil {
return nil, nil
}
var existing []modelInfo
for _, m := range models.Models {
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
}
lastModel := LastModel()
var preChecked []string
if lastModel != "" {
preChecked = []string{lastModel}
}
items, _, existingModels, _ := buildModelList(existing, preChecked, lastModel)
// Sort with last model first, then existing models, then recommendations
slices.SortStableFunc(items, func(a, b ModelItem) int {
aIsLast := a.Name == lastModel
bIsLast := b.Name == lastModel
if aIsLast != bIsLast {
if aIsLast {
return -1
}
return 1
}
aExists := existingModels[a.Name]
bExists := existingModels[b.Name]
if aExists != bExists {
if aExists {
return -1
}
return 1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
return items, existingModels
}
func pullModel(ctx context.Context, client *api.Client, model string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()

View File

@@ -94,8 +94,10 @@ func TestLaunchCmd(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
// Mock TUI function (not called in these tests)
mockTUI := func(cmd *cobra.Command) {}
cmd := LaunchCmd(mockCheck)
cmd := LaunchCmd(mockCheck, mockTUI)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
@@ -128,6 +130,75 @@ func TestLaunchCmd(t *testing.T) {
})
}
func TestLaunchCmd_TUICallback(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
t.Run("no args calls TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{})
_ = cmd.Execute()
if !tuiCalled {
t.Error("TUI callback should be called when no args provided")
}
})
t.Run("integration arg bypasses TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"claude"})
// Will error because claude isn't configured, but that's OK
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when integration arg provided")
}
})
t.Run("--model flag bypasses TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--model", "test-model"})
// Will error because no integration specified, but that's OK
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when --model flag provided")
}
})
t.Run("--config flag bypasses TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--config"})
// Will error because no integration specified, but that's OK
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when --config flag provided")
}
})
}
func TestRunIntegration_UnknownIntegration(t *testing.T) {
err := runIntegration("unknown-integration", "model", nil)
if err == nil {
@@ -168,7 +239,7 @@ func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
// This should not panic - cmd creation should work even with nil
cmd := LaunchCmd(nil)
cmd := LaunchCmd(nil, nil)
if cmd == nil {
t.Fatal("LaunchCmd returned nil")
}
@@ -314,7 +385,7 @@ func TestIsCloudModel(t *testing.T) {
})
}
func names(items []selectItem) []string {
func names(items []ModelItem) []string {
var out []string
for _, item := range items {
out = append(out, item.Name)

237
cmd/config/pi.go Normal file
View File

@@ -0,0 +1,237 @@
package config
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
type Pi struct{}
func (p *Pi) String() string { return "Pi" }
func (p *Pi) Run(model string, args []string) error {
if _, err := exec.LookPath("pi"); err != nil {
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := p.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("pi", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (p *Pi) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
if _, err := os.Stat(modelsPath); err == nil {
paths = append(paths, modelsPath)
}
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
if _, err := os.Stat(settingsPath); err == nil {
paths = append(paths, settingsPath)
}
return paths
}
func (p *Pi) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
}
providers, ok := config["providers"].(map[string]any)
if !ok {
providers = make(map[string]any)
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"baseUrl": envconfig.Host().String() + "/v1",
"api": "openai-completions",
"apiKey": "ollama",
}
}
existingModels, ok := ollama["models"].([]any)
if !ok {
existingModels = make([]any, 0)
}
// Build set of selected models to track which need to be added
selectedSet := make(map[string]bool, len(models))
for _, m := range models {
selectedSet[m] = true
}
// Build new models list:
// 1. Keep user-managed models (no _launch marker) - untouched
// 2. Keep ollama-managed models (_launch marker) that are still selected
// 3. Add new ollama-managed models
var newModels []any
for _, m := range existingModels {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
// User-managed model (no _launch marker) - always preserve
if !isPiOllamaModel(modelObj) {
newModels = append(newModels, m)
} else if selectedSet[id] {
// Ollama-managed and still selected - keep it
newModels = append(newModels, m)
selectedSet[id] = false
}
}
}
}
// Add newly selected models that weren't already in the list
client := api.NewClient(envconfig.Host(), http.DefaultClient)
ctx := context.Background()
for _, model := range models {
if selectedSet[model] {
newModels = append(newModels, createConfig(ctx, client, model))
}
}
ollama["models"] = newModels
providers["ollama"] = ollama
config["providers"] = providers
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
return err
}
// Update settings.json with default provider and model
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
settings := make(map[string]any)
if data, err := os.ReadFile(settingsPath); err == nil {
_ = json.Unmarshal(data, &settings)
}
settings["defaultProvider"] = "ollama"
settings["defaultModel"] = models[0]
settingsData, err := json.MarshalIndent(settings, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, settingsData)
}
func (p *Pi) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
config, err := readJSONFile(configPath)
if err != nil {
return nil
}
providers, _ := config["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
models, _ := ollama["models"].([]any)
var result []string
for _, m := range models {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
result = append(result, id)
}
}
}
slices.Sort(result)
return result
}
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
func isPiOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
}
return false
}
// createConfig builds Pi model config with capability detection
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
cfg := map[string]any{
"id": modelID,
"_launch": true,
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
if err != nil {
return cfg
}
// Set input types based on vision capability
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
cfg["input"] = []string{"text", "image"}
} else {
cfg["input"] = []string{"text"}
}
// Set reasoning based on thinking capability
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
cfg["reasoning"] = true
}
// Extract context window from ModelInfo
for key, val := range resp.ModelInfo {
if strings.HasSuffix(key, ".context_length") {
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
cfg["contextWindow"] = int(ctxLen)
}
break
}
}
return cfg
}

830
cmd/config/pi_test.go Normal file
View File

@@ -0,0 +1,830 @@
package config
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
func TestPiIntegration(t *testing.T) {
pi := &Pi{}
t.Run("String", func(t *testing.T) {
if got := pi.String(); got != "Pi" {
t.Errorf("String() = %q, want %q", got, "Pi")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = pi
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = pi
})
}
func TestPiPaths(t *testing.T) {
pi := &Pi{}
t.Run("returns empty when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
paths := pi.Paths()
if len(paths) != 0 {
t.Errorf("Paths() = %v, want empty", paths)
}
})
t.Run("returns path when config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte("{}"), 0o644); err != nil {
t.Fatal(err)
}
paths := pi.Paths()
if len(paths) != 1 || paths[0] != configPath {
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
}
})
}
func TestPiEdit(t *testing.T) {
// Mock Ollama server for createConfig calls during Edit
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
pi := &Pi{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
configPath := filepath.Join(configDir, "models.json")
cleanup := func() {
os.RemoveAll(configDir)
}
readConfig := func() map[string]any {
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
return cfg
}
t.Run("returns nil for empty models", func(t *testing.T) {
if err := pi.Edit([]string{}); err != nil {
t.Errorf("Edit([]) error = %v, want nil", err)
}
})
t.Run("creates config with models", func(t *testing.T) {
cleanup()
models := []string{"llama3.2", "qwen3:8b"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers, ok := cfg["providers"].(map[string]any)
if !ok {
t.Error("Config missing providers")
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
t.Error("Providers missing ollama")
}
modelsArray, ok := ollama["models"].([]any)
if !ok || len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %v", modelsArray)
}
if ollama["baseUrl"] == nil {
t.Error("Missing baseUrl")
}
if ollama["api"] != "openai-completions" {
t.Errorf("Expected api=openai-completions, got %v", ollama["api"])
}
if ollama["apiKey"] != "ollama" {
t.Errorf("Expected apiKey=ollama, got %v", ollama["apiKey"])
}
})
t.Run("updates existing config preserving ollama provider settings", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://custom:8080/v1",
"api": "custom-api",
"apiKey": "custom-key",
"models": [
{"id": "old-model", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"new-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
if ollama["baseUrl"] != "http://custom:8080/v1" {
t.Errorf("Custom baseUrl not preserved, got %v", ollama["baseUrl"])
}
if ollama["api"] != "custom-api" {
t.Errorf("Custom api not preserved, got %v", ollama["api"])
}
if ollama["apiKey"] != "custom-key" {
t.Errorf("Custom apiKey not preserved, got %v", ollama["apiKey"])
}
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 1 {
t.Errorf("Expected 1 model after update, got %d", len(modelsArray))
} else {
modelEntry := modelsArray[0].(map[string]any)
if modelEntry["id"] != "new-model" {
t.Errorf("Expected new-model, got %v", modelEntry["id"])
}
// Verify _launch marker is present
if modelEntry["_launch"] != true {
t.Errorf("Expected _launch marker to be true")
}
}
})
t.Run("replaces old models with new ones", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Old models must have _launch marker to be managed by us
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "old-model-1", "_launch": true},
{"id": "old-model-2", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
newModels := []string{"new-model-1", "new-model-2"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %d", len(modelsArray))
}
modelIDs := make(map[string]bool)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = true
}
if !modelIDs["new-model-1"] || !modelIDs["new-model-2"] {
t.Errorf("Expected new models, got %v", modelIDs)
}
if modelIDs["old-model-1"] || modelIDs["old-model-2"] {
t.Errorf("Old models should have been removed, got %v", modelIDs)
}
})
t.Run("handles partial overlap in model list", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Models must have _launch marker to be managed
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "keep-model", "_launch": true},
{"id": "remove-model", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
newModels := []string{"keep-model", "add-model"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %d", len(modelsArray))
}
modelIDs := make(map[string]bool)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = true
}
if !modelIDs["keep-model"] || !modelIDs["add-model"] {
t.Errorf("Expected keep-model and add-model, got %v", modelIDs)
}
if modelIDs["remove-model"] {
t.Errorf("remove-model should have been removed")
}
})
t.Run("handles corrupt config gracefully", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"test-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() should not fail with corrupt config, got %v", err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("Failed to read config: %v", err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("Config should be valid after Edit, got parse error: %v", err)
}
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 1 {
t.Errorf("Expected 1 model, got %d", len(modelsArray))
}
})
// CRITICAL SAFETY TEST: verifies we don't stomp on user configs
t.Run("preserves user-managed models without _launch marker", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// User has manually configured models in ollama provider (no _launch marker)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "user-model-1"},
{"id": "user-model-2", "customField": "preserved"},
{"id": "ollama-managed", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
// Add a new ollama-managed model
newModels := []string{"new-ollama-model"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
// Should have: new-ollama-model (managed) + 2 user models (preserved)
if len(modelsArray) != 3 {
t.Errorf("Expected 3 models (1 new managed + 2 preserved user models), got %d", len(modelsArray))
}
modelIDs := make(map[string]map[string]any)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = modelObj
}
// Verify new model has _launch marker
if m, ok := modelIDs["new-ollama-model"]; !ok {
t.Errorf("new-ollama-model should be present")
} else if m["_launch"] != true {
t.Errorf("new-ollama-model should have _launch marker")
}
// Verify user models are preserved
if _, ok := modelIDs["user-model-1"]; !ok {
t.Errorf("user-model-1 should be preserved")
}
if _, ok := modelIDs["user-model-2"]; !ok {
t.Errorf("user-model-2 should be preserved")
} else if modelIDs["user-model-2"]["customField"] != "preserved" {
t.Errorf("user-model-2 customField should be preserved")
}
// Verify old ollama-managed model is removed (not in new list)
if _, ok := modelIDs["ollama-managed"]; ok {
t.Errorf("ollama-managed should be removed (old ollama model not in new selection)")
}
})
t.Run("updates settings.json with default provider and model", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Create existing settings with other fields
settingsPath := filepath.Join(configDir, "settings.json")
existingSettings := `{
"theme": "dark",
"customSetting": "value",
"defaultProvider": "anthropic",
"defaultModel": "claude-3"
}`
if err := os.WriteFile(settingsPath, []byte(existingSettings), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"llama3.2"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("Failed to read settings: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("Failed to parse settings: %v", err)
}
// Verify defaultProvider is set to ollama
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
// Verify defaultModel is set to first model
if settings["defaultModel"] != "llama3.2" {
t.Errorf("defaultModel = %v, want llama3.2", settings["defaultModel"])
}
// Verify other fields are preserved
if settings["theme"] != "dark" {
t.Errorf("theme = %v, want dark (preserved)", settings["theme"])
}
if settings["customSetting"] != "value" {
t.Errorf("customSetting = %v, want value (preserved)", settings["customSetting"])
}
})
t.Run("creates settings.json if it does not exist", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
models := []string{"qwen3:8b"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
settingsPath := filepath.Join(configDir, "settings.json")
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("settings.json should be created: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("Failed to parse settings: %v", err)
}
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
if settings["defaultModel"] != "qwen3:8b" {
t.Errorf("defaultModel = %v, want qwen3:8b", settings["defaultModel"])
}
})
t.Run("handles corrupt settings.json gracefully", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Create corrupt settings
settingsPath := filepath.Join(configDir, "settings.json")
if err := os.WriteFile(settingsPath, []byte("{invalid"), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"test-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() should not fail with corrupt settings, got %v", err)
}
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("Failed to read settings: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("settings.json should be valid after Edit, got parse error: %v", err)
}
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
if settings["defaultModel"] != "test-model" {
t.Errorf("defaultModel = %v, want test-model", settings["defaultModel"])
}
})
}
func TestPiModels(t *testing.T) {
pi := &Pi{}
t.Run("returns nil when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns models from config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {
"models": [
{"id": "llama3.2"},
{"id": "qwen3:8b"}
]
}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if len(models) != 2 {
t.Errorf("Models() returned %d models, want 2", len(models))
}
if models[0] != "llama3.2" || models[1] != "qwen3:8b" {
t.Errorf("Models() = %v, want [llama3.2 qwen3:8b] (sorted)", models)
}
})
t.Run("returns sorted models", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {
"models": [
{"id": "z-model"},
{"id": "a-model"},
{"id": "m-model"}
]
}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models[0] != "a-model" || models[1] != "m-model" || models[2] != "z-model" {
t.Errorf("Models() = %v, want [a-model m-model z-model] (sorted)", models)
}
})
t.Run("returns nil when models array is missing", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil when models array is missing", models)
}
})
t.Run("handles corrupt config gracefully", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil for corrupt config", models)
}
})
}
func TestIsPiOllamaModel(t *testing.T) {
tests := []struct {
name string
cfg map[string]any
want bool
}{
{"with _launch true", map[string]any{"id": "m", "_launch": true}, true},
{"with _launch false", map[string]any{"id": "m", "_launch": false}, false},
{"without _launch", map[string]any{"id": "m"}, false},
{"with _launch non-bool", map[string]any{"id": "m", "_launch": "yes"}, false},
{"empty map", map[string]any{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isPiOllamaModel(tt.cfg); got != tt.want {
t.Errorf("isPiOllamaModel(%v) = %v, want %v", tt.cfg, got, tt.want)
}
})
}
}
func TestCreateConfig(t *testing.T) {
t.Run("sets vision input when model has vision capability", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "llava:7b")
if cfg["id"] != "llava:7b" {
t.Errorf("id = %v, want llava:7b", cfg["id"])
}
if cfg["_launch"] != true {
t.Error("expected _launch = true")
}
input, ok := cfg["input"].([]string)
if !ok || len(input) != 2 || input[0] != "text" || input[1] != "image" {
t.Errorf("input = %v, want [text image]", cfg["input"])
}
})
t.Run("sets text-only input when model lacks vision", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["completion"],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "llama3.2")
input, ok := cfg["input"].([]string)
if !ok || len(input) != 1 || input[0] != "text" {
t.Errorf("input = %v, want [text]", cfg["input"])
}
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set for non-thinking model")
}
})
t.Run("sets reasoning when model has thinking capability", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "qwq")
if cfg["reasoning"] != true {
t.Error("expected reasoning = true for thinking model")
}
})
t.Run("extracts context window from model info", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":131072}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "llama3.2")
if cfg["contextWindow"] != 131072 {
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
}
})
t.Run("handles all capabilities together", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["vision","thinking"],"model_info":{"qwen3.context_length":32768}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "qwen3-vision")
input := cfg["input"].([]string)
if len(input) != 2 || input[0] != "text" || input[1] != "image" {
t.Errorf("input = %v, want [text image]", input)
}
if cfg["reasoning"] != true {
t.Error("expected reasoning = true")
}
if cfg["contextWindow"] != 32768 {
t.Errorf("contextWindow = %v, want 32768", cfg["contextWindow"])
}
})
t.Run("returns minimal config when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "missing-model")
if cfg["id"] != "missing-model" {
t.Errorf("id = %v, want missing-model", cfg["id"])
}
if cfg["_launch"] != true {
t.Error("expected _launch = true")
}
// Should not have capability fields
if _, ok := cfg["input"]; ok {
t.Error("input should not be set when show fails")
}
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set when show fails")
}
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set when show fails")
}
})
t.Run("skips zero context length", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":0}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "test-model")
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set for zero value")
}
})
}
// Ensure Capability constants used in createConfig match expected values
func TestPiCapabilityConstants(t *testing.T) {
if model.CapabilityVision != "vision" {
t.Errorf("CapabilityVision = %q, want %q", model.CapabilityVision, "vision")
}
if model.CapabilityThinking != "thinking" {
t.Errorf("CapabilityThinking = %q, want %q", model.CapabilityThinking, "thinking")
}
}

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"regexp"
"slices"
@@ -79,6 +80,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " Ctrl + w Delete the word before the cursor")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen")
fmt.Fprintln(os.Stderr, " Ctrl + g Open default editor to compose a prompt")
fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding")
fmt.Fprintln(os.Stderr, " Ctrl + d Exit ollama (/bye)")
fmt.Fprintln(os.Stderr, "")
@@ -147,6 +149,18 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
scanner.Prompt.UseAlt = false
sb.Reset()
continue
case errors.Is(err, readline.ErrEditPrompt):
sb.Reset()
content, err := editInExternalEditor(line)
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
continue
}
if strings.TrimSpace(content) == "" {
continue
}
scanner.Prefill = content
continue
case err != nil:
return err
@@ -598,6 +612,57 @@ func extractFileData(input string) (string, []api.ImageData, error) {
return strings.TrimSpace(input), imgs, nil
}
func editInExternalEditor(content string) (string, error) {
editor := envconfig.Editor()
if editor == "" {
editor = os.Getenv("VISUAL")
}
if editor == "" {
editor = os.Getenv("EDITOR")
}
if editor == "" {
editor = "vi"
}
// Check that the editor binary exists
name := strings.Fields(editor)[0]
if _, err := exec.LookPath(name); err != nil {
return "", fmt.Errorf("editor %q not found, set OLLAMA_EDITOR to the path of your preferred editor", name)
}
tmpFile, err := os.CreateTemp("", "ollama-prompt-*.txt")
if err != nil {
return "", fmt.Errorf("creating temp file: %w", err)
}
defer os.Remove(tmpFile.Name())
if content != "" {
if _, err := tmpFile.WriteString(content); err != nil {
tmpFile.Close()
return "", fmt.Errorf("writing to temp file: %w", err)
}
}
tmpFile.Close()
args := strings.Fields(editor)
args = append(args, tmpFile.Name())
cmd := exec.Command(args[0], args[1:]...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("editor exited with error: %w", err)
}
data, err := os.ReadFile(tmpFile.Name())
if err != nil {
return "", fmt.Errorf("reading temp file: %w", err)
}
return strings.TrimRight(string(data), "\n"), nil
}
func getImageData(filePath string) ([]byte, error) {
file, err := os.Open(filePath)
if err != nil {

507
cmd/tui/selector.go Normal file
View File

@@ -0,0 +1,507 @@
package tui
import (
"errors"
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
var (
selectorTitleStyle = lipgloss.NewStyle().
Bold(true)
selectorItemStyle = lipgloss.NewStyle().
PaddingLeft(4)
selectorSelectedItemStyle = lipgloss.NewStyle().
PaddingLeft(2).
Bold(true)
selectorDescStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241"))
selectorFilterStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
Italic(true)
selectorInputStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("252"))
selectorCheckboxStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241"))
selectorCheckboxCheckedStyle = lipgloss.NewStyle().
Bold(true)
selectorDefaultTagStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
Italic(true)
selectorHelpStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241"))
selectorMoreStyle = lipgloss.NewStyle().
PaddingLeft(4).
Foreground(lipgloss.Color("241")).
Italic(true)
)
const maxSelectorItems = 10
// ErrCancelled is returned when the user cancels the selection.
var ErrCancelled = errors.New("cancelled")
// SelectItem represents an item that can be selected.
type SelectItem struct {
Name string
Description string
}
// selectorModel is the bubbletea model for single selection.
type selectorModel struct {
title string
items []SelectItem
filter string
cursor int
scrollOffset int
selected string
cancelled bool
}
func (m selectorModel) filteredItems() []SelectItem {
if m.filter == "" {
return m.items
}
filterLower := strings.ToLower(m.filter)
var result []SelectItem
for _, item := range m.items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}
func (m selectorModel) Init() tea.Cmd {
return nil
}
func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
filtered := m.filteredItems()
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
case tea.KeyEnter:
if len(filtered) > 0 && m.cursor < len(filtered) {
m.selected = filtered[m.cursor].Name
}
return m, tea.Quit
case tea.KeyUp:
if m.cursor > 0 {
m.cursor--
if m.cursor < m.scrollOffset {
m.scrollOffset = m.cursor
}
}
case tea.KeyDown:
if m.cursor < len(filtered)-1 {
m.cursor++
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
}
case tea.KeyPgUp:
m.cursor -= maxSelectorItems
if m.cursor < 0 {
m.cursor = 0
}
m.scrollOffset -= maxSelectorItems
if m.scrollOffset < 0 {
m.scrollOffset = 0
}
case tea.KeyPgDown:
m.cursor += maxSelectorItems
if m.cursor >= len(filtered) {
m.cursor = len(filtered) - 1
}
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
case tea.KeyBackspace:
if len(m.filter) > 0 {
m.filter = m.filter[:len(m.filter)-1]
m.cursor = 0
m.scrollOffset = 0
}
case tea.KeyRunes:
m.filter += string(msg.Runes)
m.cursor = 0
m.scrollOffset = 0
}
}
return m, nil
}
func (m selectorModel) View() string {
// Clear screen when exiting
if m.cancelled || m.selected != "" {
return ""
}
var s strings.Builder
// Title with filter
s.WriteString(selectorTitleStyle.Render(m.title))
s.WriteString(" ")
if m.filter == "" {
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
} else {
s.WriteString(selectorInputStyle.Render(m.filter))
}
s.WriteString("\n\n")
filtered := m.filteredItems()
if len(filtered) == 0 {
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
s.WriteString("\n")
} else {
displayCount := min(len(filtered), maxSelectorItems)
for i := range displayCount {
idx := m.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
} else {
s.WriteString(selectorItemStyle.Render(item.Name))
}
if item.Description != "" {
s.WriteString(" ")
s.WriteString(selectorDescStyle.Render("- " + item.Description))
}
s.WriteString("\n")
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
s.WriteString("\n")
}
}
s.WriteString("\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • esc cancel"))
return s.String()
}
// SelectSingle prompts the user to select a single item from a list.
func SelectSingle(title string, items []SelectItem) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
m := selectorModel{
title: title,
items: items,
}
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return "", fmt.Errorf("error running selector: %w", err)
}
fm := finalModel.(selectorModel)
if fm.cancelled {
return "", ErrCancelled
}
return fm.selected, nil
}
// multiSelectorModel is the bubbletea model for multi selection.
type multiSelectorModel struct {
title string
items []SelectItem
itemIndex map[string]int
filter string
cursor int
scrollOffset int
checked map[int]bool
checkOrder []int
cancelled bool
confirmed bool
}
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
m := multiSelectorModel{
title: title,
items: items,
itemIndex: make(map[string]int, len(items)),
checked: make(map[int]bool),
}
for i, item := range items {
m.itemIndex[item.Name] = i
}
for _, name := range preChecked {
if idx, ok := m.itemIndex[name]; ok {
m.checked[idx] = true
m.checkOrder = append(m.checkOrder, idx)
}
}
return m
}
func (m multiSelectorModel) filteredItems() []SelectItem {
if m.filter == "" {
return m.items
}
filterLower := strings.ToLower(m.filter)
var result []SelectItem
for _, item := range m.items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}
func (m *multiSelectorModel) toggleItem() {
filtered := m.filteredItems()
if len(filtered) == 0 || m.cursor >= len(filtered) {
return
}
item := filtered[m.cursor]
origIdx := m.itemIndex[item.Name]
if m.checked[origIdx] {
delete(m.checked, origIdx)
for i, idx := range m.checkOrder {
if idx == origIdx {
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
break
}
}
} else {
m.checked[origIdx] = true
m.checkOrder = append(m.checkOrder, origIdx)
}
}
func (m multiSelectorModel) selectedCount() int {
return len(m.checkOrder)
}
func (m multiSelectorModel) Init() tea.Cmd {
return nil
}
func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
filtered := m.filteredItems()
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
case tea.KeyEnter:
// Enter confirms if at least one item is selected
if len(m.checkOrder) > 0 {
m.confirmed = true
return m, tea.Quit
}
case tea.KeySpace:
// Space always toggles selection
m.toggleItem()
case tea.KeyUp:
if m.cursor > 0 {
m.cursor--
if m.cursor < m.scrollOffset {
m.scrollOffset = m.cursor
}
}
case tea.KeyDown:
if m.cursor < len(filtered)-1 {
m.cursor++
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
}
case tea.KeyPgUp:
m.cursor -= maxSelectorItems
if m.cursor < 0 {
m.cursor = 0
}
m.scrollOffset -= maxSelectorItems
if m.scrollOffset < 0 {
m.scrollOffset = 0
}
case tea.KeyPgDown:
m.cursor += maxSelectorItems
if m.cursor >= len(filtered) {
m.cursor = len(filtered) - 1
}
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
case tea.KeyBackspace:
if len(m.filter) > 0 {
m.filter = m.filter[:len(m.filter)-1]
m.cursor = 0
m.scrollOffset = 0
}
case tea.KeyRunes:
m.filter += string(msg.Runes)
m.cursor = 0
m.scrollOffset = 0
}
}
return m, nil
}
func (m multiSelectorModel) View() string {
// Clear screen when exiting
if m.cancelled || m.confirmed {
return ""
}
var s strings.Builder
// Title with filter
s.WriteString(selectorTitleStyle.Render(m.title))
s.WriteString(" ")
if m.filter == "" {
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
} else {
s.WriteString(selectorInputStyle.Render(m.filter))
}
s.WriteString("\n\n")
filtered := m.filteredItems()
if len(filtered) == 0 {
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
s.WriteString("\n")
} else {
displayCount := min(len(filtered), maxSelectorItems)
for i := range displayCount {
idx := m.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
origIdx := m.itemIndex[item.Name]
// Checkbox
var checkbox string
if m.checked[origIdx] {
checkbox = selectorCheckboxCheckedStyle.Render("[x]")
} else {
checkbox = selectorCheckboxStyle.Render("[ ]")
}
// Cursor and name
var line string
if idx == m.cursor {
line = selectorSelectedItemStyle.Render("▸ ") + checkbox + " " + selectorSelectedItemStyle.Render(item.Name)
} else {
line = " " + checkbox + " " + item.Name
}
// Default tag
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
line += " " + selectorDefaultTagStyle.Render("(default)")
}
s.WriteString(line)
s.WriteString("\n")
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
s.WriteString("\n")
}
}
s.WriteString("\n")
// Status line
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
} else {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
return s.String()
}
// SelectMultiple prompts the user to select multiple items from a list.
func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) {
if len(items) == 0 {
return nil, fmt.Errorf("no items to select from")
}
m := newMultiSelectorModel(title, items, preChecked)
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return nil, fmt.Errorf("error running selector: %w", err)
}
fm := finalModel.(multiSelectorModel)
if fm.cancelled {
return nil, ErrCancelled
}
if !fm.confirmed {
return nil, ErrCancelled
}
var result []string
for _, idx := range fm.checkOrder {
result = append(result, fm.items[idx].Name)
}
return result, nil
}

736
cmd/tui/tui.go Normal file
View File

@@ -0,0 +1,736 @@
package tui
import (
"context"
"errors"
"fmt"
"os/exec"
"runtime"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/version"
)
var (
titleStyle = lipgloss.NewStyle().
Bold(true).
MarginBottom(1)
versionStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("245"))
itemStyle = lipgloss.NewStyle().
PaddingLeft(2)
selectedStyle = lipgloss.NewStyle().
PaddingLeft(2).
Bold(true)
greyedStyle = lipgloss.NewStyle().
PaddingLeft(2).
Foreground(lipgloss.Color("241"))
greyedSelectedStyle = lipgloss.NewStyle().
PaddingLeft(2).
Foreground(lipgloss.Color("243"))
descStyle = lipgloss.NewStyle().
PaddingLeft(4).
Foreground(lipgloss.Color("241"))
modelStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("245"))
notInstalledStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
Italic(true)
)
type menuItem struct {
title string
description string
integration string // integration name for loading model config, empty if not an integration
isRunModel bool // true for the "Run a model" option
isOthers bool // true for the "Others..." toggle item
}
var mainMenuItems = []menuItem{
{
title: "Run a model",
description: "Start an interactive chat with a local model",
isRunModel: true,
},
{
title: "Launch Claude Code",
description: "Open Claude Code AI assistant",
integration: "claude",
},
{
title: "Launch Codex",
description: "Open Codex CLI",
integration: "codex",
},
{
title: "Launch Open Claw",
description: "Open the Open Claw integration",
integration: "openclaw",
},
}
var othersMenuItem = menuItem{
title: "Others...",
description: "Show additional integrations",
isOthers: true,
}
// getOtherIntegrations returns the list of other integrations, filtering out
// Codex if it's not installed (since it requires npm install).
func getOtherIntegrations() []menuItem {
return []menuItem{
{
title: "Launch Droid",
description: "Open Droid integration",
integration: "droid",
},
{
title: "Launch Open Code",
description: "Open Open Code integration",
integration: "opencode",
},
{
title: "Launch Pi",
description: "Open Pi coding agent",
integration: "pi",
},
}
}
type model struct {
items []menuItem
cursor int
quitting bool
selected bool // true if user made a selection (enter/space)
changeModel bool // true if user pressed right arrow to change model
showOthers bool // true if "Others..." is expanded
availableModels map[string]bool // cache of available model names
err error
// Modal state
showingModal bool // true when model picker modal is visible
modalSelector selectorModel // the selector model for the modal
modalItems []SelectItem // cached items for the modal
// Sign-in dialog state
showingSignIn bool // true when sign-in dialog is visible
signInURL string // URL for sign-in
signInModel string // model that requires sign-in
signInSpinner int // spinner frame index
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
}
// signInTickMsg is sent to animate the sign-in spinner
type signInTickMsg struct{}
// signInCheckMsg is sent to check if sign-in is complete
type signInCheckMsg struct {
signedIn bool
userName string
}
// modelExists checks if a model exists in the cached available models.
func (m *model) modelExists(name string) bool {
if m.availableModels == nil || name == "" {
return false
}
if m.availableModels[name] {
return true
}
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
for modelName := range m.availableModels {
if strings.HasPrefix(modelName, name+":") {
return true
}
}
return false
}
// buildModalItems creates the list of models for the modal selector.
func (m *model) buildModalItems() []SelectItem {
modelItems, _ := config.GetModelItems(context.Background())
var items []SelectItem
for _, item := range modelItems {
items = append(items, SelectItem{Name: item.Name, Description: item.Description})
}
return items
}
// openModelModal opens the model picker modal.
func (m *model) openModelModal() {
m.modalItems = m.buildModalItems()
m.modalSelector = selectorModel{
title: "Select model:",
items: m.modalItems,
}
m.showingModal = true
}
// isCloudModel returns true if the model name indicates a cloud model.
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud")
}
// checkCloudSignIn checks if a cloud model needs sign-in.
// Returns a command to start sign-in if needed, or nil if already signed in.
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
if modelName == "" || !isCloudModel(modelName) {
return nil
}
client, err := api.ClientFromEnvironment()
if err != nil {
return nil
}
user, err := client.Whoami(context.Background())
if err == nil && user != nil && user.Name != "" {
return nil // Already signed in
}
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.SigninURL != "" {
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
}
return nil
}
// startSignIn initiates the sign-in flow for a cloud model.
// fromModal indicates if this was triggered from the model picker modal.
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
m.showingModal = false
m.showingSignIn = true
m.signInURL = signInURL
m.signInModel = modelName
m.signInSpinner = 0
m.signInFromModal = fromModal
// Open browser (best effort)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", signInURL).Start()
case "linux":
_ = exec.Command("xdg-open", signInURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", signInURL).Start()
}
// Start the spinner tick
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
}
// checkSignIn checks if the user has completed sign-in.
func checkSignIn() tea.Msg {
client, err := api.ClientFromEnvironment()
if err != nil {
return signInCheckMsg{signedIn: false}
}
user, err := client.Whoami(context.Background())
if err == nil && user != nil && user.Name != "" {
return signInCheckMsg{signedIn: true, userName: user.Name}
}
return signInCheckMsg{signedIn: false}
}
// loadAvailableModels fetches and caches the list of available models.
func (m *model) loadAvailableModels() {
m.availableModels = make(map[string]bool)
client, err := api.ClientFromEnvironment()
if err != nil {
return
}
models, err := client.List(context.Background())
if err != nil {
return
}
for _, mdl := range models.Models {
m.availableModels[mdl.Name] = true
}
}
func (m *model) buildItems() {
others := getOtherIntegrations()
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
m.items = append(m.items, mainMenuItems...)
if m.showOthers {
// Change "Others..." to "Hide others..."
hideItem := menuItem{
title: "Hide others...",
description: "Hide additional integrations",
isOthers: true,
}
m.items = append(m.items, hideItem)
m.items = append(m.items, others...)
} else {
m.items = append(m.items, othersMenuItem)
}
}
// isOthersIntegration returns true if the integration is in the "Others" menu
func isOthersIntegration(name string) bool {
switch name {
case "droid", "opencode":
return true
}
return false
}
func initialModel() model {
m := model{
cursor: 0,
}
m.loadAvailableModels()
// Check last selection to determine if we need to expand "Others"
lastSelection := config.LastSelection()
if isOthersIntegration(lastSelection) {
m.showOthers = true
}
m.buildItems()
// Position cursor on last selection
if lastSelection != "" {
for i, item := range m.items {
if lastSelection == "run" && item.isRunModel {
m.cursor = i
break
} else if item.integration == lastSelection {
m.cursor = i
break
}
}
}
return m
}
func (m model) Init() tea.Cmd {
return nil
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Handle sign-in dialog
if m.showingSignIn {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
// Cancel sign-in and go back
m.showingSignIn = false
if m.signInFromModal {
m.showingModal = true
}
// If from main menu, just return to main menu (default state)
return m, nil
}
case signInTickMsg:
m.signInSpinner++
// Check sign-in status every 5th tick (~1 second)
if m.signInSpinner%5 == 0 {
return m, tea.Batch(
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
}),
checkSignIn,
)
}
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
case signInCheckMsg:
if msg.signedIn {
// Sign-in complete - proceed with selection
if m.signInFromModal {
// Came from modal - set changeModel
m.modalSelector.selected = m.signInModel
m.changeModel = true
} else {
// Came from main menu - just select
m.selected = true
}
m.quitting = true
return m, tea.Quit
}
}
return m, nil
}
// Handle modal input if modal is showing
if m.showingModal {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
// Close modal without selection
m.showingModal = false
return m, nil
case tea.KeyEnter:
filtered := m.modalSelector.filteredItems()
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
}
if m.modalSelector.selected != "" {
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
return m, cmd
}
// Selection made - exit with changeModel
m.changeModel = true
m.quitting = true
return m, tea.Quit
}
return m, nil
case tea.KeyUp:
if m.modalSelector.cursor > 0 {
m.modalSelector.cursor--
if m.modalSelector.cursor < m.modalSelector.scrollOffset {
m.modalSelector.scrollOffset = m.modalSelector.cursor
}
}
case tea.KeyDown:
filtered := m.modalSelector.filteredItems()
if m.modalSelector.cursor < len(filtered)-1 {
m.modalSelector.cursor++
if m.modalSelector.cursor >= m.modalSelector.scrollOffset+maxSelectorItems {
m.modalSelector.scrollOffset = m.modalSelector.cursor - maxSelectorItems + 1
}
}
case tea.KeyPgUp:
filtered := m.modalSelector.filteredItems()
m.modalSelector.cursor -= maxSelectorItems
if m.modalSelector.cursor < 0 {
m.modalSelector.cursor = 0
}
m.modalSelector.scrollOffset -= maxSelectorItems
if m.modalSelector.scrollOffset < 0 {
m.modalSelector.scrollOffset = 0
}
_ = filtered // suppress unused warning
case tea.KeyPgDown:
filtered := m.modalSelector.filteredItems()
m.modalSelector.cursor += maxSelectorItems
if m.modalSelector.cursor >= len(filtered) {
m.modalSelector.cursor = len(filtered) - 1
}
if m.modalSelector.cursor >= m.modalSelector.scrollOffset+maxSelectorItems {
m.modalSelector.scrollOffset = m.modalSelector.cursor - maxSelectorItems + 1
}
case tea.KeyBackspace:
if len(m.modalSelector.filter) > 0 {
m.modalSelector.filter = m.modalSelector.filter[:len(m.modalSelector.filter)-1]
m.modalSelector.cursor = 0
m.modalSelector.scrollOffset = 0
}
case tea.KeyRunes:
m.modalSelector.filter += string(msg.Runes)
m.modalSelector.cursor = 0
m.modalSelector.scrollOffset = 0
}
}
return m, nil
}
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c", "q", "esc":
m.quitting = true
return m, tea.Quit
case "up", "k":
if m.cursor > 0 {
m.cursor--
}
case "down", "j":
if m.cursor < len(m.items)-1 {
m.cursor++
}
case "enter", " ":
item := m.items[m.cursor]
// Handle "Others..." toggle
if item.isOthers {
m.showOthers = !m.showOthers
m.buildItems()
// Keep cursor on the Others/Hide item
if m.cursor >= len(m.items) {
m.cursor = len(m.items) - 1
}
return m, nil
}
// Don't allow selecting uninstalled integrations
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
return m, nil
}
// Check if a cloud model is configured and needs sign-in
var configuredModel string
if item.isRunModel {
configuredModel = config.LastModel()
} else if item.integration != "" {
configuredModel = config.IntegrationModel(item.integration)
}
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
return m, cmd
}
m.selected = true
m.quitting = true
return m, tea.Quit
case "right", "l":
// Allow model change for integrations and run model
item := m.items[m.cursor]
if item.integration != "" || item.isRunModel {
// Don't allow for uninstalled integrations
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
return m, nil
}
m.openModelModal()
}
}
}
return m, nil
}
func (m model) View() string {
if m.quitting {
return ""
}
// Render sign-in dialog if showing
if m.showingSignIn {
return m.renderSignInDialog()
}
// Render modal overlay if showing - replaces main view
if m.showingModal {
return m.renderModal()
}
s := titleStyle.Render(" Ollama "+versionStyle.Render("v"+version.Version)) + "\n\n"
for i, item := range m.items {
cursor := " "
style := itemStyle
isInstalled := true
if item.integration != "" {
isInstalled = config.IsIntegrationInstalled(item.integration)
}
if m.cursor == i {
cursor = "▸ "
if isInstalled {
style = selectedStyle
} else {
style = greyedSelectedStyle
}
} else if !isInstalled && item.integration != "" {
style = greyedStyle
}
title := item.title
if item.integration != "" {
if !isInstalled {
title += " " + notInstalledStyle.Render("(not installed)")
} else if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
title += " " + modelStyle.Render("("+mdl+")")
}
} else if item.isRunModel {
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
title += " " + modelStyle.Render("("+mdl+")")
}
}
s += style.Render(cursor+title) + "\n"
s += descStyle.Render(item.description) + "\n\n"
}
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Render("↑/↓ navigate • enter select • → change model • esc quit")
return s
}
// renderModal renders the model picker modal.
func (m model) renderModal() string {
modalStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("245")).
Padding(1, 2).
MarginLeft(2)
var content strings.Builder
// Title with filter
content.WriteString(selectorTitleStyle.Render(m.modalSelector.title))
content.WriteString(" ")
if m.modalSelector.filter == "" {
content.WriteString(selectorFilterStyle.Render("Type to filter..."))
} else {
content.WriteString(selectorInputStyle.Render(m.modalSelector.filter))
}
content.WriteString("\n\n")
filtered := m.modalSelector.filteredItems()
if len(filtered) == 0 {
content.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
content.WriteString("\n")
} else {
displayCount := min(len(filtered), maxSelectorItems)
for i := range displayCount {
idx := m.modalSelector.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
if idx == m.modalSelector.cursor {
content.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
} else {
content.WriteString(selectorItemStyle.Render(item.Name))
}
if item.Description != "" {
content.WriteString(" ")
content.WriteString(selectorDescStyle.Render("- " + item.Description))
}
content.WriteString("\n")
}
if remaining := len(filtered) - m.modalSelector.scrollOffset - displayCount; remaining > 0 {
content.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
content.WriteString("\n")
}
}
content.WriteString("\n")
content.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • esc cancel"))
return modalStyle.Render(content.String())
}
// renderSignInDialog renders the sign-in dialog.
func (m model) renderSignInDialog() string {
dialogStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("245")).
Padding(1, 2).
MarginLeft(2)
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
spinner := spinnerFrames[m.signInSpinner%len(spinnerFrames)]
var content strings.Builder
content.WriteString(selectorTitleStyle.Render("Sign in required"))
content.WriteString("\n\n")
content.WriteString(fmt.Sprintf("To use %s, please sign in.\n\n", selectedStyle.Render(m.signInModel)))
content.WriteString("Navigate to:\n")
content.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("117")).Render(" " + m.signInURL))
content.WriteString("\n\n")
content.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Render(
fmt.Sprintf("%s Waiting for sign in to complete...", spinner)))
content.WriteString("\n\n")
content.WriteString(selectorHelpStyle.Render("esc cancel"))
return dialogStyle.Render(content.String())
}
// Selection represents what the user selected
type Selection int
const (
SelectionNone Selection = iota
SelectionRunModel
SelectionChangeRunModel
SelectionIntegration // Generic integration selection
SelectionChangeIntegration // Generic change model for integration
)
// Result contains the selection and any associated data
type Result struct {
Selection Selection
Integration string // integration name if applicable
Model string // model name if selected from modal
}
// Run starts the TUI and returns the user's selection
func Run() (Result, error) {
m := initialModel()
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
}
fm := finalModel.(model)
if fm.err != nil {
return Result{Selection: SelectionNone}, fm.err
}
// User quit without selecting
if !fm.selected && !fm.changeModel {
return Result{Selection: SelectionNone}, nil
}
item := fm.items[fm.cursor]
// Handle model change request
if fm.changeModel {
if item.isRunModel {
return Result{
Selection: SelectionChangeRunModel,
Model: fm.modalSelector.selected,
}, nil
}
return Result{
Selection: SelectionChangeIntegration,
Integration: item.integration,
Model: fm.modalSelector.selected,
}, nil
}
// Handle selection
if item.isRunModel {
return Result{Selection: SelectionRunModel}, nil
}
return Result{
Selection: SelectionIntegration,
Integration: item.integration,
}, nil
}

View File

@@ -105,21 +105,52 @@
{
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/cline",
"/integrations/openclaw",
"/integrations/codex",
"/integrations/droid",
"/integrations/goose",
"/integrations/jetbrains",
"/integrations/marimo",
"/integrations/n8n",
"/integrations/onyx",
"/integrations/opencode",
"/integrations/roo-code",
"/integrations/vscode",
"/integrations/xcode",
"/integrations/zed"
"/integrations/index",
{
"group": "Coding",
"pages": [
"/integrations/claude-code",
"/integrations/codex",
"/integrations/opencode",
"/integrations/droid",
"/integrations/goose"
]
},
{
"group": "Assistants",
"pages": [
"/integrations/openclaw"
]
},
{
"group": "IDEs & Editors",
"pages": [
"/integrations/cline",
"/integrations/jetbrains",
"/integrations/roo-code",
"/integrations/vscode",
"/integrations/xcode",
"/integrations/zed"
]
},
{
"group": "Chat & RAG",
"pages": [
"/integrations/onyx"
]
},
{
"group": "Automation",
"pages": [
"/integrations/n8n"
]
},
{
"group": "Notebooks",
"pages": [
"/integrations/marimo"
]
}
]
},
{

View File

@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
## How can I view the logs?
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
Review the [Troubleshooting](./troubleshooting.mdx) docs for more about using logs.
## Is my GPU compatible with Ollama?
Please refer to the [GPU docs](./gpu).
Please refer to the [GPU docs](./gpu.mdx).
## How can I specify the context window size?
@@ -66,7 +66,7 @@ llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
```
</Info>
The `Processor` column will show which memory the model was loaded in to:
The `Processor` column will show which memory the model was loaded into:
- `100% GPU` means the model was loaded entirely into the GPU
- `100% CPU` means the model was loaded entirely in system memory
@@ -158,7 +158,7 @@ docker run -d -e HTTPS_PROXY=https://my.proxy.example.com -p 11434:11434 ollama-
## Does Ollama send my prompts and answers back to ollama.com?
No. Ollama runs locally, and conversation data does not leave your machine.
Ollama runs locally. We don't see your prompts or data when you run locally. When using cloud-hosted models, we process your prompts and responses to provide the service but do not store or log that content and never train on it. We collect basic account info and limited usage metadata to provide the service that does not include prompt or response content. We don't sell your data. You can delete your account anytime.
## How can I expose Ollama on my network?
@@ -183,7 +183,7 @@ server {
## How can I use Ollama with ngrok?
Ollama can be accessed using a range of tools for tunneling tools. For example with Ngrok:
Ollama can be accessed using a range of tunneling apps. For example with Ngrok:
```shell
ngrok http 11434 --host-header="localhost:11434"
@@ -240,7 +240,7 @@ GPU acceleration is not available for Docker Desktop in macOS due to the lack of
This can impact both installing Ollama, as well as downloading models.
Open `Control Panel > Networking and Internet > View network status and tasks` and click on `Change adapter settings` on the left panel. Find the `vEthernel (WSL)` adapter, right click and select `Properties`.
Open `Control Panel > Networking and Internet > View network status and tasks` and click on `Change adapter settings` on the left panel. Find the `vEthernet (WSL)` adapter, right click and select `Properties`.
Click on `Configure` and open the `Advanced` tab. Search through each of the properties until you find `Large Send Offload Version 2 (IPv4)` and `Large Send Offload Version 2 (IPv6)`. _Disable_ both of these
properties.
@@ -299,7 +299,7 @@ The `keep_alive` API parameter with the `/api/generate` and `/api/chat` API endp
## How do I manage the maximum number of requests the Ollama server can queue?
If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queue by setting `OLLAMA_MAX_QUEUE`.
If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queued by setting `OLLAMA_MAX_QUEUE`.
## How does Ollama handle concurrent requests?
@@ -312,10 +312,10 @@ Parallel request processing for a given model results in increasing the context
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 \* the number of GPUs or 3 for CPU inference.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time, default 1. Required RAM will scale by `OLLAMA_NUM_PARALLEL` * `OLLAMA_CONTEXT_LENGTH`.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPU's VRAM.
## How does Ollama load models on multiple GPUs?
@@ -382,7 +382,7 @@ ollama signin
Replace &lt;username&gt; with your actual Windows user name.
</Note>
## How can I stop Ollama from starting when I login to my computer
## How can I stop Ollama from starting when I login to my computer?
Ollama for Windows and macOS register as a login item during installation. You can disable this if you prefer not to have Ollama automatically start. Ollama will respect this setting across upgrades, unless you uninstall the application.
@@ -390,4 +390,4 @@ Ollama for Windows and macOS register as a login item during installation. You
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
**MacOS**
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
- Open `Settings` and search for "Login Items", find the `Ollama` entry under `Allow in the Background`, then click the slider to disable.

View File

@@ -0,0 +1,50 @@
---
title: Overview
---
Ollama integrates with a wide range of tools.
## Coding Agents
Coding assistants that can read, modify, and execute code in your projects.
- [Claude Code](/integrations/claude-code)
- [Codex](/integrations/codex)
- [OpenCode](/integrations/opencode)
- [Droid](/integrations/droid)
- [Goose](/integrations/goose)
## Assistants
AI assistants that help with everyday tasks.
- [OpenClaw](/integrations/openclaw)
## IDEs & Editors
Native integrations for popular development environments.
- [VS Code](/integrations/vscode)
- [Cline](/integrations/cline)
- [Roo Code](/integrations/roo-code)
- [JetBrains](/integrations/jetbrains)
- [Xcode](/integrations/xcode)
- [Zed](/integrations/zed)
## Chat & RAG
Chat interfaces and retrieval-augmented generation platforms.
- [Onyx](/integrations/onyx)
## Automation
Workflow automation platforms with AI integration.
- [n8n](/integrations/n8n)
## Notebooks
Interactive computing environments with AI capabilities.
- [marimo](/integrations/marimo)

View File

@@ -216,6 +216,7 @@ func String(s string) func() string {
var (
LLMLibrary = String("OLLAMA_LLM_LIBRARY")
Editor = String("OLLAMA_EDITOR")
CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES")
HipVisibleDevices = String("HIP_VISIBLE_DEVICES")
@@ -291,6 +292,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
"OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},

23
go.mod
View File

@@ -13,7 +13,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.24
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/stretchr/testify v1.10.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
@@ -21,14 +21,18 @@ require (
require (
github.com/agnivade/levenshtein v1.1.1
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/mattn/go-runewidth v0.0.14
github.com/mattn/go-runewidth v0.0.16
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/tree-sitter/go-tree-sitter v0.25.0
github.com/tree-sitter/tree-sitter-cpp v0.23.4
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0
@@ -38,22 +42,35 @@ require (
require (
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/ansi v0.10.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.11.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-pointer v0.0.1 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/tkrajina/go-reflector v0.5.5 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect

67
go.sum
View File

@@ -14,6 +14,8 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
@@ -24,6 +26,18 @@ github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
@@ -59,6 +73,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
@@ -148,13 +164,19 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -162,6 +184,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
@@ -182,8 +210,9 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
@@ -206,12 +235,39 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
@@ -220,6 +276,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -306,6 +364,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -5,6 +5,7 @@ import (
)
var ErrInterrupt = errors.New("Interrupt")
var ErrEditPrompt = errors.New("EditPrompt")
type InterruptError struct {
Line []rune

View File

@@ -41,6 +41,7 @@ type Instance struct {
Terminal *Terminal
History *History
Pasting bool
Prefill string
pastedLines []string
}
@@ -89,6 +90,26 @@ func (i *Instance) Readline() (string, error) {
buf, _ := NewBuffer(i.Prompt)
if i.Prefill != "" {
lines := strings.Split(i.Prefill, "\n")
i.Prefill = ""
for idx, l := range lines {
for _, r := range l {
buf.Add(r)
}
if idx < len(lines)-1 {
i.pastedLines = append(i.pastedLines, buf.String())
buf.Buf.Clear()
buf.Pos = 0
buf.DisplayPos = 0
buf.LineHasSpace.Clear()
fmt.Println()
fmt.Print(i.Prompt.AltPrompt)
i.Prompt.UseAlt = true
}
}
}
var esc bool
var escex bool
var metaDel bool
@@ -251,6 +272,29 @@ func (i *Instance) Readline() (string, error) {
buf.ClearScreen()
case CharCtrlW:
buf.DeleteWord()
case CharBell:
output := buf.String()
numPastedLines := len(i.pastedLines)
if numPastedLines > 0 {
output = strings.Join(i.pastedLines, "\n") + "\n" + output
i.pastedLines = nil
}
// Move cursor to the last display line of the current buffer
currLine := buf.DisplayPos / buf.LineWidth
lastLine := buf.DisplaySize() / buf.LineWidth
if lastLine > currLine {
fmt.Print(CursorDownN(lastLine - currLine))
}
// Clear all lines from bottom to top: buffer wrapped lines + pasted lines
for range lastLine + numPastedLines {
fmt.Print(CursorBOL + ClearToEOL + CursorUp)
}
fmt.Print(CursorBOL + ClearToEOL)
i.Prompt.UseAlt = false
return output, ErrEditPrompt
case CharCtrlZ:
fd := os.Stdin.Fd()
return handleCharCtrlZ(fd, i.Terminal.termios)

View File

@@ -4,6 +4,7 @@ import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
func Execute(args []string) error {
@@ -17,6 +18,8 @@ func Execute(args []string) error {
return ollamarunner.Execute(args[1:])
case "--imagegen-engine":
return imagegen.Execute(args[1:])
case "--mlx-engine":
return mlxrunner.Execute(args[1:])
}
}
return llamarunner.Execute(args)

View File

@@ -302,12 +302,22 @@ function deps {
}
function sign {
# Copy install.ps1 to dist for release packaging
write-host "Copying install.ps1 to dist"
Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1"
if ("${env:KEY_CONTAINER}") {
write-host "Signing Ollama executables, scripts and libraries"
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
$(get-childitem -path "${script:SRC_DIR}\dist\windows-*" -r -include @('*.exe', '*.dll'))
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
write-host "Signing install.ps1"
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
"${script:SRC_DIR}\dist\install.ps1"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} else {
write-host "Signing not enabled"
}

271
scripts/install.ps1 Normal file
View File

@@ -0,0 +1,271 @@
<#
.SYNOPSIS
Install, upgrade, or uninstall Ollama on Windows.
.DESCRIPTION
Downloads and installs Ollama.
Quick install:
irm https://ollama.com/install.ps1 | iex
Specific version:
$env:OLLAMA_VERSION="0.5.7"; irm https://ollama.com/install.ps1 | iex
Custom install directory:
$env:OLLAMA_INSTALL_DIR="D:\Ollama"; irm https://ollama.com/install.ps1 | iex
Uninstall:
$env:OLLAMA_UNINSTALL=1; irm https://ollama.com/install.ps1 | iex
Environment variables:
OLLAMA_VERSION Target version (default: latest stable)
OLLAMA_INSTALL_DIR Custom install directory
OLLAMA_UNINSTALL Set to 1 to uninstall Ollama
OLLAMA_DEBUG Enable verbose output
.EXAMPLE
irm https://ollama.com/install.ps1 | iex
.EXAMPLE
$env:OLLAMA_VERSION = "0.5.7"; irm https://ollama.com/install.ps1 | iex
.LINK
https://ollama.com
#>
$ErrorActionPreference = "Stop"
$ProgressPreference = "SilentlyContinue"
# --------------------------------------------------------------------------
# Configuration from environment variables
# --------------------------------------------------------------------------
$Version = if ($env:OLLAMA_VERSION) { $env:OLLAMA_VERSION } else { "" }
$InstallDir = if ($env:OLLAMA_INSTALL_DIR) { $env:OLLAMA_INSTALL_DIR } else { "" }
$Uninstall = $env:OLLAMA_UNINSTALL -eq "1"
$DebugInstall = [bool]$env:OLLAMA_DEBUG
# --------------------------------------------------------------------------
# Constants
# --------------------------------------------------------------------------
# OLLAMA_DOWNLOAD_URL for developer testing only
$DownloadBaseURL = if ($env:OLLAMA_DOWNLOAD_URL) { $env:OLLAMA_DOWNLOAD_URL.TrimEnd('/') } else { "https://ollama.com/download" }
$InnoSetupUninstallGuid = "{44E83376-CE68-45EB-8FC1-393500EB558C}_is1"
# --------------------------------------------------------------------------
# Helpers
# --------------------------------------------------------------------------
function Write-Status {
param([string]$Message)
if ($DebugInstall) { Write-Host $Message }
}
function Write-Step {
param([string]$Message)
if ($DebugInstall) { Write-Host ">>> $Message" -ForegroundColor Cyan }
}
function Test-Signature {
param([string]$FilePath)
$sig = Get-AuthenticodeSignature -FilePath $FilePath
if ($sig.Status -ne "Valid") {
Write-Status " Signature status: $($sig.Status)"
return $false
}
# Verify it's signed by Ollama Inc. (check exact organization name)
# Anchor with comma/boundary to prevent "O=Not Ollama Inc." from matching
$subject = $sig.SignerCertificate.Subject
if ($subject -notmatch "(^|, )O=Ollama Inc\.(,|$)") {
Write-Status " Unexpected signer: $subject"
return $false
}
Write-Status " Signature valid: $subject"
return $true
}
function Find-InnoSetupInstall {
# Check both HKCU (per-user) and HKLM (per-machine) locations
$possibleKeys = @(
"HKCU:\Software\Microsoft\Windows\CurrentVersion\Uninstall\$InnoSetupUninstallGuid",
"HKLM:\Software\Microsoft\Windows\CurrentVersion\Uninstall\$InnoSetupUninstallGuid",
"HKLM:\Software\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\$InnoSetupUninstallGuid"
)
foreach ($key in $possibleKeys) {
if (Test-Path $key) {
Write-Status " Found install at: $key"
return $key
}
}
return $null
}
function Update-SessionPath {
# Update PATH in current session so 'ollama' works immediately
if ($InstallDir) {
$ollamaDir = $InstallDir
} else {
$ollamaDir = Join-Path $env:LOCALAPPDATA "Programs\Ollama"
}
# Add to PATH if not already present
if (Test-Path $ollamaDir) {
$currentPath = $env:PATH -split ';'
if ($ollamaDir -notin $currentPath) {
$env:PATH = "$ollamaDir;$env:PATH"
Write-Status " Added $ollamaDir to session PATH"
}
}
}
function Invoke-Download {
param(
[string]$Url,
[string]$OutFile
)
Write-Status " Downloading: $Url"
try {
Invoke-WebRequest -Uri $Url -OutFile $OutFile -UseBasicParsing
$size = (Get-Item $OutFile).Length
Write-Status " Downloaded: $([math]::Round($size / 1MB, 1)) MB"
} catch {
if ($_.Exception.Response.StatusCode -eq 404) {
throw "Download failed: not found at $Url"
}
throw "Download failed for ${Url}: $($_.Exception.Message)"
}
}
# --------------------------------------------------------------------------
# Uninstall
# --------------------------------------------------------------------------
function Invoke-Uninstall {
Write-Step "Uninstalling Ollama"
$regKey = Find-InnoSetupInstall
if (-not $regKey) {
Write-Host "Ollama is not installed."
return
}
$uninstallString = (Get-ItemProperty -Path $regKey).UninstallString
if (-not $uninstallString) {
Write-Warning "No uninstall string found in registry"
return
}
# Strip quotes if present
$uninstallExe = $uninstallString -replace '"', ''
Write-Status " Uninstaller: $uninstallExe"
if (-not (Test-Path $uninstallExe)) {
Write-Warning "Uninstaller not found at: $uninstallExe"
return
}
Write-Host "Launching uninstaller..."
# Run with GUI so user can choose whether to keep models
Start-Process -FilePath $uninstallExe -Wait
# Verify removal
if (Find-InnoSetupInstall) {
Write-Warning "Uninstall may not have completed"
} else {
Write-Host "Ollama has been uninstalled."
}
}
# --------------------------------------------------------------------------
# Install
# --------------------------------------------------------------------------
function Invoke-Install {
# Determine installer URL
if ($Version) {
$installerUrl = "$DownloadBaseURL/OllamaSetup.exe?version=$Version"
} else {
$installerUrl = "$DownloadBaseURL/OllamaSetup.exe"
}
# Download installer
Write-Step "Downloading Ollama"
if (-not $DebugInstall) {
Write-Host "Downloading Ollama..."
}
$tempInstaller = Join-Path $env:TEMP "OllamaSetup.exe"
Invoke-Download -Url $installerUrl -OutFile $tempInstaller
# Verify signature
Write-Step "Verifying signature"
if (-not (Test-Signature -FilePath $tempInstaller)) {
Remove-Item $tempInstaller -Force -ErrorAction SilentlyContinue
throw "Installer signature verification failed"
}
# Build installer arguments
$installerArgs = "/VERYSILENT /NORESTART /SUPPRESSMSGBOXES"
if ($InstallDir) {
$installerArgs += " /DIR=`"$InstallDir`""
}
Write-Status " Installer args: $installerArgs"
# Run installer
Write-Step "Installing Ollama"
if (-not $DebugInstall) {
Write-Host "Installing..."
}
# Create upgrade marker so the app starts hidden
# The app checks for this file on startup and removes it after
$markerDir = Join-Path $env:LOCALAPPDATA "Ollama"
$markerFile = Join-Path $markerDir "upgraded"
if (-not (Test-Path $markerDir)) {
New-Item -ItemType Directory -Path $markerDir -Force | Out-Null
}
New-Item -ItemType File -Path $markerFile -Force | Out-Null
Write-Status " Created upgrade marker: $markerFile"
# Start installer and wait for just the installer process (not children)
# Using -Wait would wait for Ollama to exit too, which we don't want
$proc = Start-Process -FilePath $tempInstaller `
-ArgumentList $installerArgs `
-PassThru
$proc.WaitForExit()
if ($proc.ExitCode -ne 0) {
Remove-Item $tempInstaller -Force -ErrorAction SilentlyContinue
throw "Installation failed with exit code $($proc.ExitCode)"
}
# Cleanup
Remove-Item $tempInstaller -Force -ErrorAction SilentlyContinue
# Update PATH in current session so 'ollama' works immediately
Write-Step "Updating session PATH"
Update-SessionPath
Write-Host "Install complete. You can now run 'ollama'."
}
# --------------------------------------------------------------------------
# Main
# --------------------------------------------------------------------------
if ($Uninstall) {
Invoke-Uninstall
} else {
Invoke-Install
}

View File

@@ -35,6 +35,8 @@ case "$ARCH" in
*) error "Unsupported architecture: $ARCH" ;;
esac
VER_PARAM="${OLLAMA_VERSION:+?version=$OLLAMA_VERSION}"
###########################################
# macOS
###########################################
@@ -49,11 +51,7 @@ if [ "$OS" = "Darwin" ]; then
exit 1
fi
if [ -n "${OLLAMA_VERSION:-}" ]; then
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/download/${OLLAMA_VERSION}/Ollama-darwin.zip"
else
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/latest/download/Ollama-darwin.zip"
fi
DOWNLOAD_URL="https://ollama.com/download/Ollama-darwin.zip${VER_PARAM}"
if pgrep -x Ollama >/dev/null 2>&1; then
status "Stopping running Ollama instance..."
@@ -103,8 +101,6 @@ case "$KERN" in
*) ;;
esac
VER_PARAM="${OLLAMA_VERSION:+?version=$OLLAMA_VERSION}"
SUDO=
if [ "$(id -u)" -ne 0 ]; then
# Running as root, no need for sudo

View File

@@ -19,6 +19,7 @@ import (
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
@@ -35,7 +36,7 @@ type ModelfileConfig struct {
type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "q4", "q8", "nvfp4", or "mxfp8" for quantization
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
}
@@ -94,6 +95,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities, parserName, rendererName),
progressFn,
newPackedTensorLayerCreator(),
)
} else {
err = create.CreateImageGenModel(
@@ -141,60 +143,33 @@ func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
}
}
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
// createQuantizedLayers quantizes a tensor and returns a single combined layer.
// The combined blob contains data, scale, and optional bias tensors with metadata.
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize)
// Quantize the tensor into a single combined blob
blobData, err := quantizeTensor(r, name, dtype, shape, quantize)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
// Create single layer for the combined blob
layer, err := manifest.NewLayer(bytes.NewReader(blobData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []create.LayerInfo{
return []create.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale",
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, create.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias",
})
}
return layers, nil
}, nil
}
// createUnquantizedLayer creates a single tensor layer without quantization.
@@ -214,6 +189,58 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error
}, nil
}
// newPackedTensorLayerCreator returns a PackedTensorLayerCreator callback for
// creating packed multi-tensor blob layers (used for expert groups).
func newPackedTensorLayerCreator() create.PackedTensorLayerCreator {
return func(groupName string, tensors []create.PackedTensorInput) (create.LayerInfo, error) {
// Check if any tensor in the group needs quantization
hasQuantize := false
for _, t := range tensors {
if t.Quantize != "" {
hasQuantize = true
break
}
}
var blobReader io.Reader
if hasQuantize {
if !QuantizeSupported() {
return create.LayerInfo{}, fmt.Errorf("quantization requires MLX support")
}
blobData, err := quantizePackedGroup(tensors)
if err != nil {
return create.LayerInfo{}, fmt.Errorf("failed to quantize packed group %s: %w", groupName, err)
}
blobReader = bytes.NewReader(blobData)
} else {
// Build unquantized packed blob using streaming reader
// Extract raw tensor data from safetensors-wrapped readers
var tds []*safetensors.TensorData
for _, t := range tensors {
rawData, err := safetensors.ExtractRawFromSafetensors(t.Reader)
if err != nil {
return create.LayerInfo{}, fmt.Errorf("failed to extract tensor %s: %w", t.Name, err)
}
td := safetensors.NewTensorDataFromBytes(t.Name, t.Dtype, t.Shape, rawData)
tds = append(tds, td)
}
blobReader = safetensors.BuildPackedSafetensorsReader(tds)
}
layer, err := manifest.NewLayer(blobReader, manifest.MediaTypeImageTensor)
if err != nil {
return create.LayerInfo{}, err
}
return create.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: groupName,
}, nil
}
}
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {

View File

@@ -3,128 +3,195 @@
package client
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strconv"
"github.com/ollama/ollama/x/create"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// quantizeTensor loads a tensor from safetensors format, quantizes it,
// and returns safetensors data for the quantized weights, scales, and biases.
// Supported quantization types:
// - "q4": affine 4-bit, group_size=32 (with qbiases)
// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales)
// - "q8": affine 8-bit, group_size=64 (with qbiases)
// - "mxfp8": Microsoft MX FP8, group_size=32 (no qbiases, E4M3 scales)
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
// quantizeParams maps quantization type names to MLX quantize parameters.
var quantizeParams = map[string]struct {
groupSize int
bits int
mode string
}{
"int4": {32, 4, "affine"},
"nvfp4": {16, 4, "nvfp4"},
"int8": {64, 8, "affine"},
"mxfp8": {32, 8, "mxfp8"},
}
// loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX,
// quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias)
// to the provided maps. If quantize is empty, the tensor is kept as-is.
// Returns any temp file paths created (caller must clean up) and arrays needing eval.
func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) {
tmpDir := ensureTempDir()
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors")
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
return "", nil, nil, fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer os.Remove(tmpPath)
tmpPath = tmpFile.Name()
if _, err := io.Copy(tmpFile, r); err != nil {
tmpFile.Close()
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
return tmpPath, nil, nil, fmt.Errorf("failed to write temp file for %s: %w", name, err)
}
tmpFile.Close()
// Load the tensor using MLX's native loader
st, err := mlx.LoadSafetensorsNative(tmpPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
return tmpPath, nil, nil, fmt.Errorf("failed to load safetensors for %s: %w", name, err)
}
defer st.Free()
// Get the tensor (it's stored as "data" in our minimal safetensors format)
arr := st.Get("data")
// Find the tensor key (may differ from name for single-tensor blobs)
inputKey, err := findSafetensorsKey(tmpPath)
if err != nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err)
}
arr := st.Get(inputKey)
if arr == nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
st.Free()
return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey)
}
// Convert to BFloat16 if needed (quantize expects float type)
if quantize == "" {
arr = mlx.Contiguous(arr)
arrays[name] = arr
return tmpPath, []*mlx.Array{arr}, st, nil
}
// Convert to float type if needed (quantize expects float)
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
mlx.Eval(arr)
}
// Quantize based on quantization type
var qweight, scales, qbiases *mlx.Array
switch quantize {
case "q4":
// affine mode: group_size=32, bits=4 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
case "nvfp4":
// NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales)
qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4")
case "q8":
// affine mode: group_size=64, bits=8 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 64, 8, "affine")
case "mxfp8":
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbiases)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "mxfp8")
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
params, ok := quantizeParams[quantize]
if !ok {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}
// Eval and make contiguous for data access
qweight, scales, qbiases := mlx.Quantize(arr, params.groupSize, params.bits, params.mode)
qweight = mlx.Contiguous(qweight)
scales = mlx.Contiguous(scales)
arrays[name] = qweight
arrays[name+".scale"] = scales
toEval = append(toEval, qweight, scales)
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases)
mlx.Eval(qweight, scales, qbiases)
} else {
mlx.Eval(qweight, scales)
arrays[name+".bias"] = qbiases
toEval = append(toEval, qbiases)
}
// Get shapes
qweightShape = qweight.Shape()
scalesShape = scales.Shape()
return tmpPath, toEval, st, nil
}
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
defer os.Remove(qweightPath)
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
// quantizeTensor loads a tensor from safetensors format, quantizes it,
// and returns a single combined safetensors blob with the quantized weight, scale, and optional bias.
// Tensor keys use the original tensor name: name, name.scale, name.bias.
// The blob includes __metadata__ with quant_type and group_size.
// Supported quantization types: "int4", "nvfp4", "int8", "mxfp8".
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
arrays := make(map[string]*mlx.Array)
tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays)
if tmpPath != "" {
defer os.Remove(tmpPath)
}
if st != nil {
defer st.Free()
}
qweightData, err = os.ReadFile(qweightPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
return nil, err
}
// Save scales using MLX's native safetensors
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
defer os.Remove(scalesPath)
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
}
scalesData, err = os.ReadFile(scalesPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
mlx.Eval(toEval...)
// Build metadata for single-tensor blobs
params := quantizeParams[quantize]
metadata := map[string]string{
"quant_type": quantize,
"group_size": strconv.Itoa(params.groupSize),
}
// Affine mode returns qbiases for zero-point offset
if qbiases != nil {
qbiasShape = qbiases.Shape()
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
defer os.Remove(qbiasPath)
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
tmpDir := ensureTempDir()
outPath := filepath.Join(tmpDir, "combined.safetensors")
defer os.Remove(outPath)
if err := mlx.SaveSafetensorsWithMetadata(outPath, arrays, metadata); err != nil {
return nil, fmt.Errorf("failed to save combined blob: %w", err)
}
return os.ReadFile(outPath)
}
// quantizePackedGroup quantizes multiple tensors and saves them all into a single
// combined safetensors blob. Used for packing expert groups.
// Each tensor may have a different quantization type (mixed-precision).
// Returns the blob bytes. No __metadata__ is added because different tensors
// may use different quantization types.
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
allArrays := make(map[string]*mlx.Array)
var allToEval []*mlx.Array
var tmpPaths []string
var handles []*mlx.SafetensorsFile
for _, input := range inputs {
tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays)
if tmpPath != "" {
tmpPaths = append(tmpPaths, tmpPath)
}
if st != nil {
handles = append(handles, st)
}
qbiasData, err = os.ReadFile(qbiasPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
// Cleanup on error
for _, h := range handles {
h.Free()
}
for _, p := range tmpPaths {
os.Remove(p)
}
return nil, err
}
allToEval = append(allToEval, toEval...)
}
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
mlx.Eval(allToEval...)
// Free native handles after eval
for _, h := range handles {
h.Free()
}
// Save combined blob (no global metadata for mixed-precision packed blobs)
tmpDir := ensureTempDir()
outPath := filepath.Join(tmpDir, "packed-combined.safetensors")
defer os.Remove(outPath)
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, nil); err != nil {
return nil, fmt.Errorf("failed to save packed blob: %w", err)
}
blobData, err := os.ReadFile(outPath)
if err != nil {
return nil, fmt.Errorf("failed to read packed blob: %w", err)
}
for _, p := range tmpPaths {
os.Remove(p)
}
return blobData, nil
}
// QuantizeSupported returns true if quantization is supported (MLX build)
@@ -138,3 +205,33 @@ func ensureTempDir() string {
os.MkdirAll(tmpDir, 0755)
return tmpDir
}
// findSafetensorsKey reads the first non-metadata tensor key from a safetensors file.
func findSafetensorsKey(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return "", err
}
headerBytes := make([]byte, headerSize)
if _, err := io.ReadFull(f, headerBytes); err != nil {
return "", err
}
var header map[string]json.RawMessage
if err := json.Unmarshal(headerBytes, &header); err != nil {
return "", err
}
for k := range header {
if k != "__metadata__" {
return k, nil
}
}
return "", fmt.Errorf("no tensor found in safetensors header")
}

View File

@@ -5,11 +5,18 @@ package client
import (
"fmt"
"io"
"github.com/ollama/ollama/x/create"
)
// quantizeTensor is not available without MLX
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// quantizePackedGroup is not available without MLX
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// QuantizeSupported returns false when MLX is not available

View File

@@ -6,7 +6,9 @@ import (
"io"
"os"
"path/filepath"
"regexp"
"slices"
"sort"
"strings"
"github.com/ollama/ollama/envconfig"
@@ -228,7 +230,7 @@ type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is non-empty (e.g., "q8"), returns multiple layers (weight + scales + biases).
// When quantize is non-empty (e.g., "int8"), returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
// ManifestWriter writes the manifest file.
@@ -264,19 +266,19 @@ func ShouldQuantize(name, component string) bool {
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
// This is a more detailed check that also considers tensor dimensions.
// The quantize parameter specifies the quantization type (e.g., "q4", "nvfp4", "q8", "mxfp8").
// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "int8", "mxfp8").
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
return GetTensorQuantization(name, shape, quantize) != ""
}
// normalizeQuantType converts various quantization type aliases to canonical forms.
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> q4, q8/Q8/int8/INT8/fp8/FP8 -> q8, nvfp4/NVFP4, mxfp8/MXFP8
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp8/MXFP8
func normalizeQuantType(quantize string) string {
switch strings.ToUpper(quantize) {
case "Q4", "INT4", "FP4":
return "q4"
return "int4"
case "Q8", "INT8", "FP8":
return "q8"
return "int8"
case "NVFP4":
return "nvfp4"
case "MXFP8":
@@ -286,29 +288,12 @@ func normalizeQuantType(quantize string) string {
}
}
// getQuantGroupSize returns the group size for a given quantization type.
// These must match the values used in quantize.go when creating quantized models.
func getQuantGroupSize(quantize string) int {
switch normalizeQuantType(quantize) {
case "nvfp4":
return 16
case "q4":
return 32
case "mxfp8":
return 32
case "q8":
return 64
default:
return 32
}
}
// GetTensorQuantization returns the appropriate quantization type for a tensor.
// Returns "" if the tensor should not be quantized.
// This implements mixed-precision quantization:
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
// - Output projection, gate/up weights: q4 (less sensitive)
// - Down projection weights: q8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - Output projection, gate/up weights: int4 (less sensitive)
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - Norms, embeddings, biases, routing gates: no quantization
func GetTensorQuantization(name string, shape []int32, quantize string) string {
// Use basic name-based check first
@@ -330,12 +315,12 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
quantNorm := normalizeQuantType(quantize)
// MLX quantization requires last dimension to be divisible by group size
// nvfp4: 16, q4/mxfp8: 32, q8: 64
// nvfp4: 16, int4/mxfp8: 32, int8: 64
groupSize := int32(32)
switch quantNorm {
case "nvfp4":
groupSize = 16
case "q8":
case "int8":
groupSize = 64
}
if shape[len(shape)-1]%groupSize != 0 {
@@ -363,13 +348,13 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
return "" // No quantization - keep bf16
}
// Down projection weights - use Q8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
// Down projection weights - use INT8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
if strings.Contains(name, "down_proj") {
return "q8"
return "int8"
}
// Output projection, gate/up weights - use requested quantization (Q4)
// Output projection, gate/up weights - use requested quantization (INT4)
// o_proj, gate_proj, up_proj
if strings.Contains(name, "o_proj") ||
strings.Contains(name, "gate_proj") ||
@@ -386,14 +371,69 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
return quantNorm
}
// expertGroupRegexp matches expert tensor names and captures the group prefix.
// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes)
// Captures: model.layers.{L}.mlp.experts
var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`)
// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
// For example:
// - "model.layers.1.mlp.experts.0.down_proj.weight" -> "model.layers.1.mlp.experts"
// - "model.layers.1.mlp.shared_experts.down_proj.weight" -> "model.layers.1.mlp.shared_experts"
// - "model.layers.0.mlp.down_proj.weight" -> "" (dense layer, no experts)
// - "model.layers.1.mlp.gate.weight" -> "" (routing gate, not an expert)
func ExpertGroupPrefix(tensorName string) string {
m := expertGroupRegexp.FindStringSubmatch(tensorName)
if m == nil {
return ""
}
return m[1]
}
// PackedTensorInput holds metadata for a tensor that will be packed into a multi-tensor blob.
type PackedTensorInput struct {
Name string
Dtype string
Shape []int32
Quantize string // per-tensor quantization type (may differ within group)
Reader io.Reader // safetensors-wrapped tensor data
}
// PackedTensorLayerCreator creates a single blob layer containing multiple packed tensors.
// groupName is the group prefix (e.g., "model.layers.1.mlp.experts").
type PackedTensorLayerCreator func(groupName string, tensors []PackedTensorInput) (LayerInfo, error)
// CreateSafetensorsModel imports a standard safetensors model from a directory.
// This handles Hugging Face style models with config.json and *.safetensors files.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is non-empty (e.g., "q8"), eligible tensors will be quantized.
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Expert tensors are packed into per-layer blobs when createPackedLayer is non-nil.
// If quantize is non-empty (e.g., "int8"), eligible tensors will be quantized.
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string), createPackedLayer ...PackedTensorLayerCreator) error {
var layers []LayerInfo
var configLayer LayerInfo
// Resolve the optional packed layer creator
var packedCreator PackedTensorLayerCreator
if len(createPackedLayer) > 0 {
packedCreator = createPackedLayer[0]
}
// Accumulate expert tensors by group prefix for packing.
// Readers reference file-backed SectionReaders, so we keep extractors
// open until each group is flushed to avoid buffering tensor data in memory.
expertGroups := make(map[string][]PackedTensorInput)
var expertGroupOrder []string
// Track open extractors so we can close them after flushing groups
var openExtractors []*safetensors.TensorExtractor
closeExtractors := func() {
for _, ext := range openExtractors {
ext.Close()
}
openExtractors = nil
}
entries, err := os.ReadDir(modelDir)
if err != nil {
return fmt.Errorf("failed to read directory: %w", err)
@@ -410,6 +450,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
// Extract individual tensors from safetensors file
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
closeExtractors()
return fmt.Errorf("failed to open %s: %w", stPath, err)
}
@@ -420,10 +461,14 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
}
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
// Track whether this extractor has expert tensors that need to stay open
hasExpertTensors := false
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
closeExtractors()
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
@@ -434,20 +479,65 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize)
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
// Check if this tensor belongs to an expert group for packing
groupPrefix := ""
if packedCreator != nil {
groupPrefix = ExpertGroupPrefix(tensorName)
}
if groupPrefix != "" {
// Accumulate expert tensor for packed blob.
// The Reader uses a file-backed SectionReader, so we must
// keep the extractor open until this group is flushed.
hasExpertTensors = true
if _, exists := expertGroups[groupPrefix]; !exists {
expertGroupOrder = append(expertGroupOrder, groupPrefix)
}
expertGroups[groupPrefix] = append(expertGroups[groupPrefix], PackedTensorInput{
Name: tensorName,
Dtype: td.Dtype,
Shape: td.Shape,
Quantize: quantizeType,
Reader: td.SafetensorsReader(),
})
} else {
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
if err != nil {
extractor.Close()
closeExtractors()
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
}
layers = append(layers, newLayers...)
}
layers = append(layers, newLayers...)
}
extractor.Close()
if hasExpertTensors {
// Keep extractor open - readers still reference its file handle
openExtractors = append(openExtractors, extractor)
} else {
extractor.Close()
}
}
// Process accumulated expert groups into packed blobs, then close extractors
if packedCreator != nil {
sort.Strings(expertGroupOrder)
for _, groupName := range expertGroupOrder {
tensors := expertGroups[groupName]
fn(fmt.Sprintf("packing %s (%d tensors)", groupName, len(tensors)))
layer, err := packedCreator(groupName, tensors)
if err != nil {
closeExtractors()
return fmt.Errorf("failed to create packed layer for %s: %w", groupName, err)
}
layers = append(layers, layer)
}
}
closeExtractors()
// Process all JSON config files
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
@@ -487,23 +577,6 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
return fmt.Errorf("config.json not found in %s", modelDir)
}
// Create model_index.json with quantization info if quantizing
if quantize != "" {
modelIndex := map[string]any{
"quantization": strings.ToUpper(quantize),
"group_size": getQuantGroupSize(quantize),
}
indexData, err := json.MarshalIndent(modelIndex, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal model_index.json: %w", err)
}
indexLayer, err := createLayer(strings.NewReader(string(indexData)), "application/vnd.ollama.image.json", "model_index.json")
if err != nil {
return fmt.Errorf("failed to create model_index.json layer: %w", err)
}
layers = append(layers, indexLayer)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {

View File

@@ -586,6 +586,39 @@ func TestShouldQuantizeTensor(t *testing.T) {
}
}
func TestExpertGroupPrefix(t *testing.T) {
tests := []struct {
name string
want string
}{
// Expert tensors should return the group prefix
{"model.layers.1.mlp.experts.0.down_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
// Shared expert tensors should return their own group prefix
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
// Non-expert tensors should return empty string
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
{"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert
{"model.embed_tokens.weight", ""}, // embedding
{"model.layers.0.self_attn.q_proj.weight", ""}, // attention
{"model.norm.weight", ""}, // norm
{"lm_head.weight", ""}, // output head
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ExpertGroupPrefix(tt.name)
if got != tt.want {
t.Errorf("ExpertGroupPrefix(%q) = %q, want %q", tt.name, got, tt.want)
}
})
}
}
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
@@ -751,7 +784,7 @@ func TestCreateImageGenModel_WithQuantize(t *testing.T) {
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "q8", createLayer, createTensorLayer, writeManifest, progressFn)
err := CreateImageGenModel("test-imagegen", dir, "int8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}

View File

@@ -15,15 +15,15 @@ import (
// CreateImageGenModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
// Supported quantization types: q4, q8, nvfp4, mxfp8 (or empty for no quantization).
// Supported quantization types: int4, int8, nvfp4, mxfp8 (or empty for no quantization).
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Validate quantization type
switch quantize {
case "", "q4", "q8", "nvfp4", "mxfp8":
case "", "int4", "int8", "nvfp4", "mxfp8":
// valid
default:
return fmt.Errorf("unsupported quantization type %q: supported types are q4, q8, nvfp4, mxfp8", quantize)
return fmt.Errorf("unsupported quantization type %q: supported types are int4, int8, nvfp4, mxfp8", quantize)
}
var layers []LayerInfo
@@ -214,7 +214,7 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
// MLX requires the last dimension to be divisible by the group size.
// nvfp4: 16, q4/mxfp8: 32, q8: 64
// nvfp4: 16, int4/mxfp8: 32, int8: 64
func canQuantizeShape(shape []int32, quantize string) bool {
if len(shape) < 2 {
return false
@@ -223,7 +223,7 @@ func canQuantizeShape(shape []int32, quantize string) bool {
switch strings.ToUpper(quantize) {
case "NVFP4":
groupSize = 16
case "Q8":
case "INT8":
groupSize = 64
}
return shape[len(shape)-1]%groupSize == 0

View File

@@ -0,0 +1,158 @@
# Tensor Blob Format
Ollama stores model tensors as individual blobs in the safetensors format. Each blob contains a logical tensor (or a combined quantized tensor with its scale/bias components), or a group of logical tensors (e.g. shared experts for a given layer along with the scale/bias components for that tensor).
## Safetensors File Format
Every blob follows the [safetensors](https://github.com/huggingface/safetensors) layout:
```
[8 bytes: header_size (uint64 LE)] [header_size bytes: JSON header] [tensor data region]
```
The JSON header maps tensor names to their dtype, shape, and byte offsets within the data region. A special `__metadata__` key holds string-to-string metadata.
## Unquantized Blobs
An unquantized blob stores a single tensor keyed by its name:
```json
{
"model.layers.0.self_attn.q_proj.weight": {
"dtype": "BF16",
"shape": [2560, 2560],
"data_offsets": [0, 13107200]
}
}
```
The tensor key is the full tensor name. Dtype is typically `BF16` or `F32`.
## Quantized Blobs (Combined Format)
A quantized blob stores the packed weight, scaling factors, and optional zero-point biases in a single file. Tensor keys use the tensor name, with `.scale` and `.bias` suffixes for the auxiliary tensors:
```json
{
"__metadata__": {
"quant_type": "int4",
"group_size": "32"
},
"model.layers.0.mlp.up_proj.weight": {
"dtype": "U32",
"shape": [2560, 320],
"data_offsets": [0, 3276800]
},
"model.layers.0.mlp.up_proj.weight.scale": {
"dtype": "BF16",
"shape": [2560, 80],
"data_offsets": [3276800, 3686400]
},
"model.layers.0.mlp.up_proj.weight.bias": {
"dtype": "BF16",
"shape": [2560, 80],
"data_offsets": [3686400, 4096000]
}
}
```
### Metadata Fields
| Field | Description |
|---|---|
| `quant_type` | Quantization type: `int4`, `int8`, `nvfp4`, or `mxfp8` |
| `group_size` | Number of elements per quantization group (e.g., `32`, `64`) |
### Tensor Keys
| Key | Description |
|---|---|
| `{name}` | Packed quantized weights (dtype `U32`) |
| `{name}.scale` | Per-group scaling factors |
| `{name}.bias` | Per-group zero-point offsets (affine modes only) |
## Quantization Types
| Type | Bits | Group Size | Mode | Has Bias |
|---|---|---|---|---|
| `int4` | 4 | 32 | affine | yes |
| `int8` | 8 | 64 | affine | yes |
| `nvfp4` | 4 | 16 | nvfp4 | no |
| `mxfp8` | 8 | 32 | mxfp8 | no |
**Affine modes** (`int4`, `int8`) use `scale + bias` for dequantization. The bias tensor provides the zero-point offset.
**Non-affine modes** (`nvfp4`, `mxfp8`) use only `scale` with specialized E4M3 scale formats.
### Packed Weight Shape
Quantized weights are packed into `uint32` values:
- **4-bit** (int4, nvfp4): 8 values per uint32, so `packed_cols = original_cols / 8`
- **8-bit** (int8, mxfp8): 4 values per uint32, so `packed_cols = original_cols / 4`
Scale shape: `[rows, original_cols / group_size]`
## Manifest References
Blobs are referenced from the model manifest as layers:
```json
{
"mediaType": "application/vnd.ollama.image.tensor",
"digest": "sha256:abc123...",
"size": 4096150,
"name": "model.layers.0.mlp.up_proj.weight"
}
```
Each tensor (quantized or not) is one layer in the manifest. The layer name matches the tensor key in the blob header.
## Packed Blobs (Expert Groups)
For MoE (Mixture of Experts) models, expert tensors from the same layer are packed into a single blob to reduce blob count and improve loading efficiency. A packed blob is a standard safetensors file containing multiple tensor entries:
```json
{
"model.layers.1.mlp.experts.0.down_proj.weight": {
"dtype": "U32",
"shape": [2560, 640],
"data_offsets": [0, 6553600]
},
"model.layers.1.mlp.experts.0.down_proj.weight.scale": {
"dtype": "BF16",
"shape": [2560, 40],
"data_offsets": [6553600, 6963200]
},
"model.layers.1.mlp.experts.0.gate_proj.weight": {
"dtype": "U32",
"shape": [10240, 320],
"data_offsets": [6963200, 20070400]
},
"model.layers.1.mlp.experts.0.gate_proj.weight.scale": { "..." : "..." }
}
```
### Grouping Rules
- `model.layers.{L}.mlp.experts.*` tensors are packed into one blob per layer
- `model.layers.{L}.mlp.shared_experts.*` tensors are packed into one blob per layer
- All other tensors remain as individual blobs
### Manifest Representation
One manifest layer per packed group, using the group prefix as the layer name:
```json
{
"mediaType": "application/vnd.ollama.image.tensor",
"digest": "sha256:...",
"size": 123456789,
"name": "model.layers.1.mlp.experts"
}
```
## Loading
At load time, `mlx_load_safetensors` opens each blob via mmap for zero-copy access. For combined quantized blobs, the loader extracts `{name}`, `{name}.scale`, and `{name}.bias` tensors and caches them as `name`, `name + "_scale"`, and `name + "_qbias"` respectively, maintaining compatibility with the weight loading interface.
For packed blobs, if the manifest layer name (group prefix) is not found as a tensor key, the loader parses the blob header to discover all tensor names and loads each individually.

View File

@@ -1,11 +1,13 @@
package manifest
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strings"
"github.com/ollama/ollama/envconfig"
@@ -205,17 +207,12 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
}
}
// Fallback: detect quantization from tensor names if not in config
// Fallback: detect quantization from first tensor blob's __metadata__
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "Q8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
info.Quantization = detectQuantizationFromBlobs(manifest)
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
// Fallback: estimate parameter count if not in config
@@ -223,9 +220,7 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
totalSize += layer.Size
}
}
// Assume BF16 (2 bytes/param) as rough estimate
@@ -234,3 +229,79 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
return info, nil
}
// detectQuantizationFromBlobs reads __metadata__ from the first tensor blob
// to detect quantization type.
func detectQuantizationFromBlobs(manifest *ModelManifest) string {
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType != "application/vnd.ollama.image.tensor" {
continue
}
data, err := readBlobHeader(manifest.BlobPath(layer.Digest))
if err != nil {
continue
}
var header map[string]json.RawMessage
if json.Unmarshal(data, &header) != nil {
continue
}
if metaRaw, ok := header["__metadata__"]; ok {
var meta map[string]string
if json.Unmarshal(metaRaw, &meta) == nil {
if qt, ok := meta["quant_type"]; ok && qt != "" {
return strings.ToUpper(qt)
}
}
}
// Only check the first tensor blob
break
}
return ""
}
// ParseBlobTensorNames reads a safetensors blob and returns all "main" tensor names.
// Filters out __metadata__, .scale, and .bias entries to return only primary weight tensors.
func ParseBlobTensorNames(path string) ([]string, error) {
data, err := readBlobHeader(path)
if err != nil {
return nil, err
}
var header map[string]json.RawMessage
if err := json.Unmarshal(data, &header); err != nil {
return nil, err
}
var names []string
for k := range header {
if k == "__metadata__" || strings.HasSuffix(k, ".scale") || strings.HasSuffix(k, ".bias") {
continue
}
names = append(names, k)
}
sort.Strings(names)
return names, nil
}
// readBlobHeader reads the JSON header bytes from a safetensors blob file.
func readBlobHeader(path string) ([]byte, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return nil, err
}
if headerSize > 1024*1024 {
return nil, fmt.Errorf("header too large: %d", headerSize)
}
data := make([]byte, headerSize)
if _, err := io.ReadFull(f, data); err != nil {
return nil, err
}
return data, nil
}

View File

@@ -5,6 +5,7 @@ package manifest
import (
"fmt"
"sort"
"strconv"
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
@@ -18,6 +19,8 @@ type ManifestWeights struct {
tensors map[string]ManifestLayer // name -> layer
cache map[string]*mlx.Array // name -> loaded array
nativeCache []*mlx.SafetensorsFile // keep native handles alive
quantType string // quantization type from blob metadata (e.g., "int4", "int8")
groupSize int // quantization group size from blob metadata
}
// LoadWeightsFromManifest creates a weight loader from manifest storage.
@@ -54,43 +57,129 @@ func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*Manife
// Load loads all tensor blobs using native mmap (zero-copy).
// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
// If dtype is non-zero, tensors are converted to the specified dtype.
// Combined quantized blobs contain tensors keyed by name, name+".scale", and optional name+".bias"
// with quantization metadata. Scale and bias are stored in cache as name+"_scale"
// and name+"_qbias" for compatibility with downstream loading code.
// Packed blobs (e.g., for expert groups) contain multiple tensors; the manifest name
// is a group prefix and individual tensors are loaded by their actual names from the blob.
// If dtype is non-zero, non-quantized tensors are converted to the specified dtype.
func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
// Track native handles to free after batch eval
nativeHandles := make([]*mlx.SafetensorsFile, 0, len(mw.tensors))
arrays := make([]*mlx.Array, 0, len(mw.tensors))
// Group tensors by digest to avoid loading the same blob multiple times
type blobEntry struct {
name string
layer ManifestLayer
}
blobGroups := make(map[string][]blobEntry)
for name, layer := range mw.tensors {
path := mw.manifest.BlobPath(layer.Digest)
blobGroups[layer.Digest] = append(blobGroups[layer.Digest], blobEntry{name, layer})
}
for digest, entries := range blobGroups {
path := mw.manifest.BlobPath(digest)
// Load blob as safetensors (native mmap, zero-copy)
sf, err := mlx.LoadSafetensorsNative(path)
if err != nil {
// Free any handles we've accumulated
for _, h := range nativeHandles {
h.Free()
}
return fmt.Errorf("load %s: %w", name, err)
return fmt.Errorf("load %s: %w", entries[0].name, err)
}
nativeHandles = append(nativeHandles, sf)
// Blob contains single tensor named "data"
arr := sf.Get("data")
if arr == nil {
for _, h := range nativeHandles {
h.Free()
// Read quantization metadata from blob
if qt := sf.GetMetadata("quant_type"); qt != "" && mw.quantType == "" {
mw.quantType = qt
if gs := sf.GetMetadata("group_size"); gs != "" {
mw.groupSize, _ = strconv.Atoi(gs)
}
return fmt.Errorf("tensor 'data' not found in blob for %s", name)
}
// Convert dtype if needed
if dtype != 0 && arr.Dtype() != dtype {
arr = mlx.AsType(arr, dtype)
for _, entry := range entries {
name := entry.name
// Try to get tensor by stripped name first, then with component prefix.
// Blobs may store tensors with the full prefixed name (e.g., "text_encoder/model.layers.0.weight")
// while the tensors map uses stripped names (e.g., "model.layers.0.weight").
lookupName := name
arr := sf.Get(lookupName)
if arr == nil && mw.component != "" {
lookupName = mw.component + "/" + name
arr = sf.Get(lookupName)
}
if arr != nil {
// Single-tensor blob or tensor found by name
if dtype != 0 && arr.Dtype() != dtype {
arr = mlx.AsType(arr, dtype)
}
arr = mlx.Contiguous(arr)
mw.cache[name] = arr
arrays = append(arrays, arr)
// Check for scale tensor
if scale := sf.Get(lookupName + ".scale"); scale != nil {
scale = mlx.Contiguous(scale)
mw.cache[name+"_scale"] = scale
arrays = append(arrays, scale)
}
// Check for bias tensor
if bias := sf.Get(lookupName + ".bias"); bias != nil {
bias = mlx.Contiguous(bias)
mw.cache[name+"_qbias"] = bias
arrays = append(arrays, bias)
}
} else {
// Packed blob: manifest name is a group prefix, not a tensor name.
// Load all individual tensors from the blob.
tensorNames, err := ParseBlobTensorNames(path)
if err != nil {
for _, h := range nativeHandles {
h.Free()
}
return fmt.Errorf("parse packed blob for %s: %w", name, err)
}
for _, tensorName := range tensorNames {
tArr := sf.Get(tensorName)
if tArr == nil {
continue
}
if dtype != 0 && tArr.Dtype() != dtype {
tArr = mlx.AsType(tArr, dtype)
}
tArr = mlx.Contiguous(tArr)
// Strip component prefix from blob-internal names so cache keys
// match the stripped names used by LoadModule.
cacheName := tensorName
if mw.component != "" {
cacheName = strings.TrimPrefix(tensorName, mw.component+"/")
}
mw.cache[cacheName] = tArr
arrays = append(arrays, tArr)
// Check for scale tensor
if scale := sf.Get(tensorName + ".scale"); scale != nil {
scale = mlx.Contiguous(scale)
mw.cache[cacheName+"_scale"] = scale
arrays = append(arrays, scale)
}
// Check for bias tensor
if bias := sf.Get(tensorName + ".bias"); bias != nil {
bias = mlx.Contiguous(bias)
mw.cache[cacheName+"_qbias"] = bias
arrays = append(arrays, bias)
}
}
}
}
// Make contiguous copy to ensure independence from mmap
arr = mlx.Contiguous(arr)
mw.cache[name] = arr
arrays = append(arrays, arr)
}
// Batch evaluate all tensors at once (much faster than one at a time)
@@ -117,30 +206,50 @@ func (mw *ManifestWeights) GetTensor(name string) (*mlx.Array, error) {
}
// ListTensors returns all tensor names in sorted order.
// Includes both manifest tensor names and scale/bias entries from combined blobs.
func (mw *ManifestWeights) ListTensors() []string {
names := make([]string, 0, len(mw.tensors))
seen := make(map[string]bool, len(mw.tensors)+len(mw.cache))
for name := range mw.tensors {
seen[name] = true
}
// Also include cache entries (scale/bias from combined blobs)
for name := range mw.cache {
seen[name] = true
}
names := make([]string, 0, len(seen))
for name := range seen {
names = append(names, name)
}
sort.Strings(names)
return names
}
// HasTensor checks if a tensor exists.
// HasTensor checks if a tensor exists in the manifest or cache.
func (mw *ManifestWeights) HasTensor(name string) bool {
_, ok := mw.tensors[name]
return ok
if _, ok := mw.tensors[name]; ok {
return true
}
// Also check cache for scale/bias entries from combined blobs
if _, ok := mw.cache[name]; ok {
return true
}
return false
}
// Quantization returns the model's quantization type from model_index.json.
// Quantization returns the model's quantization type.
// Returns the quant_type from blob metadata (e.g., "int4", "int8", "nvfp4", "mxfp8").
// Returns empty string if not quantized.
// Falls back to detecting from tensor names and shapes if not in config.
// Falls back to model_index.json for image gen models.
func (mw *ManifestWeights) Quantization() string {
if mw.quantType != "" {
return strings.ToUpper(mw.quantType)
}
if mw.manifest == nil {
return ""
}
// Try to read from model_index.json first
// Fallback: read from model_index.json (for image gen models)
var index struct {
Quantization string `json:"quantization"`
}
@@ -148,89 +257,22 @@ func (mw *ManifestWeights) Quantization() string {
return index.Quantization
}
// Fallback: detect from tensor names
// Check if any tensors have _scale suffix (indicates quantization)
hasScales := false
hasQBias := false
for name := range mw.tensors {
if strings.HasSuffix(name, ".weight_scale") {
hasScales = true
}
if strings.HasSuffix(name, ".weight_qbias") {
hasQBias = true
}
}
if !hasScales {
// No scales = not quantized
return ""
}
// Has scales but no qbias = NVFP4 (or other non-affine mode)
if !hasQBias {
return "NVFP4"
}
// Has both scales and qbias = affine mode
// Need to determine FP4 vs FP8 from tensor shapes
// FP4: weight last dim is 1/8 of scales last dim * group_size
// FP8: weight last dim is 1/4 of scales last dim * group_size
//
// For affine mode with group_size=32:
// - FP4 (4 bits): 8 elements packed per uint32, so weight_dim = orig_dim / 8
// - FP8 (8 bits): 4 elements packed per uint32, so weight_dim = orig_dim / 4
// scales_dim = orig_dim / group_size
// So: weight_dim / scales_dim = group_size / pack_factor
// FP4: ratio = 32/8 = 4
// FP8: ratio = 32/4 = 8
// Find a weight/scale pair to check the ratio
for name := range mw.tensors {
if !strings.HasSuffix(name, ".weight") || strings.Contains(name, "_scale") || strings.Contains(name, "_qbias") {
continue
}
scaleName := name + "_scale"
if _, ok := mw.tensors[scaleName]; !ok {
continue
}
// Load both tensors to check shapes
weightLayer := mw.tensors[name]
scaleLayer := mw.tensors[scaleName]
// Get shapes from manifest layer metadata if available
// For now, default to FP4 since it's more common
// The actual shape check would require loading the tensor
// Simple heuristic: check if scale tensor is ~4x smaller than weight
// FP4: weight is packed 8 per uint32, scales are 1 per group (32)
// So scale size should be ~weight_size * 8 / 32 = weight_size / 4
// FP8: weight is packed 4 per uint32, scales are 1 per group (32)
// So scale size should be ~weight_size * 4 / 32 = weight_size / 8
// Rough size heuristic (assuming float16 scales)
// Q4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8
// Q8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16
ratio := float64(weightLayer.Size) / float64(scaleLayer.Size)
if ratio < 12 {
// Closer to 8 = Q4
return "Q4"
}
// Closer to 16 = Q8
return "Q8"
}
// Default to Q4 for affine mode (most common)
return "Q4"
return ""
}
// GroupSize returns the quantization group size from model_index.json.
// GroupSize returns the quantization group size.
// Returns the group_size from blob metadata.
// Returns 0 if not specified (caller should use default based on quantization type).
func (mw *ManifestWeights) GroupSize() int {
if mw.groupSize > 0 {
return mw.groupSize
}
if mw.manifest == nil {
return 0
}
// Fallback: read from model_index.json (for image gen models)
var index struct {
GroupSize int `json:"group_size"`
}

View File

@@ -1544,6 +1544,18 @@ func (s *SafetensorsFile) Count() int {
return 0
}
// GetMetadata retrieves a metadata value by key from the safetensors file
func (s *SafetensorsFile) GetMetadata(key string) string {
cKey := C.CString(key)
defer C.free(unsafe.Pointer(cKey))
var cValue *C.char
if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
return ""
}
return C.GoString(cValue)
}
// Free releases the safetensors file
func (s *SafetensorsFile) Free() {
C.mlx_map_string_to_array_free(s.arrays)
@@ -1578,6 +1590,41 @@ func SaveSafetensors(path string, arrays map[string]*Array) error {
return nil
}
// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata key/value pairs.
// This is like SaveSafetensors but inserts metadata into the __metadata__ section.
func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
// Create the array map
cArrays := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cArrays)
for name, arr := range arrays {
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
C.free(unsafe.Pointer(cName))
}
// Create metadata map
cMeta := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMeta)
for key, value := range metadata {
cKey := C.CString(key)
cValue := C.CString(value)
C.mlx_map_string_to_string_insert(cMeta, cKey, cValue)
C.free(unsafe.Pointer(cKey))
C.free(unsafe.Pointer(cValue))
}
// Save
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
return fmt.Errorf("failed to save safetensors: %s", path)
}
return nil
}
// ============ NPY Loading ============
// LoadNpy loads a numpy array from an npy file

View File

@@ -41,13 +41,11 @@ func (td *TensorData) Reader() io.Reader {
return td.reader
}
// SafetensorsReader returns a reader that outputs the tensor wrapped in
// minimal safetensors format. This allows using mlx_load_safetensors on
// individual tensor blobs for native zero-copy loading.
func (td *TensorData) SafetensorsReader() io.Reader {
// Build minimal safetensors header with tensor named "data"
header := map[string]tensorInfo{
"data": {
// safetensorsHeader builds the JSON header for a minimal safetensors blob
// containing a single tensor keyed by its name.
func (td *TensorData) safetensorsHeader() []byte {
header := map[string]any{
td.Name: tensorInfo{
Dtype: td.Dtype,
Shape: td.Shape,
DataOffsets: [2]int{0, int(td.Size)},
@@ -58,6 +56,15 @@ func (td *TensorData) SafetensorsReader() io.Reader {
// Pad header to 8-byte alignment
padding := (8 - len(headerJSON)%8) % 8
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
return headerJSON
}
// SafetensorsReader returns a reader that outputs the tensor wrapped in
// minimal safetensors format. This allows using mlx_load_safetensors on
// individual tensor blobs for native zero-copy loading.
// The tensor is keyed by its name in the safetensors header.
func (td *TensorData) SafetensorsReader() io.Reader {
headerJSON := td.safetensorsHeader()
// Build header with size prefix
headerBuf := new(bytes.Buffer)
@@ -71,16 +78,77 @@ func (td *TensorData) SafetensorsReader() io.Reader {
// SafetensorsSize returns the total size of the safetensors-wrapped tensor.
func (td *TensorData) SafetensorsSize() int64 {
header := map[string]tensorInfo{
"data": {
headerJSON := td.safetensorsHeader()
return 8 + int64(len(headerJSON)) + td.Size
}
// NewTensorDataFromBytes creates a TensorData from raw tensor bytes.
// This is useful for constructing packed blobs from already-extracted data.
func NewTensorDataFromBytes(name, dtype string, shape []int32, rawData []byte) *TensorData {
return &TensorData{
Name: name,
Dtype: dtype,
Shape: shape,
Size: int64(len(rawData)),
reader: io.NewSectionReader(bytes.NewReader(rawData), 0, int64(len(rawData))),
}
}
// ExtractRawFromSafetensors reads a safetensors-wrapped reader and extracts
// the raw tensor data bytes (stripping the header).
func ExtractRawFromSafetensors(r io.Reader) ([]byte, error) {
// Read header size (8 bytes, little endian)
var headerSize uint64
if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil {
return nil, fmt.Errorf("failed to read header size: %w", err)
}
// Skip header
if _, err := io.CopyN(io.Discard, r, int64(headerSize)); err != nil {
return nil, fmt.Errorf("failed to skip header: %w", err)
}
// Read remaining bytes (the raw tensor data)
return io.ReadAll(r)
}
// BuildPackedSafetensorsReader builds a streaming io.Reader that outputs a valid
// safetensors file containing multiple tensors. Used for packing expert tensors
// into a single blob without loading all data into memory.
// Each TensorData must have been obtained from GetTensor.
func BuildPackedSafetensorsReader(tensors []*TensorData) io.Reader {
// Build the header with sequential data offsets
header := make(map[string]tensorInfo, len(tensors))
var offset int
for _, td := range tensors {
header[td.Name] = tensorInfo{
Dtype: td.Dtype,
Shape: td.Shape,
DataOffsets: [2]int{0, int(td.Size)},
},
DataOffsets: [2]int{offset, offset + int(td.Size)},
}
offset += int(td.Size)
}
headerJSON, _ := json.Marshal(header)
// Pad header to 8-byte alignment
padding := (8 - len(headerJSON)%8) % 8
return 8 + int64(len(headerJSON)) + int64(padding) + td.Size
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
// Build header with size prefix
headerBuf := new(bytes.Buffer)
binary.Write(headerBuf, binary.LittleEndian, uint64(len(headerJSON)))
headerBuf.Write(headerJSON)
// Build multi-reader: header + all tensor data readers
readers := make([]io.Reader, 0, 1+len(tensors))
readers = append(readers, headerBuf)
for _, td := range tensors {
td.reader.Seek(0, io.SeekStart)
readers = append(readers, td.reader)
}
return io.MultiReader(readers...)
}
// OpenForExtraction opens a safetensors file for tensor extraction.

View File

@@ -17,7 +17,7 @@ type WeightSource interface {
GetTensor(name string) (*mlx.Array, error)
ListTensors() []string
HasTensor(name string) bool
Quantization() string // Returns "NVFP4", "Q4", "Q8", or ""
Quantization() string // Returns "NVFP4", "INT4", "INT8", or ""
GroupSize() int // Returns quantization group size, or 0 if not specified
}

96
x/mlxrunner/cache.go Normal file
View File

@@ -0,0 +1,96 @@
//go:build mlx
package mlxrunner
import (
"log/slog"
"github.com/ollama/ollama/x/mlxrunner/cache"
)
type CacheEntry struct {
Caches []cache.Cache
Count int
Entries map[int32]*CacheEntry
}
func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
current := &CacheEntry{Entries: s.CacheEntries}
index, cacheIndex := 0, -1
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
break
}
current = current.Entries[token]
if len(current.Caches) > 0 {
cacheIndex = index
}
index += 1
}
if cacheIndex == len(tokens)-1 {
slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
return current.Caches, []int32{}
} else if cacheIndex > 1 {
slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
return current.Caches, tokens[cacheIndex+1:]
} else if index > 0 && cacheIndex < 0 {
type stackItem struct {
entry *CacheEntry
tokens []int32
}
var best, item stackItem
stack := []stackItem{{entry: current, tokens: []int32{}}}
for len(stack) > 0 {
item, stack = stack[len(stack)-1], stack[:len(stack)-1]
if len(item.entry.Caches) > 0 {
if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
best = item
}
} else {
for token, entry := range item.entry.Entries {
stack = append(stack, stackItem{
entry: entry,
tokens: append(item.tokens, token),
})
}
}
}
prefix := min(len(tokens)-1, index)
caches := make([]cache.Cache, len(best.entry.Caches))
trim := len(best.tokens)+1
for i := range caches {
caches[i] = best.entry.Caches[i].Clone()
caches[i].Trim(trim)
}
slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
return caches, tokens[prefix:]
}
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
}
func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
current := &CacheEntry{Entries: s.CacheEntries}
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
current.Entries[token] = &CacheEntry{
Entries: make(map[int32]*CacheEntry),
}
}
current = current.Entries[token]
}
if len(current.Caches) > 0 {
current.Count += 1
} else {
current.Caches = caches
}
}

198
x/mlxrunner/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,198 @@
//go:build mlx
package cache
import (
"log/slog"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Cache interface {
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
State() (keys, values *mlx.Array)
Trim(int) int
Clone() Cache
Offset() int
Len() int
}
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
// Grow buffer if needed
if c.keys == nil || (prev+L) > c.keys.Dim(2) {
steps := (c.step + L - 1) / c.step
newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
if c.keys != nil {
if prev%c.step != 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
}
c.keys.Set(c.keys.Concatenate(2, newKeys))
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
}
}
c.offset += L
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
if c.offset == c.keys.Dim(2) {
return c.keys, c.values
}
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n
return n
}
func (c *KVCache) Clone() Cache {
return &KVCache{
keys: c.keys.Clone(),
values: c.values.Clone(),
offset: c.offset,
step: c.step,
}
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
maxSize int
idx int
*KVCache
}
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
}
func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
if keys.Dim(2) > 1 {
return c.concat(keys, values)
}
return c.update(keys, values)
}
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
if c.keys == nil {
c.keys, c.values = keys, values
} else {
if c.idx < c.keys.Dim(2) {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
}
// Trim to max_size to maintain sliding window
if trim := c.idx - c.maxSize + 1; trim > 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
}
c.keys.Set(c.keys.Concatenate(2, keys))
c.values.Set(c.values.Concatenate(2, values))
c.idx = c.keys.Dim(2)
}
c.offset += keys.Dim(2)
c.idx = c.keys.Dim(2)
return c.keys, c.values
}
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
slog.Debug("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
// Grow buffer if not yet at max
if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
newSize := min(c.step, c.maxSize-prev)
newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
if c.keys != nil {
c.keys.Set(c.keys.Concatenate(2, newKeys))
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
}
c.idx = prev
}
// Trim to max_size to maintain sliding window
if trim := c.keys.Dim(2) - c.maxSize; trim > 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
c.idx = c.maxSize
}
// Rotate when hitting max
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
c.offset += L
c.idx += L
validLen := min(c.offset, c.maxSize)
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
}
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
if c.offset < c.keys.Dim(2) {
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
return c.keys, c.values
}
func (c *RotatingKVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n
c.idx -= n
return n
}
func (c *RotatingKVCache) Clone() Cache {
return &RotatingKVCache{
maxSize: c.maxSize,
idx: c.idx,
KVCache: c.KVCache.Clone().(*KVCache),
}
}
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }

174
x/mlxrunner/client.go Normal file
View File

@@ -0,0 +1,174 @@
package mlxrunner
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"math"
"net"
"net/http"
"net/url"
"os/exec"
"strconv"
"strings"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
)
type Client struct {
Port int
*exec.Cmd
}
func (c *Client) JoinPath(path string) string {
return (&url.URL{
Scheme: "http",
Host: net.JoinHostPort("127.0.0.1", strconv.Itoa(c.Port)),
}).JoinPath(path).String()
}
func (c *Client) CheckError(w *http.Response) error {
if w.StatusCode >= 400 {
return errors.New(w.Status)
}
return nil
}
// Close implements llm.LlamaServer.
func (c *Client) Close() error {
return c.Cmd.Process.Kill()
}
// Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(req); err != nil {
return err
}
w, err := http.Post(c.JoinPath("/v1/completions"), "application/json", &b)
if err != nil {
return err
}
defer w.Body.Close()
if err := c.CheckError(w); err != nil {
return err
}
scanner := bufio.NewScanner(w.Body)
for scanner.Scan() {
bts := scanner.Bytes()
var resp llm.CompletionResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
fn(resp)
}
return nil
}
func (c *Client) ContextLength() int {
return math.MaxInt
}
// Detokenize implements llm.LlamaServer.
func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
panic("unimplemented")
}
// Embedding implements llm.LlamaServer.
func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
panic("unimplemented")
}
// GetDeviceInfos implements llm.LlamaServer.
func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
panic("unimplemented")
}
// GetPort implements llm.LlamaServer.
func (c *Client) GetPort() int {
return c.Port
}
// HasExited implements llm.LlamaServer.
func (c *Client) HasExited() bool {
panic("unimplemented")
}
// Load implements llm.LlamaServer.
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) {
w, err := http.Post(c.JoinPath("/v1/models"), "application/json", nil)
if err != nil {
return nil, err
}
defer w.Body.Close()
return []ml.DeviceID{}, nil
}
// ModelPath implements llm.LlamaServer.
func (c *Client) ModelPath() string {
panic("unimplemented")
}
// Pid implements llm.LlamaServer.
func (c *Client) Pid() int {
panic("unimplemented")
}
// Ping implements llm.LlamaServer.
func (c *Client) Ping(ctx context.Context) error {
w, err := http.Get(c.JoinPath("/v1/status"))
if err != nil {
return err
}
defer w.Body.Close()
return nil
}
// Tokenize implements llm.LlamaServer.
func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
w, err := http.Post(c.JoinPath("/v1/tokenize"), "text/plain", strings.NewReader(content))
if err != nil {
return nil, err
}
defer w.Body.Close()
var tokens []int
if err := json.NewDecoder(w.Body).Decode(&tokens); err != nil {
return nil, err
}
return tokens, nil
}
// TotalSize implements llm.LlamaServer.
func (c *Client) TotalSize() uint64 {
panic("unimplemented")
}
// VRAMByGPU implements llm.LlamaServer.
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
panic("unimplemented")
}
// VRAMSize implements llm.LlamaServer.
func (c *Client) VRAMSize() uint64 {
panic("unimplemented")
}
// WaitUntilRunning implements llm.LlamaServer.
func (c *Client) WaitUntilRunning(ctx context.Context) error {
panic("unimplemented")
}
var _ llm.LlamaServer = (*Client)(nil)

3
x/mlxrunner/mlx/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
_deps
build
dist

View File

@@ -0,0 +1,26 @@
cmake_minimum_required(VERSION 3.5)
project(mlx)
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
endif()
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG ${MLX_C_GIT_TAG}
)
FetchContent_MakeAvailable(mlx-c)

23
x/mlxrunner/mlx/act.go Normal file
View File

@@ -0,0 +1,23 @@
//go:build mlx
package mlx
// #include "generated.h"
import "C"
import "math"
func GELUApprox(t *Array) *Array {
return t.Multiply(
FromValue[float32](0.5),
).Multiply(
t.Add(
t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
).Multiply(
FromValue(float32(math.Sqrt(2 / math.Pi))),
).Tanh().Add(FromValue[float32](1.0)),
).AsType(t.DType())
}
func SILU(t *Array) *Array {
return t.Multiply(t.Sigmoid()).AsType(t.DType())
}

273
x/mlxrunner/mlx/array.go Normal file
View File

@@ -0,0 +1,273 @@
//go:build mlx
package mlx
// #include "generated.h"
import "C"
import (
"encoding/binary"
"log/slog"
"reflect"
"strings"
"time"
"unsafe"
"github.com/ollama/ollama/logutil"
)
type tensorDesc struct {
name string
inputs []*Array
numRefs int
}
func (d tensorDesc) LogValue() slog.Value {
return slog.GroupValue(
slog.String("name", d.name),
slog.Int("inputs", len(d.inputs)),
slog.Int("num_refs", d.numRefs),
)
}
type Array struct {
ctx C.mlx_array
desc tensorDesc
}
// constructor utilities
func New(name string, inputs ...*Array) *Array {
t := &Array{
desc: tensorDesc{
name: name,
inputs: inputs,
},
}
for _, input := range inputs {
input.desc.numRefs++
}
logutil.Trace("New", "t", t)
return t
}
type scalarTypes interface {
~bool | ~int | ~float32 | ~float64 | ~complex64
}
func FromValue[T scalarTypes](t T) *Array {
tt := New("")
switch v := any(t).(type) {
case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v))
case int:
tt.ctx = C.mlx_array_new_int(C.int(v))
case float32:
tt.ctx = C.mlx_array_new_float32(C.float(v))
case float64:
tt.ctx = C.mlx_array_new_float64(C.double(v))
case complex64:
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
default:
panic("unsupported type")
}
return tt
}
type arrayTypes interface {
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
~int8 | ~int16 | ~int32 | ~int64 |
~float32 | ~float64 |
~complex64
}
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
if len(shape) == 0 {
panic("shape must be provided for non-scalar tensors")
}
cShape := make([]C.int, len(shape))
for i := range shape {
cShape[i] = C.int(shape[i])
}
var dtype DType
switch reflect.TypeOf(s).Elem().Kind() {
case reflect.Bool:
dtype = DTypeBool
case reflect.Uint8:
dtype = DTypeUint8
case reflect.Uint16:
dtype = DTypeUint16
case reflect.Uint32:
dtype = DTypeUint32
case reflect.Uint64:
dtype = DTypeUint64
case reflect.Int8:
dtype = DTypeInt8
case reflect.Int16:
dtype = DTypeInt16
case reflect.Int32:
dtype = DTypeInt32
case reflect.Int64:
dtype = DTypeInt64
case reflect.Float32:
dtype = DTypeFloat32
case reflect.Float64:
dtype = DTypeFloat64
case reflect.Complex64:
dtype = DTypeComplex64
default:
panic("unsupported type")
}
bts := make([]byte, binary.Size(s))
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
panic(err)
}
tt := New("")
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
return tt
}
func (t *Array) Set(other *Array) {
other.desc.numRefs++
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx)
}
func (t *Array) Clone() *Array {
tt := New(t.desc.name, t.desc.inputs...)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
// misc. utilities
func (t *Array) Valid() bool {
return t.ctx.ctx != nil
}
func (t *Array) String() string {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_array_tostring(&str, t.ctx)
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
}
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{slog.Any("", t.desc)}
if t.Valid() {
attrs = append(attrs,
slog.Any("dtype", t.DType()),
slog.Any("shape", t.Dims()),
slog.Int("num_bytes", t.NumBytes()),
)
}
return slog.GroupValue(attrs...)
}
// shape utilities
func (t Array) Size() int {
return int(C.mlx_array_size(t.ctx))
}
func (t Array) NumBytes() int {
return int(C.mlx_array_nbytes(t.ctx))
}
func (t Array) NumDims() int {
return int(C.mlx_array_ndim(t.ctx))
}
func (t Array) Dims() []int {
dims := make([]int, t.NumDims())
for i := range dims {
dims[i] = t.Dim(i)
}
return dims
}
func (t Array) Dim(dim int) int {
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
}
func (t Array) DType() DType {
return DType(C.mlx_array_dtype(t.ctx))
}
// data utilities
func (t Array) Int() int {
var item C.int64_t
C.mlx_array_item_int64(&item, t.ctx)
return int(item)
}
func (t Array) Float() float64 {
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
}
func (t Array) Ints() []int {
ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f)
}
return ints
}
func (t Array) Floats() []float32 {
floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f)
}
return floats
}
func (t Array) Save(name string) error {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
C.mlx_save(cName, t.ctx)
return nil
}
func Free(s ...*Array) (n int) {
now := time.Now()
defer func() {
if n > 0 {
logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now))
}
}()
free := make([]*Array, 0, 8192)
fn := func(t *Array) {
if t.Valid() {
free = append(free, t.desc.inputs...)
t.desc.numRefs--
if t.desc.numRefs <= 0 {
logutil.Trace("Free", "t", t)
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
}
for _, t := range s {
fn(t)
}
for len(free) > 0 {
tail := free[len(free)-1]
free = free[:len(free)-1]
fn(tail)
}
return n
}

View File

@@ -0,0 +1,45 @@
//go:build mlx
package mlx
import "testing"
func TestFromValue(t *testing.T) {
for got, want := range map[*Array]DType{
FromValue(true): DTypeBool,
FromValue(false): DTypeBool,
FromValue(int(7)): DTypeInt32,
FromValue(float32(3.14)): DTypeFloat32,
FromValue(float64(2.71)): DTypeFloat64,
FromValue(complex64(1 + 2i)): DTypeComplex64,
} {
t.Run(want.String(), func(t *testing.T) {
if got.DType() != want {
t.Errorf("want %v, got %v", want, got)
}
})
}
}
func TestFromValues(t *testing.T) {
for got, want := range map[*Array]DType{
FromValues([]bool{true, false, true}, 3): DTypeBool,
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
FromValues([]uint16{1, 2, 3}, 3): DTypeUint16,
FromValues([]uint32{1, 2, 3}, 3): DTypeUint32,
FromValues([]uint64{1, 2, 3}, 3): DTypeUint64,
FromValues([]int8{-1, -2, -3}, 3): DTypeInt8,
FromValues([]int16{-1, -2, -3}, 3): DTypeInt16,
FromValues([]int32{-1, -2, -3}, 3): DTypeInt32,
FromValues([]int64{-1, -2, -3}, 3): DTypeInt64,
FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32,
FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64,
FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64,
} {
t.Run(want.String(), func(t *testing.T) {
if got.DType() != want {
t.Errorf("want %v, got %v", want, got)
}
})
}
}

96
x/mlxrunner/mlx/dtype.go Normal file
View File

@@ -0,0 +1,96 @@
//go:build mlx
package mlx
// #include "generated.h"
import "C"
type DType int
func (t DType) String() string {
switch t {
case DTypeBool:
return "BOOL"
case DTypeUint8:
return "U8"
case DTypeUint16:
return "U16"
case DTypeUint32:
return "U32"
case DTypeUint64:
return "U64"
case DTypeInt8:
return "I8"
case DTypeInt16:
return "I16"
case DTypeInt32:
return "I32"
case DTypeInt64:
return "I64"
case DTypeFloat16:
return "F16"
case DTypeFloat32:
return "F32"
case DTypeFloat64:
return "F64"
case DTypeBFloat16:
return "BF16"
case DTypeComplex64:
return "C64"
default:
return "Unknown"
}
}
func (t *DType) UnmarshalJSON(b []byte) error {
switch string(b) {
case `"BOOL"`:
*t = DTypeBool
case `"U8"`:
*t = DTypeUint8
case `"U16"`:
*t = DTypeUint16
case `"U32"`:
*t = DTypeUint32
case `"U64"`:
*t = DTypeUint64
case `"I8"`:
*t = DTypeInt8
case `"I16"`:
*t = DTypeInt16
case `"I32"`:
*t = DTypeInt32
case `"I64"`:
*t = DTypeInt64
case `"F16"`:
*t = DTypeFloat16
case `"F64"`:
*t = DTypeFloat64
case `"F32"`:
*t = DTypeFloat32
case `"BF16"`:
*t = DTypeBFloat16
case `"C64"`:
*t = DTypeComplex64
default:
return nil
}
return nil
}
const (
DTypeBool DType = C.MLX_BOOL
DTypeUint8 DType = C.MLX_UINT8
DTypeUint16 DType = C.MLX_UINT16
DTypeUint32 DType = C.MLX_UINT32
DTypeUint64 DType = C.MLX_UINT64
DTypeInt8 DType = C.MLX_INT8
DTypeInt16 DType = C.MLX_INT16
DTypeInt32 DType = C.MLX_INT32
DTypeInt64 DType = C.MLX_INT64
DTypeFloat16 DType = C.MLX_FLOAT16
DTypeFloat32 DType = C.MLX_FLOAT32
DTypeFloat64 DType = C.MLX_FLOAT64
DTypeBFloat16 DType = C.MLX_BFLOAT16
DTypeComplex64 DType = C.MLX_COMPLEX64
)

34
x/mlxrunner/mlx/dynamic.c Normal file
View File

@@ -0,0 +1,34 @@
#include "dynamic.h"
#include <stdio.h>
#ifdef _WIN32
#include <windows.h>
#define DLOPEN(path) LoadLibraryA(path)
#define DLCLOSE(handle) FreeLibrary((HMODULE)(handle))
#else
#ifdef __APPLE__
#include <mach-o/dyld.h>
#include <libgen.h>
#endif
#include <dlfcn.h>
#define DLOPEN(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
#define DLCLOSE(handle) dlclose(handle)
#endif
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
handle->ctx = (void*) DLOPEN(path);
CHECK(handle->ctx != NULL);
return 0;
}
int mlx_dynamic_load(mlx_dynamic_handle* handle, const char *path) {
return mlx_dynamic_open(handle, path);
}
void mlx_dynamic_unload(mlx_dynamic_handle* handle) {
if (handle->ctx) {
DLCLOSE(handle->ctx);
handle->ctx = NULL;
}
}

View File

@@ -0,0 +1,65 @@
//go:build mlx
package mlx
// #include "dynamic.h"
// #include "generated.h"
// #include <stdlib.h>
import "C"
import (
"io/fs"
"log/slog"
"os"
"path/filepath"
"runtime"
"unsafe"
)
func init() {
switch runtime.GOOS {
case "darwin":
case "windows":
default:
return
}
paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH")
if !ok {
slog.Debug("OLLAMA_LIBRARY_PATH not set, skipping mlx dynamic loading")
return
}
for _, path := range filepath.SplitList(paths) {
matches, err := fs.Glob(os.DirFS(path), "libmlxc.*")
if err != nil {
panic(err)
}
for _, match := range matches {
path := filepath.Join(paths, match)
slog.Info("Loading MLX dynamic library", "path", path)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
slog.Error("Failed to load MLX dynamic library", "path", path)
continue
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
C.mlx_dynamic_unload(&handle)
continue
}
slog.Info("Loaded MLX dynamic library", "path", path)
return
}
}
panic("Failed to load any MLX dynamic library")
}

41
x/mlxrunner/mlx/dynamic.h Normal file
View File

@@ -0,0 +1,41 @@
#ifndef MLX_DYNAMIC_H
#define MLX_DYNAMIC_H
#ifdef _WIN32
#include <windows.h>
#define DLSYM(handle, symbol) GetProcAddress((HMODULE)(handle), symbol)
#else
#include <dlfcn.h>
#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
#endif
#include <stdint.h>
// Provide fallback typedefs for float16_t and bfloat16_t on non-ARM64
// platforms where arm_fp16.h and arm_bf16.h are not available. These are
// only used as function pointer signature placeholders since MLX requires
// Apple Silicon at runtime.
#if !defined(__aarch64__) && !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
typedef uint16_t float16_t;
#endif
#if !defined(__aarch64__) && !defined(__ARM_FEATURE_BF16)
typedef uint16_t bfloat16_t;
#endif
#define ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
#define CHECK(x) if (!(x)) { ERROR("CHECK failed: " #x); }
#define CHECK_LOAD(handle, x) x##_ = DLSYM(handle, #x); CHECK(x##_)
typedef struct {
void* ctx;
} mlx_dynamic_handle;
int mlx_dynamic_load(
mlx_dynamic_handle* handle,
const char *path);
void mlx_dynamic_unload(
mlx_dynamic_handle* handle);
#endif // MLX_DYNAMIC_H

74
x/mlxrunner/mlx/fast.go Normal file
View File

@@ -0,0 +1,74 @@
//go:build mlx
package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array {
if mask == nil {
mask = New("")
}
sinks := New("")
mode := "causal"
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", query, key, value, mask, sinks)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
type LayerNorm struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_LAYERNORM", x)
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RMSNorm struct {
Weight Array `weight:"weight"`
}
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RoPE struct {
Dims int
Traditional bool
Base float32 `json:"rope_theta"`
Scale float32
}
func (r RoPE) Forward(t *Array, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", t, freqs)
C.mlx_fast_rope(
&out.ctx,
t.ctx,
C.int(r.Dims),
C._Bool(r.Traditional),
C.mlx_optional_float{
value: C.float(r.Base),
has_value: C._Bool(func() bool { return r.Base != 0 }()),
},
C.float(r.Scale),
C.int(offset),
freqs.ctx,
DefaultStream().ctx,
)
return out
}

2724
x/mlxrunner/mlx/generated.c Normal file
View File

File diff suppressed because it is too large Load Diff

7135
x/mlxrunner/mlx/generated.h Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
// This code is auto-generated; DO NOT EDIT.
#include "generated.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
{{ range .Functions }}
{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL;
{{- end }}
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
{{- range .Functions }}
CHECK_LOAD(handle, {{ .Name }});
{{- end }}
return 0;
}

View File

@@ -0,0 +1,22 @@
// This code is auto-generated; DO NOT EDIT.
#ifndef MLX_GENERATED_H
#define MLX_GENERATED_H
#include "dynamic.h"
#include "mlx/c/mlx.h"
{{ range .Functions }}
#undef {{ .Name }}
{{- end }}
{{ range .Functions }}
extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }};
{{- end }}
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
{{ range .Functions }}
static inline {{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }}
return {{ .Name }}_({{ .Args }});
{{ "}" }}
{{- end }}
#endif // MLX_GENERATED_H

View File

@@ -0,0 +1,135 @@
package main
import (
"embed"
"flag"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"text/template"
tree_sitter "github.com/tree-sitter/go-tree-sitter"
tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
)
//go:embed *.gotmpl
var fsys embed.FS
type Function struct {
Type,
Name,
Parameters,
Args string
}
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
var fn Function
fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
if params := node.ChildByFieldName("parameters"); params != nil {
fn.Parameters = params.Utf8Text(source)
fn.Args = ParseParameters(params, tc, source)
}
var types []string
for node.Parent() != nil && node.Parent().Kind() != "declaration" {
if node.Parent().Kind() == "pointer_declarator" {
types = append(types, "*")
}
node = node.Parent()
}
for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
types = append(types, sibling.Utf8Text(source))
}
slices.Reverse(types)
fn.Type = strings.Join(types, " ")
return fn
}
func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
var s []string
for _, child := range node.Children(tc) {
if child.IsNamed() {
child := child.ChildByFieldName("declarator")
for child != nil && child.Kind() != "identifier" {
if child.Kind() == "parenthesized_declarator" {
child = child.Child(1)
} else {
child = child.ChildByFieldName("declarator")
}
}
if child != nil {
s = append(s, child.Utf8Text(source))
}
}
}
return strings.Join(s, ", ")
}
func main() {
var output string
flag.StringVar(&output, "output", ".", "Output directory for generated files")
flag.Parse()
parser := tree_sitter.NewParser()
defer parser.Close()
language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
parser.SetLanguage(language)
query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
defer query.Close()
qc := tree_sitter.NewQueryCursor()
defer qc.Close()
var funs []Function
for _, arg := range flag.Args() {
bts, err := os.ReadFile(arg)
if err != nil {
fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
continue
}
tree := parser.Parse(bts, nil)
defer tree.Close()
tc := tree.Walk()
defer tc.Close()
matches := qc.Matches(query, tree.RootNode(), bts)
for match := matches.Next(); match != nil; match = matches.Next() {
for _, capture := range match.Captures {
funs = append(funs, ParseFunction(&capture.Node, tc, bts))
}
}
}
tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
return
}
for _, tmpl := range tmpl.Templates() {
name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
fmt.Println("Generating", name)
f, err := os.Create(name)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
continue
}
defer f.Close()
if err := tmpl.Execute(f, map[string]any{
"Functions": funs,
}); err != nil {
fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
}
}
}

45
x/mlxrunner/mlx/io.go Normal file
View File

@@ -0,0 +1,45 @@
//go:build mlx
package mlx
// #include "generated.h"
import "C"
import (
"iter"
"unsafe"
)
func Load(path string) iter.Seq2[string, *Array] {
return func(yield func(string, *Array) bool) {
string2array := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(string2array)
string2string := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(string2string)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
cpu := C.mlx_default_cpu_stream_new()
defer C.mlx_stream_free(cpu)
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
it := C.mlx_map_string_to_array_iterator_new(string2array)
defer C.mlx_map_string_to_array_iterator_free(it)
for {
var key *C.char
value := C.mlx_array_new()
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
break
}
name := C.GoString(key)
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
break
}
}
}
}

87
x/mlxrunner/mlx/memory.go Normal file
View File

@@ -0,0 +1,87 @@
//go:build mlx
package mlx
// #include "generated.h"
import "C"
import (
"fmt"
"log/slog"
"strconv"
)
func (b Byte) String() string {
return strconv.FormatInt(int64(b), 10) + " B"
}
func (b KibiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<10), 'f', 2, 64) + " KiB"
}
func (b MebiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(2*10)), 'f', 2, 64) + " MiB"
}
func (b GibiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(3*10)), 'f', 2, 64) + " GiB"
}
func (b TebiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(4*10)), 'f', 2, 64) + " TiB"
}
func PrettyBytes(n int) fmt.Stringer {
switch {
case n < 1<<10:
return Byte(n)
case n < 1<<(2*10):
return KibiByte(n)
case n < 1<<(3*10):
return MebiByte(n)
case n < 1<<(4*10):
return GibiByte(n)
default:
return TebiByte(n)
}
}
func ActiveMemory() int {
var active C.size_t
C.mlx_get_active_memory(&active)
return int(active)
}
func CacheMemory() int {
var cache C.size_t
C.mlx_get_cache_memory(&cache)
return int(cache)
}
func PeakMemory() int {
var peak C.size_t
C.mlx_get_peak_memory(&peak)
return int(peak)
}
type Memory struct{}
func (Memory) LogValue() slog.Value {
return slog.GroupValue(
slog.Any("active", PrettyBytes(ActiveMemory())),
slog.Any("cache", PrettyBytes(CacheMemory())),
slog.Any("peak", PrettyBytes(PeakMemory())),
)
}
type (
Byte int
KibiByte int
MebiByte int
GibiByte int
TebiByte int
)
func ClearCache() {
C.mlx_clear_cache()
}

40
x/mlxrunner/mlx/mlx.go Normal file
View File

@@ -0,0 +1,40 @@
//go:build mlx
package mlx
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
//go:generate cmake --build build --parallel
//go:generate cmake --install build
//go:generate sh -c "go run generator/main.go -output=. ./dist/include/mlx/c/*.h"
// #cgo CXXFLAGS: -std=c++17
// #cgo CPPFLAGS: -I${SRCDIR}/dist/include
// #cgo LDFLAGS: -L${SRCDIR}/dist/lib -lstdc++
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
// #include "generated.h"
import "C"
func doEval(outputs []*Array, async bool) {
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
for _, output := range outputs {
if output.Valid() {
C.mlx_vector_array_append_value(vector, output.ctx)
}
}
if async {
C.mlx_async_eval(vector)
} else {
C.mlx_eval(vector)
}
}
func AsyncEval(outputs ...*Array) {
doEval(outputs, true)
}
func Eval(outputs ...*Array) {
doEval(outputs, false)
}

38
x/mlxrunner/mlx/nn.go Normal file
View File

@@ -0,0 +1,38 @@
//go:build mlx
package mlx
type Linear struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
}
// Forward computes the linear transformation: x @ Weight.T + Bias
func (m Linear) Forward(x *Array) *Array {
w := m.Weight.Transpose(1, 0)
if m.Bias.Valid() {
return m.Bias.Addmm(x, w, 1.0, 1.0)
}
return x.Matmul(w)
}
func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
w := m.Weight.Transpose(0, 2, 1)
// TODO: bias
return x.GatherMM(w, lhs, rhs, sorted)
}
type Embedding struct {
Weight Array `weight:"weight"`
}
func (e *Embedding) Forward(indices *Array) *Array {
return e.Weight.TakeAxis(indices, 0)
}
func (e *Embedding) AsLinear() Linear {
return Linear{
Weight: e.Weight,
}
}

256
x/mlxrunner/mlx/ops.go Normal file
View File

@@ -0,0 +1,256 @@
//go:build mlx
package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func (t *Array) Abs() *Array {
out := New("ABS", t)
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Add(other *Array) *Array {
out := New("ADD", t, other)
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
out := New("ADDMM", t, a, b)
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
return out
}
func (t *Array) Argmax(axis int, keepDims bool) *Array {
out := New("ARGMAX", t)
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
out := New("ARGPARTITION", t)
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) ArgsortAxis(axis int) *Array {
out := New("ARGSORT_AXIS", t)
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) AsType(dtype DType) *Array {
out := New("AS_TYPE", t)
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
cStrides := make([]C.int64_t, len(strides))
for i, s := range strides {
cStrides[i] = C.int64_t(s)
}
out := New("AS_STRIDED", t)
C.mlx_as_strided(
&out.ctx, t.ctx,
unsafe.SliceData(cShape), C.size_t(len(shape)),
unsafe.SliceData(cStrides), C.size_t(len(strides)),
C.size_t(offset),
DefaultStream().ctx,
)
return out
}
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
s := append([]*Array{t}, others...)
for _, other := range s {
C.mlx_vector_array_append_value(vector, other.ctx)
}
out := New("CONCATENATE", s...)
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Divide(other *Array) *Array {
out := New("DIVIDE", t, other)
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS", t)
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Flatten(startAxis, endAxis int) *Array {
out := New("FLATTEN", t)
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
return out
}
func (t *Array) FloorDivide(other *Array) *Array {
out := New("FLOOR_DIVIDE", t, other)
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
if lhs == nil {
lhs = New("")
}
if rhs == nil {
rhs = New("")
}
out := New("GATHER_MM", t, other, lhs, rhs)
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
return out
}
func (t *Array) Logsumexp(keepDims bool) *Array {
out := New("LOGSUMEXP", t)
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL", t, other)
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Multiply(other *Array) *Array {
out := New("MULTIPLY", t, other)
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Negative() *Array {
out := New("NEGATIVE", t)
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Power(exponent *Array) *Array {
out := New("POWER", t, exponent)
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
return out
}
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS", t, indices, values)
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Reshape(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
out := New("RESHAPE", t)
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func (t *Array) Sigmoid() *Array {
out := New("SIGMOID", t)
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Sqrt() *Array {
out := New("SQRT", t)
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Squeeze(axis int) *Array {
out := New("SQUEEZE", t)
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) StackAxis(axis int, others ...*Array) *Array {
vectorData := make([]C.mlx_array, len(others)+1)
vectorData[0] = t.ctx
for i := range others {
vectorData[i+1] = others[i].ctx
}
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK_AXIS", append(others, t)...)
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Subtract(other *Array) *Array {
out := New("SUBTRACT", t, other)
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
out := New("SUM_AXIS", t)
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
out := New("TAKE_AXIS", t, indices)
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
out := New("TAKE_ALONG_AXIS", t, indices)
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Tanh() *Array {
out := New("TANH", t)
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Transpose(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i, axis := range axes {
cAxes[i] = C.int(axis)
}
out := New("TRANSPOSE", t)
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func Zeros(dtype DType, shape ...int) *Array {
cAxes := make([]C.int, len(shape))
for i := range shape {
cAxes[i] = C.int(shape[i])
}
t := New("ZEROS")
C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx)
return t
}

View File

@@ -0,0 +1,427 @@
//go:build mlx
package mlx
// #include "generated.h"
import "C"
import (
"reflect"
"unsafe"
)
// Quantization operations
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
res := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(res)
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx)
vecSize := int(C.mlx_vector_array_size(res))
w0 := New("QUANTIZE_W")
C.mlx_vector_array_get(&w0.ctx, res, 0)
w1 := New("QUANTIZE_S")
C.mlx_vector_array_get(&w1.ctx, res, 1)
if vecSize >= 3 {
w2 := New("QUANTIZE_B")
C.mlx_vector_array_get(&w2.ctx, res, 2)
return w0, w1, w2
}
return w0, w1, nil
}
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
optDtype := C.mlx_optional_dtype{has_value: false}
inputs := []*Array{w, scales}
var b C.mlx_array
if biases != nil {
b = biases.ctx
inputs = append(inputs, biases)
}
out := New("DEQUANTIZE", inputs...)
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
return out
}
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
inputs := []*Array{x, w, scales}
var b C.mlx_array
if biases != nil {
b = biases.ctx
inputs = append(inputs, biases)
}
out := New("QUANTIZED_MATMUL", inputs...)
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
return out
}
func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
inputs := []*Array{x, w, scales}
var b, lhs, rhs C.mlx_array
if biases != nil {
b = biases.ctx
inputs = append(inputs, biases)
}
if lhsIndices != nil {
lhs = lhsIndices.ctx
inputs = append(inputs, lhsIndices)
}
if rhsIndices != nil {
rhs = rhsIndices.ctx
inputs = append(inputs, rhsIndices)
}
out := New("GATHER_QMM", inputs...)
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
return out
}
// Missing tensor ops
func Tile(a *Array, reps []int32) *Array {
cReps := make([]C.int, len(reps))
for i, r := range reps {
cReps[i] = C.int(r)
}
out := New("TILE", a)
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
return out
}
func Tri(n, m int32, k int) *Array {
out := New("TRI")
C.mlx_tri(&out.ctx, C.int(n), C.int(m), C.int(k), C.mlx_dtype(DTypeFloat32), DefaultStream().ctx)
return out
}
func Where(condition, a, b *Array) *Array {
out := New("WHERE", condition, a, b)
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Convenience wrappers (function-style for the model code)
func Stack(arrays []*Array, axis int) *Array {
vectorData := make([]C.mlx_array, len(arrays))
for i := range arrays {
vectorData[i] = arrays[i].ctx
}
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK", arrays...)
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func Neg(a *Array) *Array {
return a.Negative()
}
func Sum(a *Array, axis int, keepDims bool) *Array {
return a.SumAxis(axis, keepDims)
}
func Argsort(a *Array, axis int) *Array {
return a.ArgsortAxis(axis)
}
func Take(a *Array, indices *Array, axis int) *Array {
return a.TakeAxis(indices, axis)
}
func RSqrt(a *Array) *Array {
out := New("RSQRT", a)
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Mean(a *Array, axis int, keepDims bool) *Array {
out := New("MEAN_AXIS", a)
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func Argpartition(a *Array, kth int, axis int) *Array {
return a.ArgpartitionAxis(kth, axis)
}
func TakeAlongAxis(a, indices *Array, axis int) *Array {
return a.TakeAlongAxis(indices, axis)
}
// Function-style wrappers matching imagegen API
func Add(a, b *Array) *Array {
return a.Add(b)
}
func Sub(a, b *Array) *Array {
return a.Subtract(b)
}
func Mul(a, b *Array) *Array {
return a.Multiply(b)
}
func Div(a, b *Array) *Array {
return a.Divide(b)
}
func Matmul(a, b *Array) *Array {
return a.Matmul(b)
}
func Reshape(a *Array, shape ...int32) *Array {
axes := make([]int, len(shape))
for i, s := range shape {
axes[i] = int(s)
}
return a.Reshape(axes...)
}
func Transpose(a *Array, axes ...int) *Array {
return a.Transpose(axes...)
}
func ExpandDims(a *Array, axis int) *Array {
return a.ExpandDims(axis)
}
func Squeeze(a *Array, axis int) *Array {
return a.Squeeze(axis)
}
func Flatten(a *Array) *Array {
return a.Flatten(0, -1)
}
func Concatenate(arrays []*Array, axis int) *Array {
if len(arrays) == 0 {
return nil
}
return arrays[0].Concatenate(axis, arrays[1:]...)
}
func SliceStartStop(a *Array, start, stop []int32) *Array {
n := len(start)
cStart := make([]C.int, n)
cStop := make([]C.int, n)
cStrides := make([]C.int, n)
for i := 0; i < n; i++ {
cStart[i] = C.int(start[i])
cStop[i] = C.int(stop[i])
cStrides[i] = 1
}
out := New("SLICE", a)
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
return out
}
func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array {
if lhsIndices == nil {
lhsIndices = New("")
}
if rhsIndices == nil {
rhsIndices = New("")
}
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
}
func SiLU(a *Array) *Array {
sig := a.Sigmoid()
return a.Multiply(sig)
}
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", x, freqs)
C.mlx_fast_rope(
&out.ctx,
x.ctx,
C.int(dims),
C.bool(traditional),
C.mlx_optional_float{
value: C.float(base),
has_value: C.bool(func() bool { return base != 0 }()),
},
C.float(scale),
C.int(offset),
freqs.ctx,
DefaultStream().ctx,
)
return out
}
func Sigmoid(a *Array) *Array {
return a.Sigmoid()
}
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
mask := New("")
sinks := New("")
mode := ""
if causalMask {
mode = "causal"
}
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", q, k, v, mask, sinks)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
func RMSNormFn(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
func AddMM(c, a, b *Array, alpha, beta float32) *Array {
return c.Addmm(a, b, alpha, beta)
}
// Scalar helpers
func AddScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return a.Add(scalar)
}
func MulScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return a.Multiply(scalar)
}
func DivScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return a.Divide(scalar)
}
func FloorDivideScalar(a *Array, s int32) *Array {
scalar := FromValue(int(s))
return a.FloorDivide(scalar)
}
// Array constructors
func NewArrayInt32(data []int32, shape []int32) *Array {
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
out := New("NEW_ARRAY_INT32")
out.ctx = C.mlx_array_new_data(unsafe.Pointer(&data[0]), unsafe.SliceData(cShape), C.int(len(shape)), C.mlx_dtype(DTypeInt32))
return out
}
func NewScalarArray(value float32) *Array {
out := New("SCALAR")
out.ctx = C.mlx_array_new_float32(C.float(value))
return out
}
func ZerosF32(shape []int32) *Array {
return Zeros(DTypeFloat32, func() []int {
ints := make([]int, len(shape))
for i, s := range shape {
ints[i] = int(s)
}
return ints
}()...)
}
// Utility
func Collect(v any) []*Array {
var arrays []*Array
seen := make(map[uintptr]bool)
collect(reflect.ValueOf(v), &arrays, seen)
return arrays
}
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
if !v.IsValid() {
return
}
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return
}
ptr := v.Pointer()
if seen[ptr] {
return
}
seen[ptr] = true
if arr, ok := v.Interface().(*Array); ok {
if arr != nil && arr.Valid() {
*arrays = append(*arrays, arr)
}
return
}
collect(v.Elem(), arrays, seen)
return
}
switch v.Kind() {
case reflect.Struct:
// Check if this struct IS an Array (not a pointer to one)
if arr, ok := v.Addr().Interface().(*Array); ok {
if arr != nil && arr.Valid() {
*arrays = append(*arrays, arr)
}
return
}
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.CanInterface() {
collect(field, arrays, seen)
}
}
case reflect.Slice:
for i := 0; i < v.Len(); i++ {
collect(v.Index(i), arrays, seen)
}
case reflect.Map:
for _, key := range v.MapKeys() {
collect(v.MapIndex(key), arrays, seen)
}
case reflect.Interface:
if !v.IsNil() {
collect(v.Elem(), arrays, seen)
}
}
}
func EnableCompile() {
C.mlx_enable_compile()
}
func DisableCompile() {
C.mlx_disable_compile()
}

13
x/mlxrunner/mlx/random.go Normal file
View File

@@ -0,0 +1,13 @@
//go:build mlx
package mlx
// #include "generated.h"
import "C"
func (t *Array) Categorical(axis int) *Array {
key := New("")
out := New("", t, key)
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
return out
}

86
x/mlxrunner/mlx/slice.go Normal file
View File

@@ -0,0 +1,86 @@
//go:build mlx
package mlx
// #include "generated.h"
import "C"
import (
"cmp"
"unsafe"
)
type slice struct {
args []int
}
func Slice(args ...int) slice {
return slice{args: args}
}
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
if len(slices) != len(dims) {
panic("number of slice arguments must match number of tensor dimensions")
}
args := [3][]C.int{
make([]C.int, len(slices)),
make([]C.int, len(slices)),
make([]C.int, len(slices)),
}
for i, s := range slices {
switch len(s.args) {
case 0:
// slice[:]
args[0][i] = C.int(0)
args[1][i] = C.int(dims[i])
args[2][i] = C.int(1)
case 1:
// slice[i]
args[0][i] = C.int(s.args[0])
args[1][i] = C.int(s.args[0] + 1)
args[2][i] = C.int(1)
case 2:
// slice[i:j]
args[0][i] = C.int(s.args[0])
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
args[2][i] = C.int(1)
case 3:
// slice[i:j:k]
args[0][i] = C.int(s.args[0])
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
args[2][i] = C.int(s.args[2])
default:
panic("invalid slice arguments")
}
}
return args[0], args[1], args[2]
}
func (t *Array) Slice(slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE", t)
C.mlx_slice(
&out.ctx, t.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE_UPDATE", t, other)
C.mlx_slice_update(
&out.ctx, t.ctx, other.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}

45
x/mlxrunner/mlx/stream.go Normal file
View File

@@ -0,0 +1,45 @@
//go:build mlx
package mlx
// #include "generated.h"
import "C"
import (
"log/slog"
"sync"
)
type Device struct {
ctx C.mlx_device
}
func (d Device) LogValue() slog.Value {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_device_tostring(&str, d.ctx)
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
}
var DefaultDevice = sync.OnceValue(func() Device {
d := C.mlx_device_new()
C.mlx_get_default_device(&d)
return Device{d}
})
type Stream struct {
ctx C.mlx_stream
}
func (s Stream) LogValue() slog.Value {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_stream_tostring(&str, s.ctx)
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
}
var DefaultStream = sync.OnceValue(func() Stream {
s := C.mlx_stream_new()
C.mlx_get_default_stream(&s, DefaultDevice().ctx)
return Stream{s}
})

123
x/mlxrunner/pipeline.go Normal file
View File

@@ -0,0 +1,123 @@
//go:build mlx
package mlxrunner
import (
"bytes"
"errors"
"log/slog"
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func (r *Runner) TextGenerationPipeline(request Request) error {
if r.Model == nil {
return errors.New("model not loaded")
}
inputs := r.Tokenizer.Encode(request.Prompt, true)
caches, tokens := r.FindNearestCache(inputs)
if len(caches) == 0 {
caches = make([]cache.Cache, r.Model.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
}
}
total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
n := min(2<<10, total-processed-1)
temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
defer mlx.Free(temp)
mlx.Eval(func() []*mlx.Array {
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
s[2*i], s[2*i+1] = c.State()
}
return s
}()...)
processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
mlx.ClearCache()
}
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
logits := r.Model.Unembed(r.Model.Forward(token.ExpandDims(0), caches))
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true))
return request.Sample(logprobs), logprobs
}
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
mlx.AsyncEval(sample, logprobs)
var b bytes.Buffer
now := time.Now()
final := Response{PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
nextSample, nextLogprobs := step(sample)
mlx.AsyncEval(nextSample, nextLogprobs)
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
mlx.Eval(sample)
final.PromptTokensDuration = time.Since(now)
now = time.Now()
}
output := int32(sample.Int())
outputs = append(outputs, output)
if r.Tokenizer.IsEOS(output) {
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
break
}
request.Responses <- Response{
Text: r.Decode(output, &b),
Token: int(output),
}
mlx.Free(sample, logprobs)
if i%256 == 0 {
mlx.ClearCache()
}
sample, logprobs = nextSample, nextLogprobs
}
mlx.Free(sample, logprobs)
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
r.InsertCache(append(inputs, outputs...), caches)
return nil
}
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
token := r.Tokenizer.Decode([]int32{sample})
if _, err := b.WriteString(token); err != nil {
slog.Error("Failed to write token to buffer", "error", err)
return ""
}
if text := b.String(); utf8.ValidString(text) {
b.Reset()
return text
} else if b.Len() >= utf8.UTFMax {
b.Reset()
return text
}
return ""
}

139
x/mlxrunner/runner.go Normal file
View File

@@ -0,0 +1,139 @@
//go:build mlx
package mlxrunner
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net"
"net/http"
"time"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/models/glm4_moe_lite"
)
// TextModel is the interface that model implementations must satisfy.
type TextModel interface {
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
Unembed(x *mlx.Array) *mlx.Array
NumLayers() int
}
type Request struct {
TextCompletionsRequest
Responses chan Response
Pipeline func(Request) error
sample.Sampler
caches []cache.Cache
}
type TextCompletionsRequest struct {
Prompt string `json:"prompt"`
Options struct {
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TopK int `json:"top_k"`
MaxTokens int `json:"max_tokens"`
// Deprecated: use MaxTokens instead
NumPredict int `json:"num_predict"`
} `json:"options"`
}
type Response struct {
Text string `json:"content,omitempty"`
Token int `json:"token,omitempty"`
Logprobs []float32 `json:"logprobs,omitempty"`
Done bool `json:"done,omitempty"`
DoneReason int `json:"done_reason,omitempty"`
PromptTokens int `json:"prompt_eval_count,omitempty"`
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
CompletionTokens int `json:"eval_count,omitempty"`
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Runner struct {
Model TextModel
Tokenizer *tokenizer.Tokenizer
Requests chan Request
CacheEntries map[int32]*CacheEntry
}
func (r *Runner) Load(modelName string) error {
modelManifest, err := manifest.LoadManifest(modelName)
if err != nil {
return err
}
// Read config to detect architecture
configData, err := modelManifest.ReadConfig("config.json")
if err != nil {
return fmt.Errorf("failed to read config.json: %w", err)
}
var archConfig struct {
Architectures []string `json:"architectures"`
}
if err := json.Unmarshal(configData, &archConfig); err != nil {
return fmt.Errorf("failed to parse config.json: %w", err)
}
if len(archConfig.Architectures) == 0 {
return fmt.Errorf("no architectures found in config.json")
}
slog.Info("Model architecture", "arch", archConfig.Architectures[0])
switch archConfig.Architectures[0] {
case "Glm4MoeLiteForCausalLM", "GLM4MoeLite":
model, err := glm4_moe_lite.LoadFromManifest(modelManifest)
if err != nil {
return fmt.Errorf("failed to load GLM4-MoE-Lite model: %w", err)
}
r.Model = model
r.Tokenizer = model.Tokenizer()
default:
return fmt.Errorf("unsupported architecture: %s", archConfig.Architectures[0])
}
return nil
}
func (r *Runner) Run(host, port string, mux http.Handler) error {
g, ctx := errgroup.WithContext(context.Background())
g.Go(func() error {
for {
select {
case <-ctx.Done():
return nil
case request := <-r.Requests:
if err := request.Pipeline(request); err != nil {
break
}
close(request.Responses)
}
}
})
g.Go(func() error {
slog.Info("Starting HTTP server", "host", host, "port", port)
return http.ListenAndServe(net.JoinHostPort(host, port), mux)
})
return g.Wait()
}

View File

@@ -0,0 +1,77 @@
//go:build mlx
package sample
import (
"math"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Sampler interface {
Sample(*mlx.Array) *mlx.Array
}
func New(temp, top_p, min_p float32, top_k int) Sampler {
if temp == 0 {
return greedy{}
}
var samplers []Sampler
if top_p > 0 && top_p < 1 {
samplers = append(samplers, TopP(top_p))
}
if min_p != 0 {
samplers = append(samplers, MinP(min_p))
}
if top_k > 0 {
samplers = append(samplers, TopK(top_k))
}
samplers = append(samplers, Temperature(temp))
return chain(samplers)
}
type greedy struct{}
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
return logits.Argmax(-1, false)
}
type chain []Sampler
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
for _, sampler := range c {
logits = sampler.Sample(logits)
}
return logits
}
type Temperature float32
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
return logits.Multiply(mlx.FromValue(1 / float32(t))).Categorical(-1)
}
type TopP float32
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
// TODO: implement
return logprobs
}
type MinP float32
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
// TODO: implement
return logprobs
}
type TopK int
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}

176
x/mlxrunner/server.go Normal file
View File

@@ -0,0 +1,176 @@
//go:build mlx
package mlxrunner
import (
"bytes"
"cmp"
"encoding/json"
"flag"
"io"
"log/slog"
"net/http"
"os"
"strconv"
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/sample"
)
func Execute(args []string) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
var (
modelName string
port int
)
flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError)
flagSet.StringVar(&modelName, "model", "", "Model name")
flagSet.IntVar(&port, "port", 0, "Port to listen on")
_ = flagSet.Bool("verbose", false, "Enable debug logging")
flagSet.Parse(args)
runner := Runner{
Requests: make(chan Request),
CacheEntries: make(map[int32]*CacheEntry),
}
if err := runner.Load(modelName); err != nil {
return err
}
mux := http.NewServeMux()
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(map[string]any{
"status": 0,
"progress": 100,
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
})
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "POST":
fallthrough
case "GET":
if err := json.NewEncoder(w).Encode(map[string]any{
"Success": true,
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
case "DELETE":
// TODO: cleanup model and cache
}
})
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan Response)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
slog.Error("Failed to decode request", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
if request.Options.MaxTokens < 1 {
request.Options.MaxTokens = 16 << 10
}
request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New(
request.Options.Temperature,
request.Options.TopP,
request.Options.MinP,
request.Options.TopK,
)
runner.Requests <- request
w.Header().Set("Content-Type", "application/jsonl")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
for response := range request.Responses {
if err := enc.Encode(response); err != nil {
slog.Error("Failed to encode response", "error", err)
return
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
})
mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
var b bytes.Buffer
if _, err := io.Copy(&b, r.Body); err != nil {
slog.Error("Failed to read request body", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
tokens := runner.Tokenizer.Encode(b.String(), true)
if err := json.NewEncoder(w).Encode(tokens); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
})
for source, target := range map[string]string{
"GET /health": "/v1/status",
"POST /load": "/v1/models",
"POST /completion": "/v1/completions",
} {
mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect))
}
return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK}
t := time.Now()
mux.ServeHTTP(recorder, r)
var level slog.Level
switch {
case recorder.code >= 500:
level = slog.LevelError
case recorder.code >= 400:
level = slog.LevelWarn
case recorder.code >= 300:
return
}
slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status())
}))
}
type statusRecorder struct {
http.ResponseWriter
code int
}
func (w *statusRecorder) WriteHeader(code int) {
w.code = code
w.ResponseWriter.WriteHeader(code)
}
func (w *statusRecorder) Status() string {
return strconv.Itoa(w.code) + " " + http.StatusText(w.code)
}
func (w *statusRecorder) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}

View File

@@ -0,0 +1,10 @@
//go:build !mlx
package mlxrunner
import "errors"
// Execute returns an error when not built with MLX support.
func Execute(args []string) error {
return errors.New("MLX runner not available: build with mlx tag")
}

View File

@@ -0,0 +1,860 @@
//go:build mlx
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
package glm4_moe_lite
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math"
"os"
"strings"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
// RopeScaling holds RoPE scaling configuration
type RopeScaling struct {
Factor float32 `json:"factor"`
MscaleAllDim float32 `json:"mscale_all_dim"`
}
// Config holds GLM4-MoE-Lite model configuration
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
AttentionBias bool `json:"attention_bias"`
// MLA (Multi-head Latent Attention) parameters
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
// MoE parameters
NRoutedExperts int32 `json:"n_routed_experts"`
NSharedExperts int32 `json:"n_shared_experts"`
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
NormTopKProb bool `json:"norm_topk_prob"`
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
NGroup int32 `json:"n_group"`
TopKGroup int32 `json:"topk_group"`
// RoPE scaling
RopeScaling *RopeScaling `json:"rope_scaling"`
// Quantization parameters (set during load based on model quantization)
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
QuantBits int `json:"-"` // Bits per weight (4 or 8)
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
// Computed fields
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
}
// MLAAttention implements Multi-head Latent Attention with absorption.
type MLAAttention struct {
QAProj nn.LinearLayer
QALayerNorm *nn.RMSNorm
QBProj nn.LinearLayer
KVAProjWithMQA nn.LinearLayer
KVALayerNorm *nn.RMSNorm
EmbedQ *nn.MultiLinear
UnembedOut *nn.MultiLinear
OProj nn.LinearLayer
}
// Forward computes absorbed MLA attention output.
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QAProj.Forward(x)
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
q = a.QBProj.Forward(q)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
qNope := mlx.SliceStartStop(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
qPE := mlx.SliceStartStop(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
compressedKV := a.KVAProjWithMQA.Forward(x)
kvCompressed := mlx.SliceStartStop(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
kPE := mlx.SliceStartStop(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
kvLatent = mlx.ExpandDims(kvLatent, 1)
offset := 0
if c != nil {
offset = c.Offset()
}
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
qLatent := a.EmbedQ.Forward(qNope)
keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
cachedL := L
if c != nil {
placeholderValues := mlx.ZerosF32([]int32{B, 1, L, 0})
keys, _ = c.Update(keys, placeholderValues)
cachedL = int32(keys.Dim(2))
}
values := mlx.SliceStartStop(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank})
queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
out := mlx.ScaledDotProductAttentionCausal(queries, keys, values, cfg.Scale, L > 1)
out = a.UnembedOut.Forward(out)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
return a.OProj.Forward(out)
}
// DenseMLP implements the standard SwiGLU MLP for dense layers
type DenseMLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
// Forward applies the SwiGLU MLP
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(m.GateProj.Forward(x))
up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up))
}
// MoEGate implements the expert gating mechanism
type MoEGate struct {
Gate nn.LinearLayer
EScoreCorrectionBias *mlx.Array
}
// Forward computes expert selection indices and scores
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
gates := g.Gate.Forward(x)
scores := mlx.Sigmoid(gates)
origScores := scores
if g.EScoreCorrectionBias != nil {
scores = mlx.Add(scores, g.EScoreCorrectionBias)
}
topK := cfg.NumExpertsPerTok
negScores := mlx.Neg(scores)
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
dims := inds.Dims()
inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK})
scores = mlx.TakeAlongAxis(origScores, inds, -1)
if topK > 1 && cfg.NormTopKProb {
sumScores := mlx.Sum(scores, -1, true)
scores = mlx.Div(scores, sumScores)
}
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
return inds, scores
}
// SwitchMLP implements the MoE expert computation using stacked weights
type SwitchMLP struct {
GateWeight *mlx.Array
UpWeight *mlx.Array
DownWeight *mlx.Array
GateWeightQ, GateScales, GateBiases *mlx.Array
UpWeightQ, UpScales, UpBiases *mlx.Array
DownWeightQ, DownScales, DownBiases *mlx.Array
GateBits int
UpBits int
DownBits int
GateGroupSize int
UpGroupSize int
DownGroupSize int
UseQuantized bool
}
// Forward applies the switched expert MLP
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
dims := x.Dims()
B, L := int32(dims[0]), int32(dims[1])
topK := cfg.NumExpertsPerTok
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
idxFlat := mlx.Reshape(indices, B*L, topK)
doSort := B*L >= 64
var invOrder *mlx.Array
n := B * L * topK
if doSort {
idxAll := mlx.Flatten(idxFlat)
order := mlx.Argsort(idxAll, 0)
invOrder = mlx.Argsort(order, 0)
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
}
var gate, up, hidden, down *mlx.Array
if s.UseQuantized {
gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
} else {
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
}
if doSort {
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
} else {
down = mlx.Squeeze(down, 2)
}
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
}
// SharedExperts implements the shared expert MLP
type SharedExperts struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
// Forward applies the shared expert MLP
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(s.GateProj.Forward(x))
up := s.UpProj.Forward(x)
return s.DownProj.Forward(mlx.Mul(gate, up))
}
// MoE implements the full Mixture of Experts layer
type MoE struct {
Gate *MoEGate
SwitchMLP *SwitchMLP
SharedExperts *SharedExperts
}
// Forward applies the MoE layer
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
dims := x.Dims()
B, L := int32(dims[0]), int32(dims[1])
inds, scores := m.Gate.Forward(x, cfg)
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
scoresExpanded := mlx.ExpandDims(scores, -1)
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
if m.SharedExperts != nil {
y = mlx.Add(y, m.SharedExperts.Forward(x))
}
return mlx.Reshape(y, B, L, cfg.HiddenSize)
}
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
type DenseBlock struct {
Attention *MLAAttention
MLP *DenseMLP
InputLayerNorm *nn.RMSNorm
PostAttentionLayerNorm *nn.RMSNorm
}
// Forward applies the dense block
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
h := mlx.Add(x, r)
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
return mlx.Add(h, r)
}
// MoEBlock represents a MoE transformer block
type MoEBlock struct {
Attention *MLAAttention
MoE *MoE
InputLayerNorm *nn.RMSNorm
PostAttentionLayerNorm *nn.RMSNorm
}
// Forward applies the MoE block
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
h := mlx.Add(x, r)
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
return mlx.Add(h, r)
}
// Block interface for both dense and MoE blocks
type Block interface {
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
}
// Model represents the complete GLM4-MoE-Lite model
type Model struct {
EmbedTokens *nn.Embedding
Layers []Block
Norm *nn.RMSNorm
LMHead nn.LinearLayer
tok *tokenizer.Tokenizer
*Config
}
// computeScale computes the attention scale.
func computeScale(cfg *Config) float32 {
keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
scale := float32(1.0 / math.Sqrt(float64(keyLength)))
if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
scale *= s * s
}
return scale
}
// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
func supportsGatherQMM(mode string, bits int) bool {
return mode == "affine" && (bits == 4 || bits == 8)
}
// quantizationParams returns groupSize, bits, mode for a quantization type string.
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "NVFP4":
return 16, 4, "nvfp4"
case "FP4", "Q4", "INT4":
return 32, 4, "affine"
case "MXFP8":
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8", "":
return 64, 8, "affine"
default:
return 32, 8, "affine"
}
}
// readBlobMetadata reads the __metadata__ from a safetensors blob header.
func readBlobMetadata(path string) (map[string]string, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return nil, err
}
if headerSize > 1024*1024 {
return nil, fmt.Errorf("header too large: %d", headerSize)
}
data := make([]byte, headerSize)
if _, err := io.ReadFull(f, data); err != nil {
return nil, err
}
var header map[string]json.RawMessage
if err := json.Unmarshal(data, &header); err != nil {
return nil, err
}
metaRaw, ok := header["__metadata__"]
if !ok {
return nil, nil
}
var meta map[string]string
if err := json.Unmarshal(metaRaw, &meta); err != nil {
return nil, err
}
return meta, nil
}
// ExpertWeight holds a single expert's weight with optional quantization components.
type ExpertWeight struct {
Weight *mlx.Array
Scales *mlx.Array
Biases *mlx.Array
Bits int
GroupSize int
}
// loadExpertWeight loads an expert weight from the tensor map.
func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized bool, cfg *Config) *ExpertWeight {
w := tensors[path+".weight"]
if w == nil {
return nil
}
scales := tensors[path+".weight_scale"]
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
groupSize, bits, mode := cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode
if useQuantized && supportsGatherQMM(mode, bits) {
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
}
return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
}
return &ExpertWeight{Weight: w}
}
// StackedExpertWeights holds stacked weights for all experts.
type StackedExpertWeights struct {
Weight *mlx.Array
Scales *mlx.Array
Biases *mlx.Array
Bits int
GroupSize int
}
// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
func collectAndStackExpertWeights(
tensors map[string]*mlx.Array,
prefix string,
projName string,
numExperts int32,
useQuantized bool,
cfg *Config,
) *StackedExpertWeights {
var w, s, b []*mlx.Array
var bits, groupSize int
for e := int32(0); e < numExperts; e++ {
path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
ew := loadExpertWeight(tensors, path, useQuantized, cfg)
if ew == nil {
continue
}
w = append(w, ew.Weight)
if ew.Scales != nil {
s = append(s, ew.Scales)
}
if ew.Biases != nil {
b = append(b, ew.Biases)
}
if e == 0 {
bits = ew.Bits
groupSize = ew.GroupSize
}
}
result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
if len(w) > 0 {
result.Weight = mlx.Stack(w, 0)
if len(s) > 0 {
result.Scales = mlx.Stack(s, 0)
}
if len(b) > 0 {
result.Biases = mlx.Stack(b, 0)
}
}
return result
}
// sanitizeExpertWeights stacks individual expert weights into tensors.
func sanitizeExpertWeights(tensors map[string]*mlx.Array, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
gate = collectAndStackExpertWeights(tensors, prefix, "gate_proj", numExperts, useQuantized, cfg)
up = collectAndStackExpertWeights(tensors, prefix, "up_proj", numExperts, useQuantized, cfg)
down = collectAndStackExpertWeights(tensors, prefix, "down_proj", numExperts, useQuantized, cfg)
return gate, up, down
}
// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
path := prefix + ".self_attn.kv_b_proj"
w := tensors[path+".weight"]
if w == nil {
return nil, nil
}
// Check if quantized and dequantize
if scales := tensors[path+".weight_scale"]; scales != nil {
qbiases := tensors[path+".weight_qbias"]
w = mlx.Dequantize(w, scales, qbiases, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode)
}
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
wk := mlx.SliceStartStop(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
wv := mlx.SliceStartStop(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
embedQ := mlx.Transpose(wk, 0, 2, 1)
unembedOut := wv
return embedQ, unembedOut
}
// makeLinear creates a Linear or QuantizedLinear layer from the tensor map.
func makeLinear(tensors map[string]*mlx.Array, path string, cfg *Config) nn.LinearLayer {
w := tensors[path+".weight"]
if w == nil {
return nil
}
scales := tensors[path+".weight_scale"]
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
bias := tensors[path+".bias"]
return &nn.QuantizedLinear{
Weight: w,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: cfg.QuantGroupSize,
Bits: cfg.QuantBits,
Mode: cfg.QuantMode,
}
}
bias := tensors[path+".bias"]
return nn.NewLinear(w, bias)
}
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
configData, err := modelManifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
cfg.Scale = computeScale(&cfg)
// Load all tensors from manifest blobs into a flat map
allTensors := make(map[string]*mlx.Array)
seen := make(map[string]bool) // dedupe by digest
var quantType string
var quantGroupSize int
for _, layer := range modelManifest.GetTensorLayers("") {
if seen[layer.Digest] {
continue
}
seen[layer.Digest] = true
blobPath := modelManifest.BlobPath(layer.Digest)
// Read quantization metadata from first blob
if quantType == "" {
if meta, err := readBlobMetadata(blobPath); err == nil && meta != nil {
if qt := meta["quant_type"]; qt != "" {
quantType = strings.ToUpper(qt)
}
if gs := meta["group_size"]; gs != "" {
fmt.Sscanf(gs, "%d", &quantGroupSize)
}
}
}
for name, arr := range mlx.Load(blobPath) {
// Map safetensors key naming to our naming convention
// Combined blobs use ".scale" and ".bias" suffixes
if strings.HasSuffix(name, ".scale") {
baseName := strings.TrimSuffix(name, ".scale")
allTensors[baseName+"_scale"] = arr
} else if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") {
// Check if this is a quantization bias or a regular bias
// by checking if there's a corresponding weight
baseName := strings.TrimSuffix(name, ".bias")
if _, hasScale := allTensors[baseName+"_scale"]; hasScale {
allTensors[baseName+"_qbias"] = arr
} else {
allTensors[name] = arr
}
} else {
allTensors[name] = arr
}
}
}
// Set up quantization parameters
useQuantized := false
if quantType != "" {
_, cfg.QuantBits, cfg.QuantMode = quantizationParams(quantType)
if quantGroupSize > 0 {
cfg.QuantGroupSize = quantGroupSize
} else {
cfg.QuantGroupSize, _, _ = quantizationParams(quantType)
}
useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
}
// Load tokenizer
tokData, err := modelManifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{
ConfigJSON: configData,
}
if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]Block, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
// Load embedding
if w := allTensors["model.embed_tokens.weight"]; w != nil {
m.EmbedTokens = nn.NewEmbedding(w)
}
// Load final norm
if w := allTensors["model.norm.weight"]; w != nil {
m.Norm = nn.NewRMSNorm(w, cfg.RMSNormEps)
}
// Load LM head
m.LMHead = makeLinear(allTensors, "lm_head", &cfg)
// Load layers
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
// Load attention (same for both block types)
attn := &MLAAttention{}
attn.QAProj = makeLinear(allTensors, prefix+".self_attn.q_a_proj", &cfg)
if w := allTensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
}
attn.QBProj = makeLinear(allTensors, prefix+".self_attn.q_b_proj", &cfg)
attn.KVAProjWithMQA = makeLinear(allTensors, prefix+".self_attn.kv_a_proj_with_mqa", &cfg)
if w := allTensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
}
attn.OProj = makeLinear(allTensors, prefix+".self_attn.o_proj", &cfg)
// Sanitize MLA weights for absorbed attention
embedQ, unembedOut := sanitizeMLAWeights(allTensors, prefix, &cfg)
attn.EmbedQ = nn.NewMultiLinear(embedQ)
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
inputLN := allTensors[prefix+".input_layernorm.weight"]
postAttnLN := allTensors[prefix+".post_attention_layernorm.weight"]
if i < cfg.FirstKDenseReplace {
// Dense block
block := &DenseBlock{Attention: attn}
if inputLN != nil {
block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps)
}
if postAttnLN != nil {
block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps)
}
block.MLP = &DenseMLP{
GateProj: makeLinear(allTensors, prefix+".mlp.gate_proj", &cfg),
UpProj: makeLinear(allTensors, prefix+".mlp.up_proj", &cfg),
DownProj: makeLinear(allTensors, prefix+".mlp.down_proj", &cfg),
}
m.Layers[i] = block
} else {
// MoE block
block := &MoEBlock{Attention: attn}
if inputLN != nil {
block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps)
}
if postAttnLN != nil {
block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps)
}
// Stack expert weights
gate, up, down := sanitizeExpertWeights(allTensors, prefix, cfg.NRoutedExperts, useQuantized, &cfg)
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
if useQuantized {
switchMLP.GateWeightQ = gate.Weight
switchMLP.GateScales = gate.Scales
switchMLP.GateBiases = gate.Biases
switchMLP.GateBits = gate.Bits
switchMLP.GateGroupSize = gate.GroupSize
switchMLP.UpWeightQ = up.Weight
switchMLP.UpScales = up.Scales
switchMLP.UpBiases = up.Biases
switchMLP.UpBits = up.Bits
switchMLP.UpGroupSize = up.GroupSize
switchMLP.DownWeightQ = down.Weight
switchMLP.DownScales = down.Scales
switchMLP.DownBiases = down.Biases
switchMLP.DownBits = down.Bits
switchMLP.DownGroupSize = down.GroupSize
} else {
switchMLP.GateWeight = gate.Weight
switchMLP.UpWeight = up.Weight
switchMLP.DownWeight = down.Weight
}
moeGate := &MoEGate{}
moeGate.Gate = makeLinear(allTensors, prefix+".mlp.gate", &cfg)
if bias := allTensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
moeGate.EScoreCorrectionBias = bias
}
block.MoE = &MoE{
Gate: moeGate,
SwitchMLP: switchMLP,
}
// Load shared experts if present
if cfg.NSharedExperts > 0 {
block.MoE.SharedExperts = &SharedExperts{
GateProj: makeLinear(allTensors, prefix+".mlp.shared_experts.gate_proj", &cfg),
UpProj: makeLinear(allTensors, prefix+".mlp.shared_experts.up_proj", &cfg),
DownProj: makeLinear(allTensors, prefix+".mlp.shared_experts.down_proj", &cfg),
}
}
m.Layers[i] = block
}
}
mlx.Eval(mlx.Collect(m)...)
return m, nil
}
// Forward computes the forward pass of the model
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
}
h = m.Norm.Forward(h, m.RMSNormEps)
return h
}
// Unembed applies the LM head to get logits.
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
// NumLayers returns the number of transformer layers
func (m *Model) NumLayers() int { return len(m.Layers) }
// MaxContextLength returns the maximum context length
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
// VocabSize returns the vocabulary size
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
// Tokenizer returns the model's tokenizer
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
// NewCache creates a new KV cache for the model
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
func (m *Model) FormatPrompt(prompt string) string {
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
}
// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
if think {
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
}
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|></think>"
}
// NewRenderer returns a new Renderer for formatting multi-turn conversations.
func (m *Model) NewRenderer() *Renderer {
return &Renderer{}
}
// NewParser returns a new Parser for extracting thinking and tool calls from output.
func (m *Model) NewParser() *Parser {
return &Parser{}
}

View File

@@ -0,0 +1,479 @@
//go:build mlx
package glm4_moe_lite
import (
"context"
"encoding/json"
"encoding/xml"
"fmt"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type parserState int
const (
parserState_LookingForThinkingOpen parserState = iota
parserState_ThinkingStartedEatingWhitespace
parserState_CollectingThinking
parserState_ThinkingDoneEatingWhitespace
parserState_CollectingContent
parserState_ToolStartedEatingWhitespace
parserState_CollectingToolContent
)
const (
thinkingOpenTag = "<think>"
thinkingCloseTag = "</think>"
toolOpenTag = "<tool_call>"
toolCloseTag = "</tool_call>"
)
// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
// GLM-4's prompt ends with <think> when thinking is enabled, so the parser
// must start in CollectingThinking state (the model outputs thinking content directly).
type Parser struct {
state parserState
buffer strings.Builder
tools []api.Tool
}
// HasToolSupport returns true as GLM4 supports tool calling.
func (p *Parser) HasToolSupport() bool {
return true
}
// HasThinkingSupport returns true as GLM4 supports thinking mode.
func (p *Parser) HasThinkingSupport() bool {
return true
}
// Init initializes the parser with tools and thinking configuration.
func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
// When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() {
p.state = parserState_CollectingThinking
}
return tools
}
type parserEvent interface {
isParserEvent()
}
type eventContent struct {
content string
}
func (eventContent) isParserEvent() {}
type eventRawToolCall struct {
raw string
}
func (eventRawToolCall) isParserEvent() {}
type eventThinkingContent struct {
content string
}
func (eventThinkingContent) isParserEvent() {}
// Add processes new output text and returns parsed content, thinking, and tool calls.
func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case eventRawToolCall:
toolCall, err := parseToolCall(event, p.tools)
if err != nil {
slog.Warn("glm-4 tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case eventThinkingContent:
thinkingSb.WriteString(event.content)
case eventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *Parser) parseEvents() []parserEvent {
var all []parserEvent
keepLooping := true
for keepLooping {
var events []parserEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
// and transitions to the next state. Returns (nil, false) if only whitespace remains
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false // Still only whitespace, keep waiting for more input
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true // Successfully transitioned
}
// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
// the content after (optionally trimmed of leading whitespace), and updates the buffer
func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
p.buffer.Reset()
p.buffer.WriteString(after)
return before, after
}
func (p *Parser) eat() ([]parserEvent, bool) {
var events []parserEvent
switch p.state {
case parserState_LookingForThinkingOpen:
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, thinkingOpenTag) {
// Found <think> opening tag
after := strings.TrimPrefix(trimmed, thinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
p.state = parserState_ThinkingStartedEatingWhitespace
} else {
p.state = parserState_CollectingThinking
}
return events, true
} else if strings.HasPrefix(thinkingOpenTag, trimmed) {
// Partial opening tag seen, keep accumulating
return events, false
} else if trimmed == "" {
// Only whitespace, keep accumulating
return events, false
} else {
// No thinking tag found, skip to content collection
p.state = parserState_CollectingContent
// Don't trim - we want to keep the original content
return events, true
}
case parserState_ThinkingStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
case parserState_CollectingThinking:
acc := p.buffer.String()
if strings.Contains(acc, thinkingCloseTag) {
thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, eventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = parserState_ThinkingDoneEatingWhitespace
} else {
p.state = parserState_CollectingContent
}
return events, true
} else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
// Partial closing tag - withhold it along with any trailing whitespace before it
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventThinkingContent{content: unambiguous})
}
return events, false
} else {
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventThinkingContent{content: unambiguous})
}
return events, false
}
case parserState_ThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
case parserState_CollectingContent:
if strings.Contains(p.buffer.String(), toolOpenTag) {
before, after := p.splitAtTag(toolOpenTag, true)
if len(before) > 0 {
events = append(events, eventContent{content: before})
}
if after == "" {
p.state = parserState_ToolStartedEatingWhitespace
} else {
p.state = parserState_CollectingToolContent
}
return events, true
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventContent{content: unambiguous})
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventContent{content: unambiguous})
}
return events, false
}
case parserState_ToolStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
case parserState_CollectingToolContent:
acc := p.buffer.String()
if strings.Contains(acc, toolCloseTag) {
toolContent, _ := p.splitAtTag(toolCloseTag, true)
if len(toolContent) == 0 {
slog.Warn("glm4 tool call closing tag found but no content before it")
}
events = append(events, eventRawToolCall{raw: toolContent})
p.state = parserState_CollectingContent
return events, true
} else {
// Keep accumulating - tool calls are not streamed
// We just wait for the closing tag
return events, false
}
default:
panic("unreachable")
}
}
// overlap returns the length of the overlap between the end of s and the start of tag.
func overlap(s, tag string) int {
for i := 1; i <= len(tag) && i <= len(s); i++ {
if strings.HasSuffix(s, tag[:i]) {
return i
}
}
return 0
}
// trailingWhitespaceLen returns the length of trailing whitespace in s.
func trailingWhitespaceLen(s string) int {
trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
return len(s) - len(trimmed)
}
// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
type ToolCallXML struct {
XMLName xml.Name `xml:"tool_call"`
Content string `xml:",chardata"` // Function name (text nodes between tags)
Keys []string `xml:"arg_key"` // All arg_key elements in document order
Values []string `xml:"arg_value"` // All arg_value elements in document order
}
// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
func escapeContent(s string) string {
var result strings.Builder
inTag := false
for i := range len(s) {
ch := s[i]
if ch == '<' {
// Check if this is a known tag
if strings.HasPrefix(s[i:], "<arg_key>") ||
strings.HasPrefix(s[i:], "</arg_key>") ||
strings.HasPrefix(s[i:], "<arg_value>") ||
strings.HasPrefix(s[i:], "</arg_value>") {
inTag = true
}
}
if inTag {
result.WriteByte(ch)
if ch == '>' {
inTag = false
}
} else {
// Escape special characters in text content
switch ch {
case '&':
result.WriteString("&amp;")
case '<':
result.WriteString("&lt;")
case '>':
result.WriteString("&gt;")
default:
result.WriteByte(ch)
}
}
}
return result.String()
}
func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
// Escape any unescaped entities in text content
escaped := escapeContent(raw.raw)
// Wrap the content in a root element to make it valid XML
xmlString := "<tool_call>" + escaped + "</tool_call>"
// Parse XML into struct
var parsed ToolCallXML
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
}
// Extract and trim function name
functionName := strings.TrimSpace(parsed.Content)
if functionName == "" {
return api.ToolCall{}, fmt.Errorf("empty function name")
}
// Verify keys and values are paired correctly
if len(parsed.Keys) != len(parsed.Values) {
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
}
// Find the matching tool to get parameter types
var matchedTool *api.Tool
for i := range tools {
if tools[i].Function.Name == functionName {
matchedTool = &tools[i]
break
}
}
// Build arguments map by pairing keys and values
toolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: functionName,
Arguments: api.NewToolCallFunctionArguments(),
},
}
for i := range parsed.Keys {
key := strings.TrimSpace(parsed.Keys[i])
value := parsed.Values[i] // Don't trim here - parseValue handles it
// Look up parameter type
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
// Handle anyOf by collecting all types from the union
if len(prop.AnyOf) > 0 {
for _, anyOfProp := range prop.AnyOf {
paramType = append(paramType, anyOfProp.Type...)
}
} else {
paramType = prop.Type
}
}
}
// Parse value with type coercion
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
}
return toolCall, nil
}
// parseValue parses a string value and coerces it to the appropriate type based on paramType.
func parseValue(value string, paramType api.PropertyType) any {
value = strings.TrimSpace(value)
// If no type specified, return as string
if len(paramType) == 0 {
return value
}
// Try to parse based on specified types
for _, t := range paramType {
switch t {
case "boolean":
if value == "true" {
return true
}
if value == "false" {
return false
}
case "integer":
var i int64
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
return i
}
case "number":
var f float64
if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
return f
}
case "array", "object":
// Try to parse as JSON
var result any
if err := json.Unmarshal([]byte(value), &result); err == nil {
return result
}
}
}
// Default to string
return value
}

View File

@@ -0,0 +1,192 @@
//go:build mlx
package glm4_moe_lite
import (
"testing"
"github.com/ollama/ollama/api"
)
func TestParserThinking(t *testing.T) {
tests := []struct {
name string
input string
thinkEnabled bool
wantContent string
wantThinking string
wantToolCalls int
}{
{
name: "thinking enabled - simple thinking then content",
input: "Let me think about this...</think>Here is my answer.",
thinkEnabled: true,
wantThinking: "Let me think about this...",
wantContent: "Here is my answer.",
},
{
name: "thinking enabled - only thinking",
input: "I need to consider multiple factors...",
thinkEnabled: true,
wantThinking: "I need to consider multiple factors...",
wantContent: "",
},
{
name: "thinking disabled - direct content",
input: "Here is my direct answer.",
thinkEnabled: false,
wantThinking: "",
wantContent: "Here is my direct answer.",
},
{
name: "thinking with tool call",
input: "Let me search for that...</think>I'll use a tool.<tool_call>search<arg_key>query</arg_key><arg_value>test</arg_value></tool_call>",
thinkEnabled: true,
wantThinking: "Let me search for that...",
wantContent: "I'll use a tool.",
wantToolCalls: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Parser{}
var thinkValue *api.ThinkValue
if tt.thinkEnabled {
thinkValue = &api.ThinkValue{Value: true}
} else {
thinkValue = &api.ThinkValue{Value: false}
}
// Define tools for tool call tests
props := api.NewToolPropertiesMap()
props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "search",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
},
}
p.Init(tools, nil, thinkValue)
content, thinking, calls, err := p.Add(tt.input, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if thinking != tt.wantThinking {
t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
}
if content != tt.wantContent {
t.Errorf("content = %q, want %q", content, tt.wantContent)
}
if len(calls) != tt.wantToolCalls {
t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
}
})
}
}
func TestParserToolCall(t *testing.T) {
p := &Parser{}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
},
}
// Initialize with thinking disabled
tv := &api.ThinkValue{Value: false}
p.Init(tools, nil, tv)
input := "<tool_call>get_weather<arg_key>location</arg_key><arg_value>San Francisco</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call>"
_, _, calls, err := p.Add(input, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
call := calls[0]
if call.Function.Name != "get_weather" {
t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
}
location, ok := call.Function.Arguments.Get("location")
if !ok || location != "San Francisco" {
t.Errorf("location = %v, want %q", location, "San Francisco")
}
unit, ok := call.Function.Arguments.Get("unit")
if !ok || unit != "celsius" {
t.Errorf("unit = %v, want %q", unit, "celsius")
}
}
func TestOverlap(t *testing.T) {
tests := []struct {
s string
tag string
want int
}{
{"hello<", "</think>", 1},
{"hello</", "</think>", 2},
{"hello</t", "</think>", 3},
{"hello</th", "</think>", 4},
{"hello</thi", "</think>", 5},
{"hello</thin", "</think>", 6},
{"hello</think", "</think>", 7},
{"hello</think>", "</think>", 8}, // Complete tag at end returns full length
{"hello", "</think>", 0},
{"", "</think>", 0},
}
for _, tt := range tests {
t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
got := overlap(tt.s, tt.tag)
if got != tt.want {
t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
}
})
}
}
func TestTrailingWhitespaceLen(t *testing.T) {
tests := []struct {
s string
want int
}{
{"hello ", 3},
{"hello\n\t ", 3},
{"hello", 0},
{"", 0},
{" ", 3},
}
for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) {
got := trailingWhitespaceLen(tt.s)
if got != tt.want {
t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,175 @@
//go:build mlx
package glm4_moe_lite
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
// Renderer renders messages for GLM4-MoE-Lite models.
//
// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
//
// 1. INTERLEAVED THINKING
// The model thinks between tool calls and after receiving tool results.
// This enables complex step-by-step reasoning: interpreting each tool output
// before deciding what to do next. Thinking blocks are preserved and returned
// with tool results to maintain reasoning continuity.
//
// 2. PRESERVED THINKING
// The model retains reasoning content from previous assistant turns in context.
// This preserves reasoning continuity across multi-turn conversations. The
// upstream API has a "clear_thinking" parameter to control this:
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
//
// 3. TURN-LEVEL THINKING
// Controls whether the model should reason on each turn. The upstream API
// uses "enable_thinking" parameter:
// - enable_thinking=true: outputs <think> to start reasoning
// - enable_thinking=false: outputs </think> to skip reasoning
//
// OLLAMA DEFAULTS:
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
// - Users can disable thinking per-turn via thinkValue=false
type Renderer struct{}
// Render renders messages into the GLM4 chat format.
func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(formatToolJSON(d))
sb.WriteString("\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
}
think := true
if thinkValue != nil && !thinkValue.Bool() {
think = false
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>")
sb.WriteString(message.Content)
case "assistant":
sb.WriteString("<|assistant|>")
if message.Thinking != "" {
sb.WriteString("<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("</think>")
}
if message.Content != "" {
sb.WriteString(message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("<tool_call>" + toolCall.Function.Name)
sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("<tool_response>")
sb.WriteString(message.Content)
sb.WriteString("</tool_response>")
case "system":
sb.WriteString("<|system|>")
sb.WriteString(message.Content)
}
}
sb.WriteString("<|assistant|>")
if think {
sb.WriteString("<think>")
} else {
sb.WriteString("</think>")
}
return sb.String(), nil
}
// renderToolArguments converts tool call arguments to GLM4 XML format.
func renderToolArguments(args api.ToolCallFunctionArguments) string {
var sb strings.Builder
for key, value := range args.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
}
return sb.String()
}
// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
func formatToolJSON(raw []byte) string {
var sb strings.Builder
sb.Grow(len(raw) + len(raw)/10)
inString := false
escaped := false
for i := range raw {
ch := raw[i]
sb.WriteByte(ch)
if inString {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
continue
}
if ch == ':' || ch == ',' {
sb.WriteByte(' ')
}
}
return sb.String()
}

View File

@@ -0,0 +1,205 @@
//go:build mlx
package glm4_moe_lite
import (
"strings"
"testing"
"github.com/ollama/ollama/api"
)
func TestRendererSimple(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "Hello"},
}
// Thinking enabled (default)
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := "[gMASK]<sop><|user|>Hello<|assistant|><think>"
if result != expected {
t.Errorf("result = %q, want %q", result, expected)
}
}
func TestRendererThinkingDisabled(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "Hello"},
}
tv := &api.ThinkValue{Value: false}
result, err := r.Render(messages, nil, tv)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := "[gMASK]<sop><|user|>Hello<|assistant|></think>"
if result != expected {
t.Errorf("result = %q, want %q", result, expected)
}
}
func TestRendererMultiTurn(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "What is 2+2?"},
{Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
{Role: "user", Content: "And 3+3?"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check key parts
if !strings.Contains(result, "[gMASK]<sop>") {
t.Error("missing [gMASK]<sop> prefix")
}
if !strings.Contains(result, "<|user|>What is 2+2?") {
t.Error("missing first user message")
}
if !strings.Contains(result, "<|assistant|><think>Let me calculate: 2+2=4</think>4") {
t.Error("missing assistant message with thinking")
}
if !strings.Contains(result, "<|user|>And 3+3?") {
t.Error("missing second user message")
}
if !strings.HasSuffix(result, "<|assistant|><think>") {
t.Errorf("should end with <|assistant|><think>, got suffix: %q", result[len(result)-30:])
}
}
func TestRendererWithSystem(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
t.Error("missing system message")
}
}
func TestRendererWithTools(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "What's the weather?"},
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the weather for a location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
Required: []string{"location"},
},
},
},
}
result, err := r.Render(messages, tools, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check for tool system prompt
if !strings.Contains(result, "<|system|>") {
t.Error("missing system tag for tools")
}
if !strings.Contains(result, "# Tools") {
t.Error("missing tools header")
}
if !strings.Contains(result, "<tools>") {
t.Error("missing tools tag")
}
if !strings.Contains(result, "get_weather") {
t.Error("missing tool name")
}
if !strings.Contains(result, "</tools>") {
t.Error("missing closing tools tag")
}
}
func TestRendererWithToolCalls(t *testing.T) {
r := &Renderer{}
args := api.NewToolCallFunctionArguments()
args.Set("location", "San Francisco")
messages := []api.Message{
{Role: "user", Content: "What's the weather in SF?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
},
},
},
{Role: "tool", Content: "Sunny, 72F"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "<tool_call>get_weather") {
t.Error("missing tool call")
}
if !strings.Contains(result, "<arg_key>location</arg_key>") {
t.Error("missing arg_key")
}
if !strings.Contains(result, "<arg_value>San Francisco</arg_value>") {
t.Error("missing arg_value")
}
if !strings.Contains(result, "</tool_call>") {
t.Error("missing tool call closing tag")
}
if !strings.Contains(result, "<|observation|>") {
t.Error("missing observation tag")
}
if !strings.Contains(result, "<tool_response>Sunny, 72F</tool_response>") {
t.Error("missing tool response")
}
}
func TestFormatToolJSON(t *testing.T) {
input := []byte(`{"name":"test","value":123}`)
result := formatToolJSON(input)
// Should add spaces after : and ,
if !strings.Contains(result, ": ") {
t.Error("should add space after colon")
}
if !strings.Contains(result, ", ") {
t.Error("should add space after comma")
}
}

188
x/models/nn/nn.go Normal file
View File

@@ -0,0 +1,188 @@
//go:build mlx
package nn
import "github.com/ollama/ollama/x/mlxrunner/mlx"
// Layer is the interface for neural network layers with a Forward method.
type Layer interface {
Forward(x *mlx.Array) *mlx.Array
}
// LinearLayer is an interface for linear layers (both regular and quantized).
type LinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
OutputDim() int32
}
// Linear applies an affine transformation: y = x @ W.T + b
type Linear struct {
Weight *mlx.Array
Bias *mlx.Array
}
func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
return &Linear{Weight: weight, Bias: bias}
}
func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
w := l.Weight.Transpose(1, 0)
if l.Bias != nil && l.Bias.Valid() {
return l.Bias.Addmm(x, w, 1.0, 1.0)
}
return x.Matmul(w)
}
func (l *Linear) OutputDim() int32 {
return int32(l.Weight.Dim(0))
}
// QuantizedLinear applies an affine transformation using quantized weights.
type QuantizedLinear struct {
Weight *mlx.Array // Quantized weight data
Scales *mlx.Array // Scale factors for dequantization
QBiases *mlx.Array // Quantization biases (nil for nvfp4)
Bias *mlx.Array // Layer bias [output_dims] or nil
GroupSize int
Bits int
Mode string
}
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
if qbiases != nil {
mlx.Eval(qw, scales, qbiases)
} else {
mlx.Eval(qw, scales)
}
return &QuantizedLinear{
Weight: qw,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
if ql.Bias != nil && ql.Bias.Valid() {
out = out.Add(ql.Bias)
}
return out
}
func (ql *QuantizedLinear) OutputDim() int32 {
return int32(ql.Weight.Dim(0))
}
// RMSNorm represents an RMS normalization layer.
type RMSNorm struct {
Weight *mlx.Array
Eps float32
}
func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm {
return &RMSNorm{Weight: weight, Eps: eps}
}
func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array {
if eps == 0 {
eps = rn.Eps
}
return mlx.RMSNormFn(x, rn.Weight, eps)
}
// Embedding represents an embedding layer.
type Embedding struct {
Weight *mlx.Array
}
func NewEmbedding(weight *mlx.Array) *Embedding {
return &Embedding{Weight: weight}
}
func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
return e.Weight.TakeAxis(indices, 0)
}
// LayerNorm represents a standard layer normalization layer (with bias).
type LayerNorm struct {
Weight *mlx.Array
Bias *mlx.Array
Eps float32
}
func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
eps := ln.Eps
if eps == 0 {
eps = 1e-5
}
mean := mlx.Mean(x, -1, true)
centered := x.Subtract(mean)
variance := mlx.Mean(centered.Multiply(centered), -1, true)
normalized := centered.Multiply(mlx.RSqrt(mlx.AddScalar(variance, eps)))
out := normalized.Multiply(ln.Weight)
if ln.Bias != nil && ln.Bias.Valid() {
out = out.Add(ln.Bias)
}
return out
}
// MultiLinearLayer is an interface for per-head linear layers.
type MultiLinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
}
// MultiLinear performs per-head linear projections.
// Weight shape: [num_heads, output_dims, input_dims]
type MultiLinear struct {
Weight *mlx.Array
}
func NewMultiLinear(weight *mlx.Array) *MultiLinear {
return &MultiLinear{Weight: weight}
}
func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
wT := ml.Weight.Transpose(0, 2, 1)
return x.Matmul(wT)
}
// RepeatKV repeats K/V tensors for grouped query attention.
func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array {
if repeatFactor == 1 {
return x
}
shape := x.Dims()
x = x.ExpandDims(2)
reps := []int32{1, 1, repeatFactor, 1, 1}
x = mlx.Tile(x, reps)
return mlx.Reshape(x, int32(shape[0]), int32(shape[1])*repeatFactor, int32(shape[2]), int32(shape[3]))
}
// ApplyCausalMask applies causal (lower triangular) mask to attention scores.
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
shape := scores.Dims()
seqLen := int32(shape[2])
mask := mlx.Tri(seqLen, seqLen, 0)
negInf := mlx.NewScalarArray(float32(-1e9))
mask = mask.ExpandDims(0).ExpandDims(0)
return mlx.Where(mask, scores, negInf)
}
// ApplyCausalMaskWithOffset applies causal mask for cached attention.
func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array {
if offset == 0 {
return ApplyCausalMask(scores)
}
shape := scores.Dims()
queryLen := int32(shape[2])
keyLen := int32(shape[3])
mask := mlx.Tri(queryLen, keyLen, int(offset))
negInf := mlx.NewScalarArray(float32(-1e9))
mask = mask.ExpandDims(0).ExpandDims(0)
return mlx.Where(mask, scores, negInf)
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"os"
"sort"
"strings"
"github.com/ollama/ollama/api"
@@ -105,9 +106,9 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
bytesPerParam = 1
}
// Subtract safetensors header overhead (88 bytes per tensor file)
// Each tensor is stored as a minimal safetensors file
totalBytes := totalTensorBytes - tensorCount*88
// Subtract safetensors header overhead per tensor blob.
// Headers include __metadata__ with the tensor name, so overhead is ~150 bytes on average.
totalBytes := totalTensorBytes - tensorCount*150
paramCount := totalBytes / bytesPerParam
@@ -163,24 +164,103 @@ func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
// getTensorInfoFromManifest extracts tensor info from a manifest.
// This is separated for testability.
// For quantized models, groups weight/scale/qbias into single entries with detected quantization type.
// For quantized tensors, reads quant_type from blob __metadata__.
// For packed blobs (multiple tensors per blob), enumerates all tensors in the blob.
func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
var tensors []api.Tensor
// First pass: collect all tensor info and identify scale tensors
type tensorData struct {
info *safetensorsTensorInfo
digest string
}
tensorMap := make(map[string]*tensorData)
scaleMap := make(map[string]*tensorData) // base name -> scale tensor info
for _, layer := range mf.Layers {
if layer.MediaType != manifest.MediaTypeImageTensor {
continue
}
// Read the safetensors header from the blob
// Read all tensor entries from the safetensors header
blobPath, err := manifest.BlobsPath(layer.Digest)
if err != nil {
continue
}
f, err := os.Open(blobPath)
if err != nil {
continue
}
allInfos, err := parseSafetensorsAllHeaders(f)
f.Close()
if err != nil {
continue
}
// Determine if this is a packed blob (multiple main tensors)
isPacked := len(allInfos) > 1
for _, info := range allInfos {
tensorName := layer.Name
if isPacked {
// For packed blobs, use the tensor name from the header
tensorName = info.Name
}
if info.QuantType != "" {
quantType := strings.ToUpper(info.QuantType)
shape := make([]uint64, len(info.Shape))
for i, s := range info.Shape {
shape[i] = uint64(s)
}
var packFactor int64
switch strings.ToLower(info.QuantType) {
case "int4", "nvfp4":
packFactor = 8
case "int8", "mxfp8":
packFactor = 4
}
if packFactor > 0 && len(shape) >= 2 {
shape[len(shape)-1] = uint64(info.Shape[len(info.Shape)-1] * packFactor)
}
tensors = append(tensors, api.Tensor{
Name: tensorName,
Type: quantType,
Shape: shape,
})
} else {
shape := make([]uint64, len(info.Shape))
for i, s := range info.Shape {
shape[i] = uint64(s)
}
tensors = append(tensors, api.Tensor{
Name: tensorName,
Type: info.Dtype,
Shape: shape,
})
}
}
}
sort.Slice(tensors, func(i, j int) bool {
return tensors[i].Name < tensors[j].Name
})
return tensors, nil
}
// GetSafetensorsDtype returns the quantization type for a safetensors model.
// Reads quant_type from the first tensor blob's __metadata__.
// Falls back to torch_dtype from config.json if no quant metadata.
func GetSafetensorsDtype(name model.Name) (string, error) {
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return "", fmt.Errorf("failed to load manifest: %w", err)
}
// Check first tensor blob for quant_type metadata
for _, layer := range mf.Layers {
if layer.MediaType != manifest.MediaTypeImageTensor {
continue
}
blobPath, err := manifest.BlobsPath(layer.Digest)
if err != nil {
continue
@@ -189,131 +269,11 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
if err != nil {
continue
}
td := &tensorData{info: info, digest: layer.Digest}
if strings.HasSuffix(layer.Name, "_scale") {
baseName := strings.TrimSuffix(layer.Name, "_scale")
scaleMap[baseName] = td
} else if strings.HasSuffix(layer.Name, "_qbias") {
// Skip qbias tensors - they're included with the quantized weight
continue
} else {
tensorMap[layer.Name] = td
if info.QuantType != "" {
return strings.ToUpper(info.QuantType), nil
}
}
// Second pass: build tensor list with quantization info
for _, layer := range mf.Layers {
if layer.MediaType != manifest.MediaTypeImageTensor {
continue
}
// Skip scale and qbias tensors
if strings.HasSuffix(layer.Name, "_scale") || strings.HasSuffix(layer.Name, "_qbias") {
continue
}
td := tensorMap[layer.Name]
if td == nil {
continue
}
// Check if this tensor has a corresponding scale tensor (quantized)
scaleTd := scaleMap[layer.Name]
if scaleTd != nil && len(td.info.Shape) >= 2 && len(scaleTd.info.Shape) >= 2 {
// Quantized tensor - detect bits from shapes
weightCols := td.info.Shape[len(td.info.Shape)-1]
scaleCols := scaleTd.info.Shape[len(scaleTd.info.Shape)-1]
// Detect quantization: Q4 has pack_factor=8, Q8 has pack_factor=4
// Q4 uses group_size=32: weightCols * 8 / scaleCols = 32
// Q8 uses group_size=64: weightCols * 4 / scaleCols = 64
var bits int
var quantType string
if weightCols*8/scaleCols == 32 {
bits = 4
quantType = "Q4"
} else if weightCols*4/scaleCols == 64 {
bits = 8
quantType = "Q8"
} else {
// Unknown quantization, show raw
quantType = td.info.Dtype
}
// Calculate unpacked shape
shape := make([]uint64, len(td.info.Shape))
for i, s := range td.info.Shape {
shape[i] = uint64(s)
}
if bits > 0 {
packFactor := int64(32 / bits)
shape[len(shape)-1] = uint64(td.info.Shape[len(td.info.Shape)-1] * packFactor)
}
tensors = append(tensors, api.Tensor{
Name: layer.Name,
Type: quantType,
Shape: shape,
})
} else {
// Non-quantized tensor
shape := make([]uint64, len(td.info.Shape))
for i, s := range td.info.Shape {
shape[i] = uint64(s)
}
tensors = append(tensors, api.Tensor{
Name: layer.Name,
Type: td.info.Dtype,
Shape: shape,
})
}
}
return tensors, nil
}
// GetSafetensorsDtype returns the quantization type for a safetensors model.
// Reads from model_index.json first, falls back to detection from tensor names.
// Otherwise returns the torch_dtype from config.json.
func GetSafetensorsDtype(name model.Name) (string, error) {
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return "", fmt.Errorf("failed to load manifest: %w", err)
}
// First try to read quantization from model_index.json
var modelIndex struct {
Quantization string `json:"quantization"`
}
if err := mf.ReadConfigJSON("model_index.json", &modelIndex); err == nil && modelIndex.Quantization != "" {
return modelIndex.Quantization, nil
}
// Fallback: detect from tensor names
hasScales := false
hasQBias := false
for _, layer := range mf.Layers {
if layer.MediaType == manifest.MediaTypeImageTensor {
if strings.HasSuffix(layer.Name, "_scale") {
hasScales = true
}
if strings.HasSuffix(layer.Name, "_qbias") {
hasQBias = true
}
}
}
if hasScales {
if hasQBias {
// Affine mode (has scale + qbias) - could be Q4 or Q8
// Default to Q4 as it's more common
return "Q4", nil
}
// No qbias = NVFP4
return "NVFP4", nil
// Only check the first tensor blob
break
}
// Not quantized - return torch_dtype from config.json
@@ -329,8 +289,11 @@ func GetSafetensorsDtype(name model.Name) (string, error) {
// safetensorsTensorInfo holds metadata about a tensor from a safetensors header
type safetensorsTensorInfo struct {
Dtype string `json:"dtype"`
Shape []int64 `json:"shape"`
Name string // tensor name from the header key
Dtype string `json:"dtype"`
Shape []int64 `json:"shape"`
QuantType string // from __metadata__.quant_type (e.g., "int4", "int8", "nvfp4", "mxfp8")
GroupSize string // from __metadata__.group_size (e.g., "32", "64")
}
// readSafetensorsHeader reads the JSON header from a safetensors file to get tensor metadata.
@@ -347,6 +310,7 @@ func readSafetensorsHeader(path string) (*safetensorsTensorInfo, error) {
// parseSafetensorsHeader parses a safetensors header from a reader.
// This is separated for testability.
// Parses __metadata__ for quant_type and group_size if present.
func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
// Read header size (8 bytes, little endian)
var headerSize uint64
@@ -371,7 +335,31 @@ func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
return nil, fmt.Errorf("failed to parse header: %w", err)
}
// Find the first (and should be only) tensor entry
// Parse metadata if present
var quantType, groupSize string
if metaRaw, ok := header["__metadata__"]; ok {
var meta map[string]string
if json.Unmarshal(metaRaw, &meta) == nil {
quantType = meta["quant_type"]
groupSize = meta["group_size"]
}
}
// Find the main tensor entry (not __metadata__, .scale, or .bias)
for name, raw := range header {
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
continue
}
var info safetensorsTensorInfo
if err := json.Unmarshal(raw, &info); err != nil {
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
}
info.QuantType = quantType
info.GroupSize = groupSize
return &info, nil
}
// Fall back to first non-metadata tensor entry
for name, raw := range header {
if name == "__metadata__" {
continue
@@ -380,8 +368,134 @@ func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
if err := json.Unmarshal(raw, &info); err != nil {
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
}
info.QuantType = quantType
info.GroupSize = groupSize
return &info, nil
}
return nil, fmt.Errorf("no tensor found in header")
}
// parseSafetensorsAllHeaders parses all tensor entries from a safetensors header.
// Returns one safetensorsTensorInfo per main tensor (skipping __metadata__, .scale, .bias).
// For packed blobs this returns multiple entries; for single-tensor blobs, one entry.
// Each tensor's quant type is inferred from its shape and the presence of .scale/.bias entries
// when no global __metadata__ quant_type is present.
func parseSafetensorsAllHeaders(r io.Reader) ([]safetensorsTensorInfo, error) {
var headerSize uint64
if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil {
return nil, fmt.Errorf("failed to read header size: %w", err)
}
if headerSize > 100*1024*1024 { // 100MB limit for packed blob headers
return nil, fmt.Errorf("header size too large: %d", headerSize)
}
headerBytes := make([]byte, headerSize)
if _, err := io.ReadFull(r, headerBytes); err != nil {
return nil, fmt.Errorf("failed to read header: %w", err)
}
var header map[string]json.RawMessage
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, fmt.Errorf("failed to parse header: %w", err)
}
// Parse global metadata if present
var globalQuantType, globalGroupSize string
if metaRaw, ok := header["__metadata__"]; ok {
var meta map[string]string
if json.Unmarshal(metaRaw, &meta) == nil {
globalQuantType = meta["quant_type"]
globalGroupSize = meta["group_size"]
}
}
// Build a set of all keys for checking .scale/.bias presence
headerKeys := make(map[string]bool, len(header))
for k := range header {
headerKeys[k] = true
}
// Collect all main tensor entries (sorted for deterministic output)
var mainNames []string
for name := range header {
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
continue
}
mainNames = append(mainNames, name)
}
sort.Strings(mainNames)
var results []safetensorsTensorInfo
for _, name := range mainNames {
var info safetensorsTensorInfo
if err := json.Unmarshal(header[name], &info); err != nil {
return nil, fmt.Errorf("failed to parse tensor info for %s: %w", name, err)
}
info.Name = name
if globalQuantType != "" {
// Use global metadata
info.QuantType = globalQuantType
info.GroupSize = globalGroupSize
} else if headerKeys[name+".scale"] {
// No global metadata, but has .scale - infer quant type from shape
info.QuantType = inferQuantType(header, name)
}
results = append(results, info)
}
if len(results) == 0 {
return nil, fmt.Errorf("no tensor found in header")
}
return results, nil
}
// inferQuantType infers the quantization type for a tensor from its shape and scale shape.
// Returns "int4", "int8", etc. or "" if not quantized.
func inferQuantType(header map[string]json.RawMessage, name string) string {
// Parse the main tensor shape
var mainInfo struct {
Shape []int64 `json:"shape"`
}
if json.Unmarshal(header[name], &mainInfo) != nil || len(mainInfo.Shape) < 2 {
return ""
}
// Parse scale shape to determine group size
scaleRaw, ok := header[name+".scale"]
if !ok {
return ""
}
var scaleInfo struct {
Shape []int64 `json:"shape"`
}
if json.Unmarshal(scaleRaw, &scaleInfo) != nil || len(scaleInfo.Shape) < 2 {
return ""
}
// Calculate group size: main_cols * pack_factor / scale_cols
// Main dtype is U32, so we need to figure out the pack factor
// For int4: pack=8, group=32. scale_cols = original_cols / 32 = main_cols * 8 / 32 = main_cols / 4
// For int8: pack=4, group=64. scale_cols = original_cols / 64 = main_cols * 4 / 64 = main_cols / 16
mainCols := mainInfo.Shape[len(mainInfo.Shape)-1]
scaleCols := scaleInfo.Shape[len(scaleInfo.Shape)-1]
if scaleCols == 0 {
return ""
}
ratio := mainCols / scaleCols // main_packed_cols / scale_cols
// int4: ratio = (orig/8) / (orig/32) = 32/8 = 4
// int8: ratio = (orig/4) / (orig/64) = 64/4 = 16
switch ratio {
case 4:
return "int4"
case 16:
return "int8"
default:
return ""
}
}

View File

@@ -36,7 +36,7 @@ func TestBuildModelInfo(t *testing.T) {
VocabSize: 262144,
TorchDtype: "bfloat16",
},
totalTensorBytes: 8_600_000_088, // ~4.3B params * 2 bytes + 88 bytes header
totalTensorBytes: 8_600_000_150, // ~4.3B params * 2 bytes + 150 bytes header
tensorCount: 1,
wantArch: "gemma3",
wantContextLen: 131072,
@@ -57,7 +57,7 @@ func TestBuildModelInfo(t *testing.T) {
VocabSize: 32000,
TorchDtype: "float16",
},
totalTensorBytes: 14_000_000_088, // ~7B params * 2 bytes + 88 bytes header
totalTensorBytes: 14_000_000_150, // ~7B params * 2 bytes + 150 bytes header
tensorCount: 1,
wantArch: "llama",
wantContextLen: 4096,
@@ -84,7 +84,7 @@ func TestBuildModelInfo(t *testing.T) {
VocabSize: 262144,
TorchDtype: "bfloat16",
},
totalTensorBytes: 8_600_000_088,
totalTensorBytes: 8_600_000_150,
tensorCount: 1,
wantArch: "gemma3",
wantContextLen: 131072,
@@ -101,7 +101,7 @@ func TestBuildModelInfo(t *testing.T) {
MaxPositionEmbeddings: 2048,
TorchDtype: "float32",
},
totalTensorBytes: 400_000_088, // 100M params * 4 bytes + 88 bytes header
totalTensorBytes: 400_000_150, // 100M params * 4 bytes + 150 bytes header
tensorCount: 1,
wantArch: "test",
wantContextLen: 2048,
@@ -118,7 +118,7 @@ func TestBuildModelInfo(t *testing.T) {
MaxPositionEmbeddings: 1024,
TorchDtype: "bfloat16",
},
totalTensorBytes: 2_000_880, // 1M params * 2 bytes + 10 tensors * 88 bytes
totalTensorBytes: 2_001_500, // 1M params * 2 bytes + 10 tensors * 150 bytes
tensorCount: 10,
wantArch: "test",
wantContextLen: 1024,
@@ -230,42 +230,42 @@ func TestBuildModelInfo_BytesPerParam(t *testing.T) {
{
name: "bfloat16",
dtype: "bfloat16",
totalBytes: 2_000_088, // 1M * 2 + 88
totalBytes: 2_000_150, // 1M * 2 + 150
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "float16",
dtype: "float16",
totalBytes: 2_000_088,
totalBytes: 2_000_150,
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "float32",
dtype: "float32",
totalBytes: 4_000_088, // 1M * 4 + 88
totalBytes: 4_000_150, // 1M * 4 + 150
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "int8",
dtype: "int8",
totalBytes: 1_000_088, // 1M * 1 + 88
totalBytes: 1_000_150, // 1M * 1 + 150
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "unknown dtype defaults to 2 bytes",
dtype: "unknown",
totalBytes: 2_000_088,
totalBytes: 2_000_150,
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "empty dtype defaults to 2 bytes",
dtype: "",
totalBytes: 2_000_088,
totalBytes: 2_000_150,
tensorCount: 1,
wantParamCount: 1_000_000,
},
@@ -288,11 +288,13 @@ func TestBuildModelInfo_BytesPerParam(t *testing.T) {
func TestParseSafetensorsHeader(t *testing.T) {
tests := []struct {
name string
header map[string]any
wantDtype string
wantShape []int64
wantErr bool
name string
header map[string]any
wantDtype string
wantShape []int64
wantQuantType string
wantGroupSize string
wantErr bool
}{
{
name: "simple tensor",
@@ -307,7 +309,70 @@ func TestParseSafetensorsHeader(t *testing.T) {
wantShape: []int64{2560, 262144},
},
{
name: "with metadata",
name: "tensor keyed by name",
header: map[string]any{
"model.layers.0.weight": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 2560},
"data_offsets": []int64{0, 13107200},
},
},
wantDtype: "BF16",
wantShape: []int64{2560, 2560},
},
{
name: "with int4 quant metadata",
header: map[string]any{
"__metadata__": map[string]any{
"quant_type": "int4",
"group_size": "32",
},
"model.layers.0.mlp.up_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{2560, 320},
"data_offsets": []int64{0, 3276800},
},
"model.layers.0.mlp.up_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 80},
"data_offsets": []int64{3276800, 3686400},
},
"model.layers.0.mlp.up_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 80},
"data_offsets": []int64{3686400, 4096000},
},
},
wantDtype: "U32",
wantShape: []int64{2560, 320},
wantQuantType: "int4",
wantGroupSize: "32",
},
{
name: "int8 quant metadata",
header: map[string]any{
"__metadata__": map[string]any{
"quant_type": "int8",
"group_size": "64",
},
"model.layers.0.mlp.down_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{2560, 640},
"data_offsets": []int64{0, 6553600},
},
"model.layers.0.mlp.down_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 40},
"data_offsets": []int64{6553600, 6963200},
},
},
wantDtype: "U32",
wantShape: []int64{2560, 640},
wantQuantType: "int8",
wantGroupSize: "64",
},
{
name: "with old-style format metadata",
header: map[string]any{
"__metadata__": map[string]any{
"format": "pt",
@@ -371,6 +436,13 @@ func TestParseSafetensorsHeader(t *testing.T) {
}
}
}
if info.QuantType != tt.wantQuantType {
t.Errorf("QuantType = %v, want %v", info.QuantType, tt.wantQuantType)
}
if info.GroupSize != tt.wantGroupSize {
t.Errorf("GroupSize = %v, want %v", info.GroupSize, tt.wantGroupSize)
}
})
}
}
@@ -460,7 +532,7 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
t.Fatalf("failed to create blobs dir: %v", err)
}
// Create test tensor blobs
// Create test tensor blobs with __metadata__
tensors := []struct {
name string
digest string
@@ -487,10 +559,9 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
},
}
// Create blob files
// Create blob files with tensor keyed by name
var layers []manifest.Layer
for _, tensor := range tensors {
// Create safetensors blob
header := map[string]any{
tensor.name: map[string]any{
"dtype": tensor.dtype,
@@ -561,6 +632,391 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
}
}
func TestGetTensorInfoFromManifest_Quantized(t *testing.T) {
// Create a temp directory for blobs and set OLLAMA_MODELS
tempDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", tempDir)
blobDir := filepath.Join(tempDir, "blobs")
if err := os.MkdirAll(blobDir, 0o755); err != nil {
t.Fatalf("failed to create blobs dir: %v", err)
}
// Create a combined quantized blob with __metadata__
header := map[string]any{
"__metadata__": map[string]string{
"quant_type": "int4",
"group_size": "32",
},
"model.layers.0.mlp.up_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{2560, 320}, // packed: 2560 / 8 = 320
"data_offsets": []int64{0, 3276800},
},
"model.layers.0.mlp.up_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 80}, // 2560 / 32 = 80
"data_offsets": []int64{3276800, 3686400},
},
"model.layers.0.mlp.up_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 80},
"data_offsets": []int64{3686400, 4096000},
},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
digest := "sha256:aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb"
blobPath, err := manifest.BlobsPath(digest)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob: %v", err)
}
mf := &manifest.Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Layers: []manifest.Layer{
{
MediaType: manifest.MediaTypeImageTensor,
Digest: digest,
Size: int64(buf.Len() + 4096000),
Name: "model.layers.0.mlp.up_proj.weight",
},
},
}
result, err := getTensorInfoFromManifest(mf)
if err != nil {
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
}
if len(result) != 1 {
t.Fatalf("got %d tensors, want 1", len(result))
}
tensor := result[0]
if tensor.Name != "model.layers.0.mlp.up_proj.weight" {
t.Errorf("Name = %v, want model.layers.0.mlp.up_proj.weight", tensor.Name)
}
if tensor.Type != "INT4" {
t.Errorf("Type = %v, want INT4", tensor.Type)
}
// Shape should be unpacked: 320 * 8 = 2560
if len(tensor.Shape) != 2 || tensor.Shape[0] != 2560 || tensor.Shape[1] != 2560 {
t.Errorf("Shape = %v, want [2560, 2560]", tensor.Shape)
}
}
func TestParseSafetensorsAllHeaders(t *testing.T) {
tests := []struct {
name string
header map[string]any
wantCount int
wantNames []string
wantDtypes []string
wantQuants []string
wantErr bool
}{
{
name: "single tensor blob",
header: map[string]any{
"model.layers.0.weight": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 2560},
"data_offsets": []int64{0, 13107200},
},
},
wantCount: 1,
wantNames: []string{"model.layers.0.weight"},
wantDtypes: []string{"BF16"},
wantQuants: []string{""},
},
{
name: "packed unquantized blob",
header: map[string]any{
"model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 10240},
"data_offsets": []int64{0, 52428800},
},
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
"dtype": "BF16",
"shape": []int64{10240, 2560},
"data_offsets": []int64{52428800, 104857600},
},
"model.layers.0.mlp.experts.0.up_proj.weight": map[string]any{
"dtype": "BF16",
"shape": []int64{10240, 2560},
"data_offsets": []int64{104857600, 157286400},
},
},
wantCount: 3,
wantNames: []string{
"model.layers.0.mlp.experts.0.down_proj.weight",
"model.layers.0.mlp.experts.0.gate_proj.weight",
"model.layers.0.mlp.experts.0.up_proj.weight",
},
wantDtypes: []string{"BF16", "BF16", "BF16"},
wantQuants: []string{"", "", ""},
},
{
name: "packed quantized blob with global metadata",
header: map[string]any{
"__metadata__": map[string]any{
"quant_type": "int4",
"group_size": "32",
},
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{10240, 320},
"data_offsets": []int64{0, 13107200},
},
"model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{10240, 80},
"data_offsets": []int64{13107200, 14745600},
},
"model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{10240, 80},
"data_offsets": []int64{14745600, 16384000},
},
"model.layers.0.mlp.experts.0.up_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{10240, 320},
"data_offsets": []int64{16384000, 29491200},
},
"model.layers.0.mlp.experts.0.up_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{10240, 80},
"data_offsets": []int64{29491200, 31129600},
},
"model.layers.0.mlp.experts.0.up_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{10240, 80},
"data_offsets": []int64{31129600, 32768000},
},
},
wantCount: 2,
wantNames: []string{
"model.layers.0.mlp.experts.0.gate_proj.weight",
"model.layers.0.mlp.experts.0.up_proj.weight",
},
wantDtypes: []string{"U32", "U32"},
wantQuants: []string{"int4", "int4"},
},
{
name: "packed mixed-precision blob (no global metadata)",
header: map[string]any{
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{10240, 320},
"data_offsets": []int64{0, 13107200},
},
"model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{10240, 80},
"data_offsets": []int64{13107200, 14745600},
},
"model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{10240, 80},
"data_offsets": []int64{14745600, 16384000},
},
"model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{2560, 2560},
"data_offsets": []int64{16384000, 42598400},
},
"model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 160},
"data_offsets": []int64{42598400, 43417600},
},
},
wantCount: 2,
wantNames: []string{
"model.layers.0.mlp.experts.0.down_proj.weight",
"model.layers.0.mlp.experts.0.gate_proj.weight",
},
wantDtypes: []string{"U32", "U32"},
wantQuants: []string{"int8", "int4"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
headerJSON, err := json.Marshal(tt.header)
if err != nil {
t.Fatalf("failed to marshal header: %v", err)
}
var buf bytes.Buffer
if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
t.Fatalf("failed to write header size: %v", err)
}
buf.Write(headerJSON)
results, err := parseSafetensorsAllHeaders(&buf)
if (err != nil) != tt.wantErr {
t.Errorf("parseSafetensorsAllHeaders() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if len(results) != tt.wantCount {
t.Fatalf("got %d tensors, want %d", len(results), tt.wantCount)
}
for i, info := range results {
if info.Name != tt.wantNames[i] {
t.Errorf("tensor[%d].Name = %v, want %v", i, info.Name, tt.wantNames[i])
}
if info.Dtype != tt.wantDtypes[i] {
t.Errorf("tensor[%d].Dtype = %v, want %v", i, info.Dtype, tt.wantDtypes[i])
}
if info.QuantType != tt.wantQuants[i] {
t.Errorf("tensor[%d].QuantType = %v, want %v", i, info.QuantType, tt.wantQuants[i])
}
}
})
}
}
func TestGetTensorInfoFromManifest_Packed(t *testing.T) {
// Create a temp directory for blobs and set OLLAMA_MODELS
tempDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", tempDir)
blobDir := filepath.Join(tempDir, "blobs")
if err := os.MkdirAll(blobDir, 0o755); err != nil {
t.Fatalf("failed to create blobs dir: %v", err)
}
// Create a packed blob with multiple expert tensors (mixed quantization)
header := map[string]any{
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{10240, 320},
"data_offsets": []int64{0, 13107200},
},
"model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{10240, 80},
"data_offsets": []int64{13107200, 14745600},
},
"model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{10240, 80},
"data_offsets": []int64{14745600, 16384000},
},
"model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{2560, 2560},
"data_offsets": []int64{16384000, 42598400},
},
"model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 160},
"data_offsets": []int64{42598400, 43417600},
},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
packedDigest := "sha256:aaaa000000000000000000000000000000000000000000000000000000000001"
blobPath, err := manifest.BlobsPath(packedDigest)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write packed blob: %v", err)
}
// Also create a regular (single-tensor) blob
singleHeader := map[string]any{
"model.embed_tokens.weight": map[string]any{
"dtype": "BF16",
"shape": []int64{262144, 2560},
"data_offsets": []int64{0, 1342177280},
},
}
singleHeaderJSON, _ := json.Marshal(singleHeader)
var singleBuf bytes.Buffer
binary.Write(&singleBuf, binary.LittleEndian, uint64(len(singleHeaderJSON)))
singleBuf.Write(singleHeaderJSON)
singleDigest := "sha256:bbbb000000000000000000000000000000000000000000000000000000000002"
singleBlobPath, err := manifest.BlobsPath(singleDigest)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(singleBlobPath, singleBuf.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write single blob: %v", err)
}
mf := &manifest.Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Layers: []manifest.Layer{
{
MediaType: manifest.MediaTypeImageTensor,
Digest: singleDigest,
Size: int64(singleBuf.Len()),
Name: "model.embed_tokens.weight",
},
{
MediaType: manifest.MediaTypeImageTensor,
Digest: packedDigest,
Size: int64(buf.Len()),
Name: "model.layers.0.mlp.experts", // group prefix
},
},
}
result, err := getTensorInfoFromManifest(mf)
if err != nil {
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
}
// Should have 3 tensors: 1 single + 2 packed main tensors
if len(result) != 3 {
t.Fatalf("got %d tensors, want 3. Tensors: %v", len(result), result)
}
// First tensor should be the single blob
if result[0].Name != "model.embed_tokens.weight" {
t.Errorf("tensor[0].Name = %v, want model.embed_tokens.weight", result[0].Name)
}
if result[0].Type != "BF16" {
t.Errorf("tensor[0].Type = %v, want BF16", result[0].Type)
}
// Packed tensors should have their actual names (sorted)
packedNames := make(map[string]bool)
for _, r := range result[1:] {
packedNames[r.Name] = true
}
if !packedNames["model.layers.0.mlp.experts.0.down_proj.weight"] {
t.Error("missing packed tensor: model.layers.0.mlp.experts.0.down_proj.weight")
}
if !packedNames["model.layers.0.mlp.experts.0.gate_proj.weight"] {
t.Error("missing packed tensor: model.layers.0.mlp.experts.0.gate_proj.weight")
}
}
func TestReadSafetensorsHeader(t *testing.T) {
// Create a temp file with a valid safetensors header
tempDir := t.TempDir()