Compare commits

...

6 Commits

Author SHA1 Message Date
Patrick Devine
2020ab1d14 gofumpt the linter 2026-02-02 18:52:06 -08:00
Patrick Devine
fc0ea22207 blink 2026-02-02 18:25:58 -08:00
Patrick Devine
2639491e1d logo tweaks 2026-02-02 17:54:25 -08:00
Patrick Devine
34541bfa70 launch: new menu system for ollama launch 2026-02-02 17:31:32 -08:00
Jeffrey Morgan
8f4a008139 Add GLM-OCR vision model support (#14024) 2026-02-02 15:39:18 -08:00
Patrick Devine
d8cc798c2b glm 4.7 flash support on experimental engine (#13838) 2026-02-02 15:22:11 -08:00
56 changed files with 7016 additions and 2425 deletions

View File

@@ -36,6 +36,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/tui"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
@@ -1791,6 +1792,159 @@ Environment Variables:
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
}
// runInteractiveTUI runs the main interactive TUI menu.
func runInteractiveTUI(cmd *cobra.Command) {
// errSelectionCancelled is returned when user cancels model selection
errSelectionCancelled := errors.New("cancelled")
// Selector adapters for tui
singleSelector := func(title string, items []config.ModelItem) (string, error) {
tuiItems := make([]tui.SelectItem, len(items))
for i, item := range items {
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description}
}
result, err := tui.SelectSingle(title, tuiItems)
if errors.Is(err, tui.ErrCancelled) {
return "", errSelectionCancelled
}
return result, err
}
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
tuiItems := make([]tui.SelectItem, len(items))
for i, item := range items {
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description}
}
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
if errors.Is(err, tui.ErrCancelled) {
return nil, errSelectionCancelled
}
return result, err
}
for {
result, err := tui.Run()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return
}
runModel := func(modelName string) {
_ = config.SetLastModel(modelName)
opts := runOptions{
Model: modelName,
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
ShowConnect: true,
}
if err := loadOrUnloadModel(cmd, &opts); err != nil {
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
return
}
if err := generateInteractive(cmd, opts); err != nil {
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
}
}
launchIntegration := func(name string) bool {
// If not configured or model no longer exists, prompt for model selection
configuredModel := config.IntegrationModel(name)
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
if errors.Is(err, errSelectionCancelled) {
return false // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
return true
}
}
if err := config.LaunchIntegration(name); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
}
return true
}
switch result.Selection {
case tui.SelectionNone:
// User quit
return
case tui.SelectionRunModel:
// Run last model directly if configured and still exists
if modelName := config.LastModel(); modelName != "" && config.ModelExists(cmd.Context(), modelName) {
runModel(modelName)
} else {
// No last model or model no longer exists, show picker
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
if errors.Is(err, errSelectionCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
continue
}
runModel(modelName)
}
case tui.SelectionChangeRunModel:
// Always show picker
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
if errors.Is(err, errSelectionCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
continue
}
runModel(modelName)
case tui.SelectionIntegration:
if !launchIntegration(result.Integration) {
continue // Return to main menu
}
case tui.SelectionChangeIntegration:
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
if errors.Is(err, errSelectionCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
continue
}
if err := config.LaunchIntegration(result.Integration); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
case tui.SelectionInstallIntegration:
// Configure first so the installer can detect the config
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
if errors.Is(err, errSelectionCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
continue
}
// Now install
fmt.Fprintf(os.Stderr, "Installing %s...\n", result.Integration)
if err := config.InstallIntegration(result.Integration); err != nil {
// Don't print error for user abort (common case)
continue
}
// Verify installation actually succeeded (some installers return 0 on abort)
if !config.IsIntegrationInstalled(result.Integration) {
// Installation was aborted or failed, return to menu
continue
}
fmt.Fprintf(os.Stderr, "%s installed successfully!\n", result.Integration)
// Don't launch after install - installers like openclaw already run
// their onboarding flow via `exec` and launching again would be
// redundant. User can select the integration from the menu.
}
}
}
func NewCLI() *cobra.Command {
log.SetFlags(log.LstdFlags | log.Lshortfile)
cobra.EnableCommandSorting = false
@@ -1813,11 +1967,13 @@ func NewCLI() *cobra.Command {
return
}
cmd.Print(cmd.UsageString())
runInteractiveTUI(cmd)
},
}
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
rootCmd.Flags().Bool("verbose", false, "Show timings for response")
rootCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
createCmd := &cobra.Command{
Use: "create MODEL",
@@ -2031,7 +2187,7 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
config.LaunchCmd(checkServerHeartbeat),
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
)
return rootCmd

View File

@@ -1553,7 +1553,7 @@ func TestShowInfoImageGen(t *testing.T) {
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
QuantizationLevel: "Q8",
},
Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0",
@@ -1565,7 +1565,7 @@ func TestShowInfoImageGen(t *testing.T) {
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization FP8 \n" +
" quantization Q8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +

View File

@@ -3,12 +3,15 @@
package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/api"
)
type integration struct {
@@ -17,6 +20,7 @@ type integration struct {
type config struct {
Integrations map[string]*integration `json:"integrations"`
LastModel string `json:"last_model,omitempty"`
}
func configPath() (string, error) {
@@ -86,6 +90,55 @@ func saveIntegration(appName string, models []string) error {
return save(cfg)
}
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
func IntegrationModel(appName string) string {
ic, err := loadIntegration(appName)
if err != nil || len(ic.Models) == 0 {
return ""
}
return ic.Models[0]
}
// LastModel returns the last model that was run, or empty string if none.
func LastModel() string {
cfg, err := load()
if err != nil {
return ""
}
return cfg.LastModel
}
// SetLastModel saves the last model that was run.
func SetLastModel(model string) error {
cfg, err := load()
if err != nil {
return err
}
cfg.LastModel = model
return save(cfg)
}
// ModelExists checks if a model exists on the Ollama server.
func ModelExists(ctx context.Context, name string) bool {
if name == "" {
return false
}
client, err := api.ClientFromEnvironment()
if err != nil {
return false
}
models, err := client.List(ctx)
if err != nil {
return false
}
for _, m := range models.Models {
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
return true
}
}
return false
}
func loadIntegration(appName string) (*integration, error) {
cfg, err := load()
if err != nil {

View File

@@ -4,9 +4,12 @@ import (
"context"
"errors"
"fmt"
"io"
"maps"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
@@ -55,6 +58,165 @@ var integrationAliases = map[string]bool{
"moltbot": true,
}
// integrationInstallURLs maps integration names to their install script URLs.
var integrationInstallURLs = map[string]string{
"claude": "https://claude.ai/install.sh",
"openclaw": "https://openclaw.ai/install.sh",
"droid": "https://app.factory.ai/cli",
"opencode": "https://opencode.ai/install",
}
// CanInstallIntegration returns true if we have an install script for this integration.
func CanInstallIntegration(name string) bool {
_, ok := integrationInstallURLs[name]
return ok
}
// IsIntegrationInstalled checks if an integration binary is installed.
func IsIntegrationInstalled(name string) bool {
switch name {
case "claude":
c := &Claude{}
_, err := c.findPath()
return err == nil
case "openclaw":
if _, err := exec.LookPath("openclaw"); err == nil {
return true
}
if _, err := exec.LookPath("clawdbot"); err == nil {
return true
}
return false
case "codex":
_, err := exec.LookPath("codex")
return err == nil
case "droid":
_, err := exec.LookPath("droid")
return err == nil
case "opencode":
_, err := exec.LookPath("opencode")
return err == nil
default:
return true // Assume installed for unknown integrations
}
}
// InstallIntegration downloads and runs the install script for an integration.
func InstallIntegration(name string) error {
url, ok := integrationInstallURLs[name]
if !ok {
return fmt.Errorf("no install script available for %s", name)
}
// Download the install script
resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("failed to download install script: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download install script: HTTP %d", resp.StatusCode)
}
script, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read install script: %w", err)
}
// Create a temporary file for the script
tmpDir := os.TempDir()
scriptPath := filepath.Join(tmpDir, fmt.Sprintf("install-%s.sh", name))
if err := os.WriteFile(scriptPath, script, 0o700); err != nil {
return fmt.Errorf("failed to write install script: %w", err)
}
defer os.Remove(scriptPath)
// Execute the script with bash
cmd := exec.Command("bash", scriptPath)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("install script failed: %w", err)
}
return nil
}
// SelectModel lets the user select a model to run.
// ModelItem represents a model for selection.
type ModelItem struct {
Name string
Description string
}
// SingleSelector is a function type for single item selection.
type SingleSelector func(title string, items []ModelItem) (string, error)
// MultiSelector is a function type for multi item selection.
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
// SelectModelWithSelector prompts the user to select a model using the provided selector.
func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (string, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return "", err
}
models, err := client.List(ctx)
if err != nil {
return "", err
}
if len(models.Models) == 0 {
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
lastModel := LastModel()
var items []ModelItem
for _, m := range models.Models {
items = append(items, ModelItem{Name: m.Name})
}
// Sort with last model first, then alphabetically
slices.SortFunc(items, func(a, b ModelItem) int {
aIsLast := a.Name == lastModel
bIsLast := b.Name == lastModel
if aIsLast != bIsLast {
if aIsLast {
return -1
}
return 1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
return selector("Select model to run:", items)
}
func SelectModel(ctx context.Context) (string, error) {
return SelectModelWithSelector(ctx, defaultSingleSelector)
}
func defaultSingleSelector(title string, items []ModelItem) (string, error) {
selectItems := make([]selectItem, len(items))
for i, item := range items {
selectItems[i] = selectItem(item)
}
return selectPrompt(title, selectItems)
}
func defaultMultiSelector(title string, items []ModelItem, preChecked []string) ([]string, error) {
selectItems := make([]selectItem, len(items))
for i, item := range items {
selectItems[i] = selectItem(item)
}
return multiSelectPrompt(title, selectItems, preChecked)
}
func selectIntegration() (string, error) {
if len(integrations) == 0 {
return "", fmt.Errorf("no integrations available")
@@ -77,8 +239,8 @@ func selectIntegration() (string, error) {
return selectPrompt("Select integration:", items)
}
// selectModels lets the user select models for an integration
func selectModels(ctx context.Context, name, current string) ([]string, error) {
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
func selectModelsWithSelectors(ctx context.Context, name, current string, single SingleSelector, multi MultiSelector) ([]string, error) {
r, ok := integrations[name]
if !ok {
return nil, fmt.Errorf("unknown integration: %s", name)
@@ -98,13 +260,13 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
var items []selectItem
var items []ModelItem
cloudModels := make(map[string]bool)
for _, m := range models.Models {
if m.RemoteModel != "" {
cloudModels[m.Name] = true
}
items = append(items, selectItem{Name: m.Name})
items = append(items, ModelItem{Name: m.Name})
}
if len(items) == 0 {
@@ -137,7 +299,7 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
}
// Sort: checked first, then alphabetical
slices.SortFunc(items, func(a, b selectItem) int {
slices.SortFunc(items, func(a, b ModelItem) int {
ac, bc := checked[a.Name], checked[b.Name]
if ac != bc {
if ac {
@@ -151,12 +313,12 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
var selected []string
// only editors support multi-model selection
if _, ok := r.(Editor); ok {
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
selected, err = multi(fmt.Sprintf("Select models for %s:", r), items, preChecked)
if err != nil {
return nil, err
}
} else {
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
model, err := single(fmt.Sprintf("Select model for %s:", r), items)
if err != nil {
return nil, err
}
@@ -233,6 +395,11 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
return selected, nil
}
// selectModels lets the user select models for an integration using default selectors.
func selectModels(ctx context.Context, name, current string) ([]string, error) {
return selectModelsWithSelectors(ctx, name, current, defaultSingleSelector, defaultMultiSelector)
}
func runIntegration(name, modelName string) error {
r, ok := integrations[name]
if !ok {
@@ -242,15 +409,86 @@ func runIntegration(name, modelName string) error {
return r.Run(modelName)
}
// LaunchIntegration launches the named integration using saved config or prompts for setup.
func LaunchIntegration(name string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
// Try to use saved config
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0])
}
// No saved config - prompt user to run setup
return fmt.Errorf("%s is not configured. Run 'ollama launch %s' to set it up", r, name)
}
// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors.
func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
models, err := selectModelsWithSelectors(ctx, name, "", single, multi)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if editor, isEditor := r.(Editor); isEditor {
paths := editor.Paths()
if len(paths) > 0 {
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
for _, p := range paths {
fmt.Fprintf(os.Stderr, " %s\n", p)
}
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
if ok, _ := confirmPrompt("Proceed?"); !ok {
return nil
}
}
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
}
if err := saveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
if len(models) == 1 {
fmt.Fprintf(os.Stderr, "Configured %s with %s\n", r, models[0])
} else {
fmt.Fprintf(os.Stderr, "Configured %s with %d models (default: %s)\n", r, len(models), models[0])
}
return nil
}
// ConfigureIntegration allows the user to select/change the model for an integration.
func ConfigureIntegration(ctx context.Context, name string) error {
return ConfigureIntegrationWithSelectors(ctx, name, defaultSingleSelector, defaultMultiSelector)
}
// LaunchCmd returns the cobra command for launching integrations.
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
// The runTUI callback is called when no arguments are provided (alias for main TUI).
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
var modelFlag string
var configFlag bool
cmd := &cobra.Command{
Use: "launch [INTEGRATION]",
Short: "Launch an integration with Ollama",
Long: `Launch an integration configured with Ollama models.
Short: "Launch the Ollama menu or an integration",
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
Without arguments, this is equivalent to running 'ollama' directly.
Supported integrations:
claude Claude Code
@@ -267,6 +505,12 @@ Examples:
Args: cobra.MaximumNArgs(1),
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
// No args - run the main TUI (same as 'ollama')
if len(args) == 0 && modelFlag == "" && !configFlag {
runTUI(cmd)
return nil
}
var name string
if len(args) > 0 {
name = args[0]

View File

@@ -86,8 +86,10 @@ func TestLaunchCmd(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
// Mock TUI function (not called in these tests)
mockTUI := func(cmd *cobra.Command) {}
cmd := LaunchCmd(mockCheck)
cmd := LaunchCmd(mockCheck, mockTUI)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION]" {
@@ -120,6 +122,75 @@ func TestLaunchCmd(t *testing.T) {
})
}
func TestLaunchCmd_TUICallback(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
t.Run("no args calls TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{})
_ = cmd.Execute()
if !tuiCalled {
t.Error("TUI callback should be called when no args provided")
}
})
t.Run("integration arg bypasses TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"claude"})
// Will error because claude isn't configured, but that's OK
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when integration arg provided")
}
})
t.Run("--model flag bypasses TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--model", "test-model"})
// Will error because no integration specified, but that's OK
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when --model flag provided")
}
})
t.Run("--config flag bypasses TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--config"})
// Will error because no integration specified, but that's OK
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when --config flag provided")
}
})
}
func TestRunIntegration_UnknownIntegration(t *testing.T) {
err := runIntegration("unknown-integration", "model")
if err == nil {
@@ -160,7 +231,7 @@ func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
// This should not panic - cmd creation should work even with nil
cmd := LaunchCmd(nil)
cmd := LaunchCmd(nil, nil)
if cmd == nil {
t.Fatal("LaunchCmd returned nil")
}

471
cmd/tui/selector.go Normal file
View File

@@ -0,0 +1,471 @@
package tui
import (
"errors"
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
var (
selectorTitleStyle = lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("147"))
selectorItemStyle = lipgloss.NewStyle().
PaddingLeft(4)
selectorSelectedItemStyle = lipgloss.NewStyle().
PaddingLeft(2).
Foreground(lipgloss.Color("147")).
Bold(true)
selectorDescStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241"))
selectorFilterStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
Italic(true)
selectorInputStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("252"))
selectorCheckboxStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241"))
selectorCheckboxCheckedStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("147"))
selectorDefaultTagStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
Italic(true)
selectorHelpStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241"))
selectorMoreStyle = lipgloss.NewStyle().
PaddingLeft(4).
Foreground(lipgloss.Color("241")).
Italic(true)
)
const maxSelectorItems = 10
// ErrCancelled is returned when the user cancels the selection.
var ErrCancelled = errors.New("cancelled")
// SelectItem represents an item that can be selected.
type SelectItem struct {
Name string
Description string
}
// selectorModel is the bubbletea model for single selection.
type selectorModel struct {
title string
items []SelectItem
filter string
cursor int
scrollOffset int
selected string
cancelled bool
}
func (m selectorModel) filteredItems() []SelectItem {
if m.filter == "" {
return m.items
}
filterLower := strings.ToLower(m.filter)
var result []SelectItem
for _, item := range m.items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}
func (m selectorModel) Init() tea.Cmd {
return nil
}
func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
filtered := m.filteredItems()
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
case tea.KeyEnter:
if len(filtered) > 0 && m.cursor < len(filtered) {
m.selected = filtered[m.cursor].Name
}
return m, tea.Quit
case tea.KeyUp:
if m.cursor > 0 {
m.cursor--
if m.cursor < m.scrollOffset {
m.scrollOffset = m.cursor
}
}
case tea.KeyDown:
if m.cursor < len(filtered)-1 {
m.cursor++
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
}
case tea.KeyBackspace:
if len(m.filter) > 0 {
m.filter = m.filter[:len(m.filter)-1]
m.cursor = 0
m.scrollOffset = 0
}
case tea.KeyRunes:
m.filter += string(msg.Runes)
m.cursor = 0
m.scrollOffset = 0
}
}
return m, nil
}
func (m selectorModel) View() string {
// Clear screen when exiting
if m.cancelled || m.selected != "" {
return ""
}
var s strings.Builder
// Title with filter
s.WriteString(selectorTitleStyle.Render(m.title))
s.WriteString(" ")
if m.filter == "" {
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
} else {
s.WriteString(selectorInputStyle.Render(m.filter))
}
s.WriteString("\n\n")
filtered := m.filteredItems()
if len(filtered) == 0 {
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
s.WriteString("\n")
} else {
displayCount := min(len(filtered), maxSelectorItems)
for i := range displayCount {
idx := m.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
} else {
s.WriteString(selectorItemStyle.Render(item.Name))
}
if item.Description != "" {
s.WriteString(" ")
s.WriteString(selectorDescStyle.Render("- " + item.Description))
}
s.WriteString("\n")
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
s.WriteString("\n")
}
}
s.WriteString("\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • esc cancel"))
return s.String()
}
// SelectSingle prompts the user to select a single item from a list.
func SelectSingle(title string, items []SelectItem) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
m := selectorModel{
title: title,
items: items,
}
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return "", fmt.Errorf("error running selector: %w", err)
}
fm := finalModel.(selectorModel)
if fm.cancelled {
return "", ErrCancelled
}
return fm.selected, nil
}
// multiSelectorModel is the bubbletea model for multi selection.
type multiSelectorModel struct {
title string
items []SelectItem
itemIndex map[string]int
filter string
cursor int
scrollOffset int
checked map[int]bool
checkOrder []int
cancelled bool
confirmed bool
}
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
m := multiSelectorModel{
title: title,
items: items,
itemIndex: make(map[string]int, len(items)),
checked: make(map[int]bool),
}
for i, item := range items {
m.itemIndex[item.Name] = i
}
for _, name := range preChecked {
if idx, ok := m.itemIndex[name]; ok {
m.checked[idx] = true
m.checkOrder = append(m.checkOrder, idx)
}
}
return m
}
func (m multiSelectorModel) filteredItems() []SelectItem {
if m.filter == "" {
return m.items
}
filterLower := strings.ToLower(m.filter)
var result []SelectItem
for _, item := range m.items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}
func (m *multiSelectorModel) toggleItem() {
filtered := m.filteredItems()
if len(filtered) == 0 || m.cursor >= len(filtered) {
return
}
item := filtered[m.cursor]
origIdx := m.itemIndex[item.Name]
if m.checked[origIdx] {
delete(m.checked, origIdx)
for i, idx := range m.checkOrder {
if idx == origIdx {
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
break
}
}
} else {
m.checked[origIdx] = true
m.checkOrder = append(m.checkOrder, origIdx)
}
}
func (m multiSelectorModel) selectedCount() int {
return len(m.checkOrder)
}
func (m multiSelectorModel) Init() tea.Cmd {
return nil
}
func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
filtered := m.filteredItems()
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
case tea.KeyEnter:
// Enter confirms if at least one item is selected
if len(m.checkOrder) > 0 {
m.confirmed = true
return m, tea.Quit
}
case tea.KeySpace:
// Space always toggles selection
m.toggleItem()
case tea.KeyUp:
if m.cursor > 0 {
m.cursor--
if m.cursor < m.scrollOffset {
m.scrollOffset = m.cursor
}
}
case tea.KeyDown:
if m.cursor < len(filtered)-1 {
m.cursor++
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
}
case tea.KeyBackspace:
if len(m.filter) > 0 {
m.filter = m.filter[:len(m.filter)-1]
m.cursor = 0
m.scrollOffset = 0
}
case tea.KeyRunes:
m.filter += string(msg.Runes)
m.cursor = 0
m.scrollOffset = 0
}
}
return m, nil
}
func (m multiSelectorModel) View() string {
// Clear screen when exiting
if m.cancelled || m.confirmed {
return ""
}
var s strings.Builder
// Title with filter
s.WriteString(selectorTitleStyle.Render(m.title))
s.WriteString(" ")
if m.filter == "" {
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
} else {
s.WriteString(selectorInputStyle.Render(m.filter))
}
s.WriteString("\n\n")
filtered := m.filteredItems()
if len(filtered) == 0 {
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
s.WriteString("\n")
} else {
displayCount := min(len(filtered), maxSelectorItems)
for i := range displayCount {
idx := m.scrollOffset + i
if idx >= len(filtered) {
break
}
item := filtered[idx]
origIdx := m.itemIndex[item.Name]
// Checkbox
var checkbox string
if m.checked[origIdx] {
checkbox = selectorCheckboxCheckedStyle.Render("[x]")
} else {
checkbox = selectorCheckboxStyle.Render("[ ]")
}
// Cursor and name
var line string
if idx == m.cursor {
line = selectorSelectedItemStyle.Render("▸ ") + checkbox + " " + selectorSelectedItemStyle.Render(item.Name)
} else {
line = " " + checkbox + " " + item.Name
}
// Default tag
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
line += " " + selectorDefaultTagStyle.Render("(default)")
}
s.WriteString(line)
s.WriteString("\n")
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
s.WriteString("\n")
}
}
s.WriteString("\n")
// Status line
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
} else {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
return s.String()
}
// SelectMultiple prompts the user to select multiple items from a list.
func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) {
if len(items) == 0 {
return nil, fmt.Errorf("no items to select from")
}
m := newMultiSelectorModel(title, items, preChecked)
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return nil, fmt.Errorf("error running selector: %w", err)
}
fm := finalModel.(multiSelectorModel)
if fm.cancelled {
return nil, ErrCancelled
}
if !fm.confirmed {
return nil, ErrCancelled
}
var result []string
for _, idx := range fm.checkOrder {
result = append(result, fm.items[idx].Name)
}
return result, nil
}

526
cmd/tui/tui.go Normal file
View File

@@ -0,0 +1,526 @@
package tui
import (
"context"
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/version"
)
const (
logoNormal = ` ▆▁▂▃▂▁▆
▟███████▙
█▙▛▄▄▄▜▟█
▟█▙▀ ▀▟█▙
█████████
▟███████▙
▀▀▀▀▀▀▀▀▀`
logoBlink = ` ▆▁▂▃▂▁▆
▟███████▙
██▛▄▄▄▜██
▟█▙▀ ▀▟█▙
█████████
▟███████▙
▀▀▀▀▀▀▀▀▀`
blinkInterval = 15 * time.Second
blinkDuration = 250 * time.Millisecond
)
type (
blinkMsg struct{}
unblinkMsg struct{}
)
var (
logoStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("255")).
Background(lipgloss.Color("0"))
titleStyle = lipgloss.NewStyle().
Bold(true).
MarginBottom(1)
itemStyle = lipgloss.NewStyle().
PaddingLeft(2)
selectedStyle = lipgloss.NewStyle().
PaddingLeft(2).
Foreground(lipgloss.Color("147")).
Bold(true)
greyedStyle = lipgloss.NewStyle().
PaddingLeft(2).
Foreground(lipgloss.Color("241"))
greyedSelectedStyle = lipgloss.NewStyle().
PaddingLeft(2).
Foreground(lipgloss.Color("243"))
descStyle = lipgloss.NewStyle().
PaddingLeft(4).
Foreground(lipgloss.Color("241"))
modelStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("245"))
notInstalledStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
Italic(true)
// Dialog styles
dialogBoxStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("147")).
Padding(1, 2).
MarginTop(1)
dialogTitleStyle = lipgloss.NewStyle().
Bold(true).
MarginBottom(1)
buttonStyle = lipgloss.NewStyle().
Padding(0, 2).
Background(lipgloss.Color("238")).
Foreground(lipgloss.Color("252"))
activeButtonStyle = lipgloss.NewStyle().
Padding(0, 2).
Background(lipgloss.Color("147")).
Foreground(lipgloss.Color("235")).
Bold(true)
)
type menuItem struct {
title string
description string
integration string // integration name for loading model config, empty if not an integration
isRunModel bool // true for the "Run a model" option
isOthers bool // true for the "Others..." toggle item
}
var mainMenuItems = []menuItem{
{
title: "Run a model",
description: "Start an interactive chat with a local model",
isRunModel: true,
},
{
title: "Launch Claude Code",
description: "Open Claude Code AI assistant",
integration: "claude",
},
{
title: "Launch Open Claw",
description: "Open the Open Claw integration",
integration: "openclaw",
},
}
var othersMenuItem = menuItem{
title: "Others...",
description: "Show additional integrations",
isOthers: true,
}
// getOtherIntegrations returns the list of other integrations, filtering out
// Codex if it's not installed (since it requires npm install).
func getOtherIntegrations() []menuItem {
items := []menuItem{
{
title: "Launch Droid",
description: "Open Droid integration",
integration: "droid",
},
{
title: "Launch Open Code",
description: "Open Open Code integration",
integration: "opencode",
},
}
// Only show Codex if it's already installed
if config.IsIntegrationInstalled("codex") {
items = append([]menuItem{{
title: "Launch Codex",
description: "Open Codex CLI",
integration: "codex",
}}, items...)
}
return items
}
type model struct {
items []menuItem
cursor int
quitting bool
selected bool // true if user made a selection (enter/space)
changeModel bool // true if user pressed 'm' to change model
showDialog bool // true if showing install confirmation dialog
dialogCursor int // 0 = Cancel, 1 = OK
installPending string // integration name pending install
showOthers bool // true if "Others..." is expanded
availableModels map[string]bool // cache of available model names
blinking bool // true when showing blink logo
err error
}
// modelExists checks if a model exists in the cached available models.
func (m *model) modelExists(name string) bool {
if m.availableModels == nil || name == "" {
return false
}
if m.availableModels[name] {
return true
}
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
for modelName := range m.availableModels {
if strings.HasPrefix(modelName, name+":") {
return true
}
}
return false
}
// loadAvailableModels fetches and caches the list of available models.
func (m *model) loadAvailableModels() {
m.availableModels = make(map[string]bool)
client, err := api.ClientFromEnvironment()
if err != nil {
return
}
models, err := client.List(context.Background())
if err != nil {
return
}
for _, mdl := range models.Models {
m.availableModels[mdl.Name] = true
}
}
func (m *model) buildItems() {
others := getOtherIntegrations()
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
m.items = append(m.items, mainMenuItems...)
if m.showOthers {
// Change "Others..." to "Hide others..."
hideItem := menuItem{
title: "Hide others...",
description: "Hide additional integrations",
isOthers: true,
}
m.items = append(m.items, hideItem)
m.items = append(m.items, others...)
} else {
m.items = append(m.items, othersMenuItem)
}
}
func initialModel() model {
m := model{
cursor: 0,
}
m.loadAvailableModels()
m.buildItems()
return m
}
func (m model) Init() tea.Cmd {
return tea.Tick(blinkInterval, func(t time.Time) tea.Msg {
return blinkMsg{}
})
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case blinkMsg:
m.blinking = true
return m, tea.Tick(blinkDuration, func(t time.Time) tea.Msg {
return unblinkMsg{}
})
case unblinkMsg:
m.blinking = false
return m, tea.Tick(blinkInterval, func(t time.Time) tea.Msg {
return blinkMsg{}
})
case tea.KeyMsg:
// Handle dialog input
if m.showDialog {
switch msg.String() {
case "ctrl+c", "esc":
// Close dialog and return to menu
m.showDialog = false
m.installPending = ""
m.dialogCursor = 0
return m, nil
case "left", "h":
m.dialogCursor = 0
return m, nil
case "right", "l":
m.dialogCursor = 1
return m, nil
case "tab":
m.dialogCursor = (m.dialogCursor + 1) % 2
return m, nil
case "enter", " ":
if m.dialogCursor == 1 {
// OK selected - proceed with install
m.selected = true
m.quitting = true
return m, tea.Quit
}
// Cancel selected - close dialog and return to menu
m.showDialog = false
m.installPending = ""
m.dialogCursor = 0
return m, nil
}
// Unknown key in dialog - ignore
return m, nil
}
// Handle main menu input
switch msg.String() {
case "ctrl+c", "q", "esc":
m.quitting = true
return m, tea.Quit
case "up", "k":
if m.cursor > 0 {
m.cursor--
}
case "down", "j":
if m.cursor < len(m.items)-1 {
m.cursor++
}
case "enter", " ":
item := m.items[m.cursor]
// Handle "Others..." toggle
if item.isOthers {
m.showOthers = !m.showOthers
m.buildItems()
// Keep cursor on the Others/Hide item
if m.cursor >= len(m.items) {
m.cursor = len(m.items) - 1
}
return m, nil
}
// Check if integration is installed
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
// Only show install dialog if we can install it
if config.CanInstallIntegration(item.integration) {
m.showDialog = true
m.dialogCursor = 1 // Default to OK
m.installPending = item.integration
return m, nil
}
// Can't install - just try to launch (will show error with install instructions)
}
m.selected = true
m.quitting = true
return m, tea.Quit
case "m":
// Allow model change for integrations and run model
item := m.items[m.cursor]
if item.integration != "" || item.isRunModel {
// For integrations, check if installed first
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
// Only show install dialog if we can install it
if config.CanInstallIntegration(item.integration) {
m.showDialog = true
m.dialogCursor = 1
m.installPending = item.integration
return m, nil
}
// Can't install - just try to launch (will show error with install instructions)
}
m.changeModel = true
m.quitting = true
return m, tea.Quit
}
}
}
return m, nil
}
func (m model) View() string {
if m.quitting {
return ""
}
logo := logoNormal
if m.blinking {
logo = logoBlink
}
versionText := "\n\n Ollama v" + version.Version
logoRendered := logoStyle.Render(logo)
logoBlock := lipgloss.NewStyle().Padding(0, 1).Background(lipgloss.Color("0")).Render(logoRendered)
versionBlock := titleStyle.Render(versionText)
header := lipgloss.JoinHorizontal(lipgloss.Top, logoBlock, versionBlock)
s := header + "\n\n"
for i, item := range m.items {
cursor := " "
style := itemStyle
isInstalled := true
if item.integration != "" {
isInstalled = config.IsIntegrationInstalled(item.integration)
}
if m.cursor == i {
cursor = "▸ "
if isInstalled {
style = selectedStyle
} else {
style = greyedSelectedStyle
}
} else if !isInstalled && item.integration != "" {
style = greyedStyle
}
title := item.title
if item.integration != "" {
if !isInstalled {
title += " " + notInstalledStyle.Render("(not installed)")
} else if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
title += " " + modelStyle.Render("("+mdl+")")
}
} else if item.isRunModel {
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
title += " " + modelStyle.Render("("+mdl+")")
}
}
s += style.Render(cursor+title) + "\n"
s += descStyle.Render(item.description) + "\n\n"
}
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Render("↑/↓ navigate • enter select • m change model • esc quit")
// Show dialog if active
if m.showDialog {
s += "\n\n"
integrationName := "this integration"
switch m.installPending {
case "claude":
integrationName = "Claude Code"
case "openclaw":
integrationName = "Open Claw"
case "codex":
integrationName = "Codex"
case "droid":
integrationName = "Droid"
case "opencode":
integrationName = "Open Code"
}
dialogContent := dialogTitleStyle.Render(fmt.Sprintf("Install %s?", integrationName)) + "\n\n"
dialogContent += fmt.Sprintf("%s is not installed. Would you like to install it now?\n\n", integrationName)
cancelBtn := buttonStyle.Render("Cancel")
okBtn := buttonStyle.Render(" OK ")
if m.dialogCursor == 0 {
cancelBtn = activeButtonStyle.Render("Cancel")
} else {
okBtn = activeButtonStyle.Render(" OK ")
}
dialogContent += cancelBtn + " " + okBtn
s += dialogBoxStyle.Render(dialogContent)
}
return s
}
// Selection represents what the user selected
type Selection int
const (
SelectionNone Selection = iota
SelectionRunModel
SelectionChangeRunModel
SelectionIntegration // Generic integration selection
SelectionChangeIntegration // Generic change model for integration
SelectionInstallIntegration
)
// Result contains the selection and any associated data
type Result struct {
Selection Selection
Integration string // integration name if applicable
}
// Run starts the TUI and returns the user's selection
func Run() (Result, error) {
m := initialModel()
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
}
fm := finalModel.(model)
if fm.err != nil {
return Result{Selection: SelectionNone}, fm.err
}
// User quit without selecting
if !fm.selected && !fm.changeModel {
return Result{Selection: SelectionNone}, nil
}
// Handle install request
if fm.installPending != "" {
return Result{
Selection: SelectionInstallIntegration,
Integration: fm.installPending,
}, nil
}
item := fm.items[fm.cursor]
// Handle model change request
if fm.changeModel {
if item.isRunModel {
return Result{Selection: SelectionChangeRunModel}, nil
}
return Result{
Selection: SelectionChangeIntegration,
Integration: item.integration,
}, nil
}
// Handle selection
if item.isRunModel {
return Result{Selection: SelectionRunModel}, nil
}
return Result{
Selection: SelectionIntegration,
Integration: item.integration,
}, nil
}

View File

@@ -313,6 +313,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &deepseek2Model{}
case "Glm4MoeLiteForCausalLM":
conv = &glm4MoeLiteModel{}
case "GlmOcrForConditionalGeneration":
conv = &glmOcrModel{}
case "Lfm2ForCausalLM":
conv = &lfm2Model{}
default:

455
convert/convert_glmocr.go Normal file
View File

@@ -0,0 +1,455 @@
package convert
import (
"cmp"
"encoding/json"
"io/fs"
"log/slog"
"regexp"
"strconv"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
// normalToNeoXRepacker creates a repacker that permutes Q/K weights from interleaved (LLaMA)
// to NeoX ordering for compatibility with GGML's M-RoPE kernel.
//
// For weights: reshape [out, in] -> [n_heads, head_dim, in], permute rotary dims, reshape back
// For biases: reshape [out] -> [n_heads, head_dim], permute rotary dims, reshape back
func normalToNeoXRepacker(nHeads, headDim int, partialRotaryFactor float32) func(string, []float32, []uint64) ([]float32, error) {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
rotaryDim := int(float32(headDim) * partialRotaryFactor)
if rotaryDim%2 != 0 {
rotaryDim = (rotaryDim / 2) * 2 // Round down to even
}
// Handle 1D (bias) or 2D (weight) tensors
is1D := len(shape) == 1
var inFeatures int
if is1D {
inFeatures = 1
} else {
inFeatures = int(shape[1])
}
outFeatures := int(shape[0])
nEffectiveHeads := outFeatures / headDim
if nEffectiveHeads != nHeads {
slog.Warn("normalToNeoX: unexpected head count", "effective", nEffectiveHeads, "expected", nHeads)
}
// Reshape to [n_heads, head_dim, in_features]
reshaped := make([]float32, len(data))
copy(reshaped, data)
// Permute the rotary dimensions: even indices first, then odd
// For each head, reorder [0,1,2,3,4,5...] to [0,2,4...,1,3,5...]
result := make([]float32, len(data))
halfRotary := rotaryDim / 2
for h := range nEffectiveHeads {
for f := range inFeatures {
for i := range halfRotary {
// Even dim (0, 2, 4, ...) -> position i
srcIdx := h*headDim*inFeatures + (2*i)*inFeatures + f
dstIdx := h*headDim*inFeatures + i*inFeatures + f
result[dstIdx] = reshaped[srcIdx]
// Odd dim (1, 3, 5, ...) -> position halfRotary + i
srcIdx = h*headDim*inFeatures + (2*i+1)*inFeatures + f
dstIdx = h*headDim*inFeatures + (halfRotary+i)*inFeatures + f
result[dstIdx] = reshaped[srcIdx]
}
// Non-rotary part: copy as-is
for i := rotaryDim; i < headDim; i++ {
srcIdx := h*headDim*inFeatures + i*inFeatures + f
result[srcIdx] = reshaped[srcIdx]
}
}
}
return result, nil
}
}
type glmOcrModel struct {
ModelParameters
TextConfig struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
MaxPositionEmbed uint32 `json:"max_position_embeddings"`
RMSNormEps float32 `json:"rms_norm_eps"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeParameters struct {
RopeType string `json:"rope_type"`
MRopeSection []int32 `json:"mrope_section"`
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionConfig struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
Depth uint32 `json:"depth"`
NumHeads uint32 `json:"num_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
OutHiddenSize uint32 `json:"out_hidden_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"`
ImageStartTokenID uint32 `json:"image_start_token_id"`
ImageEndTokenID uint32 `json:"image_end_token_id"`
VideoStartTokenID uint32 `json:"video_start_token_id"`
VideoEndTokenID uint32 `json:"video_end_token_id"`
ImageTokenID uint32 `json:"image_token_id"`
VideoTokenID uint32 `json:"video_token_id"`
// Preprocessor config (preprocessor_config.json)
Preprocessor struct {
Size struct {
ShortestEdge uint32 `json:"shortest_edge"`
LongestEdge uint32 `json:"longest_edge"`
} `json:"size"`
PatchSize uint32 `json:"patch_size"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
MergeSize uint32 `json:"merge_size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
} `json:"-"`
}
var _ ModelConverter = (*glmOcrModel)(nil)
func (m *glmOcrModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "preprocessor_config.json")
if err != nil {
return err
}
return json.Unmarshal(bts, &m.Preprocessor)
}
func (m *glmOcrModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "glmocr"
// Text model parameters
kv["glmocr.block_count"] = cmp.Or(m.TextConfig.NumHiddenLayers, 16)
kv["glmocr.embedding_length"] = cmp.Or(m.TextConfig.HiddenSize, 1536)
kv["glmocr.attention.head_count"] = cmp.Or(m.TextConfig.NumAttentionHeads, 16)
kv["glmocr.attention.head_count_kv"] = cmp.Or(m.TextConfig.NumKeyValueHeads, 8)
headDim := cmp.Or(m.TextConfig.HeadDim, m.TextConfig.HiddenSize/m.TextConfig.NumAttentionHeads)
kv["glmocr.attention.key_length"] = headDim
kv["glmocr.attention.value_length"] = headDim
kv["glmocr.feed_forward_length"] = cmp.Or(m.TextConfig.IntermediateSize, 4608)
kv["glmocr.attention.layer_norm_rms_epsilon"] = cmp.Or(m.TextConfig.RMSNormEps, 1e-5)
kv["glmocr.context_length"] = cmp.Or(m.TextConfig.MaxPositionEmbed, 131072)
kv["glmocr.rope.freq_base"] = cmp.Or(m.TextConfig.RopeParameters.RopeTheta, float32(10000))
kv["glmocr.rope.partial_rotary_factor"] = cmp.Or(m.TextConfig.RopeParameters.PartialRotaryFactor, m.TextConfig.PartialRotaryFactor, float32(1.0))
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 {
kv["glmocr.rope.mrope_section"] = m.TextConfig.RopeParameters.MRopeSection
}
// Vision model parameters
kv["glmocr.vision.block_count"] = cmp.Or(m.VisionConfig.Depth, 24)
kv["glmocr.vision.embedding_length"] = cmp.Or(m.VisionConfig.HiddenSize, 1024)
kv["glmocr.vision.attention.head_count"] = cmp.Or(m.VisionConfig.NumHeads, 16)
kv["glmocr.vision.image_size"] = cmp.Or(m.VisionConfig.ImageSize, 336)
kv["glmocr.vision.patch_size"] = cmp.Or(m.VisionConfig.PatchSize, m.Preprocessor.PatchSize, 14)
kv["glmocr.vision.spatial_merge_size"] = cmp.Or(m.VisionConfig.SpatialMergeSize, m.Preprocessor.MergeSize, 2)
kv["glmocr.vision.temporal_patch_size"] = cmp.Or(m.VisionConfig.TemporalPatchSize, m.Preprocessor.TemporalPatchSize, 2)
kv["glmocr.vision.out_hidden_size"] = cmp.Or(m.VisionConfig.OutHiddenSize, 1536)
kv["glmocr.vision.intermediate_size"] = cmp.Or(m.VisionConfig.IntermediateSize, 4096)
kv["glmocr.vision.attention.layer_norm_rms_epsilon"] = cmp.Or(m.VisionConfig.RMSNormEps, 1e-5)
// Preprocessor-derived image settings (min/max pixels and normalization)
// Note: fs.Config.keyValue() auto-prepends architecture prefix, so use full key
if m.Preprocessor.Size.ShortestEdge > 0 {
kv["glmocr.vision.min_pixels"] = m.Preprocessor.Size.ShortestEdge
}
if m.Preprocessor.Size.LongestEdge > 0 {
kv["glmocr.vision.max_pixels"] = m.Preprocessor.Size.LongestEdge
}
if len(m.Preprocessor.ImageMean) == 3 {
kv["glmocr.vision.image_mean"] = m.Preprocessor.ImageMean
}
if len(m.Preprocessor.ImageStd) == 3 {
kv["glmocr.vision.image_std"] = m.Preprocessor.ImageStd
}
// Special tokens
kv["glmocr.image_token_id"] = m.ImageTokenID
kv["glmocr.image_start_token_id"] = m.ImageStartTokenID
kv["glmocr.image_end_token_id"] = m.ImageEndTokenID
kv["glmocr.video_token_id"] = m.VideoTokenID
kv["glmocr.video_start_token_id"] = m.VideoStartTokenID
kv["glmocr.video_end_token_id"] = m.VideoEndTokenID
return kv
}
func (m *glmOcrModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
// Skip layers >= num_hidden_layers (Multi-Token Prediction layers not needed for basic inference)
numLayers := int(cmp.Or(m.TextConfig.NumHiddenLayers, 16))
skipLayer := func(name string) bool {
// Tensor names are already replaced to "blk.N.xxx" format
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(name)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return blkNum >= numLayers
}
for _, t := range ts {
name := t.Name()
// Skip next-n prediction layers (layers >= num_hidden_layers)
if skipLayer(name) {
continue
}
// Split ffn_gate_up into separate gate and up projections
if strings.Contains(name, "ffn_gate_up") {
for t := range splitDim(t, 0,
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_gate")},
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_up")},
) {
out = append(out, t)
}
continue
}
if strings.HasSuffix(name, "patch_embd.weight") {
shape := t.Shape()
if len(shape) == 5 && shape[2] == 2 {
newShape := []uint64{shape[0], shape[1], shape[3], shape[4]}
t0 := t.Clone()
t0.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tt.Slice(nil, nil, tensor.S(0, 1), nil, nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
if err := tt.Reshape(newDims...); err != nil {
return nil, err
}
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
})
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t0,
})
t1 := t.Clone()
t1.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tt.Slice(nil, nil, tensor.S(1, 2), nil, nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
if err := tt.Reshape(newDims...); err != nil {
return nil, err
}
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
})
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_1.weight", 1),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t1,
})
continue
}
if len(shape) == 4 {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
slog.Warn("glmocr: patch_embed weight has unexpected shape - not splitting", "shape", shape)
// Fall through to default handling
}
// Handle pre-split patch embedding weights
// Pattern 1: v.patch_embd.0.weight, v.patch_embd.1.weight -> patch_embd_0.weight, patch_embd_1.weight
// Pattern 2: v.patch_embd.weight.0, v.patch_embd.weight.1 -> patch_embd_0.weight, patch_embd_1.weight
if strings.Contains(name, "patch_embd.0.") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.0.", "patch_embd_0.", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
if strings.Contains(name, "patch_embd.1.") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.1.", "patch_embd_1.", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
// Handle .weight.0 and .weight.1 suffix patterns
if strings.HasSuffix(name, "patch_embd.weight.0") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight.0", "patch_embd_0.weight", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
if strings.HasSuffix(name, "patch_embd.weight.1") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight.1", "patch_embd_1.weight", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
// Permute Q/K weights for M-RoPE compatibility (interleaved -> NeoX ordering)
// GGML's M-RoPE kernel uses NeoX-style rotation, but GLM-OCR uses interleaved (LLaMA-style)
// We permute at conversion time so the weights work correctly with GGML's kernel
// This aligns Q/K rotary dimensions with GGML's NeoX-style rotation
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 &&
strings.Contains(name, "blk.") && (strings.Contains(name, "attn_q.") || strings.Contains(name, "attn_k.")) {
// Get config values for permutation
nHeads := int(cmp.Or(m.TextConfig.NumAttentionHeads, 16))
nKVHeads := int(cmp.Or(m.TextConfig.NumKeyValueHeads, 8))
hiddenSize := int(cmp.Or(m.TextConfig.HiddenSize, 1536))
headDim := int(cmp.Or(m.TextConfig.HeadDim, uint32(hiddenSize/nHeads)))
partialRotaryFactor := cmp.Or(m.TextConfig.PartialRotaryFactor, m.TextConfig.RopeParameters.PartialRotaryFactor, float32(1.0))
// Use appropriate head count: nHeads for Q, nKVHeads for K
effectiveHeads := nHeads
if strings.Contains(name, "attn_k.") {
effectiveHeads = nKVHeads
}
permutedT := t.Clone()
permutedT.SetRepacker(normalToNeoXRepacker(effectiveHeads, headDim, partialRotaryFactor))
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: permutedT,
})
continue
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *glmOcrModel) Replacements() []string {
return []string{
// Vision encoder
"model.visual.patch_embed.proj_1", "v.patch_embd_1", // Second temporal split
"model.visual.patch_embed.proj", "v.patch_embd",
"model.visual.blocks", "v.blk",
"model.visual.post_layernorm", "v.post_ln",
"model.visual.downsample", "mm.patch_merger",
// Vision attention
"attn.qkv", "attn_qkv",
"attn.proj", "attn_out",
"attn.q_norm", "attn_q_norm",
"attn.k_norm", "attn_k_norm",
// Vision norms
"norm1", "ln1",
"norm2", "ln2",
// Vision MLP
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
// Merger (multimodal projector)
"model.visual.merger.proj", "mm.model.fc",
"model.visual.merger.post_projection_norm", "mm.post_norm",
"model.visual.merger.gate_proj", "mm.gate",
"model.visual.merger.up_proj", "mm.up",
"model.visual.merger.down_proj", "mm.down",
// Language model
"model.language_model.embed_tokens", "token_embd",
"model.language_model.layers", "blk",
"model.language_model.norm", "output_norm",
"lm_head", "output",
// Language model attention
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_out",
// Language model norms
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"post_self_attn_layernorm", "post_attn_norm",
"post_mlp_layernorm", "post_ffn_norm",
// Language model MLP (remove mlp. prefix so ffn_* names work)
"mlp.gate_up_proj", "ffn_gate_up",
"mlp.down_proj", "ffn_down",
}
}

View File

@@ -99,6 +99,7 @@ func (st safetensor) Kind() uint32 {
if st.dtype == "BF16" &&
!strings.HasPrefix(st.name, "v.") &&
!strings.HasPrefix(st.name, "s.") &&
!strings.HasPrefix(st.name, "mm.") &&
kind != tensorKindFP32 {
kind = tensorKindBF16
}

View File

@@ -270,6 +270,7 @@ func (kv KV) OllamaEngineRequired() bool {
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
"glm4moelite",
"glmocr",
"lfm2",
}, kv.Architecture())
}
@@ -859,6 +860,7 @@ func (f GGML) FlashAttention() bool {
"bert",
"gemma3",
"glm4moelite",
"glmocr",
"gptoss", "gpt-oss",
"lfm2",
"mistral3",

18
go.mod
View File

@@ -21,10 +21,12 @@ require (
require (
github.com/agnivade/levenshtein v1.1.1
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/mattn/go-runewidth v0.0.14
github.com/mattn/go-runewidth v0.0.16
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
@@ -37,22 +39,34 @@ require (
require (
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/ansi v0.10.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.11.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/tkrajina/go-reflector v0.5.5 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect

36
go.sum
View File

@@ -14,6 +14,8 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
@@ -24,6 +26,18 @@ github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
@@ -59,6 +73,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
@@ -148,13 +164,17 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -162,6 +182,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
@@ -180,8 +206,9 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
@@ -218,6 +245,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -304,6 +333,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=

View File

@@ -170,6 +170,7 @@ type Tensor interface {
Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
GELU_ERF(ctx Context) Tensor
QuickGELU(ctx Context, up ...Tensor) Tensor
SILU(ctx Context, up ...Tensor) Tensor
RELU(ctx Context, up ...Tensor) Tensor

View File

@@ -1581,6 +1581,13 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) GELU_ERF(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_gelu_erf_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
var tt *C.struct_ggml_tensor
if len(t2) > 0 {

View File

@@ -0,0 +1,174 @@
package glmocr
import (
"image"
"log/slog"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize int
patchSize int
temporalPatchSize int
spatialMergeSize int
minPixels int
maxPixels int
factor int
imageMean [3]float32
imageStd [3]float32
}
func newImageProcessor(c fs.Config) ImageProcessor {
patchSize := int(c.Uint("vision.patch_size", 14))
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
// Read normalization values from config if available, otherwise use CLIP defaults
imageMean := c.Floats("vision.image_mean", imageproc.ClipDefaultMean[:])
imageStd := c.Floats("vision.image_std", imageproc.ClipDefaultSTD[:])
// Default max_pixels: 2048 * patchSize^2 * mergeSize^2 * temporal = ~3.2M pixels
// This limits to ~16k patches (4k output tokens) to keep memory stable without flash attention
defaultMaxPixels := 2048 * patchSize * patchSize * spatialMergeSize * spatialMergeSize * temporalPatchSize
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 336)),
patchSize: patchSize,
temporalPatchSize: temporalPatchSize,
spatialMergeSize: spatialMergeSize,
minPixels: int(c.Uint("vision.min_pixels", uint32(8*patchSize*patchSize*spatialMergeSize*spatialMergeSize*temporalPatchSize))),
maxPixels: int(c.Uint("vision.max_pixels", uint32(defaultMaxPixels))),
factor: patchSize * spatialMergeSize,
imageMean: [3]float32{imageMean[0], imageMean[1], imageMean[2]},
imageStd: [3]float32{imageStd[0], imageStd[1], imageStd[2]},
}
}
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
factor := p.factor
temporalFactor := p.temporalPatchSize
numFrames := temporalFactor // single image
if height < factor || width < factor {
// Scale up small images
scale := float64(factor) / float64(min(height, width))
height = int(math.Ceil(float64(height) * scale))
width = int(math.Ceil(float64(width) * scale))
}
if temporalFactor <= 0 {
slog.Warn("temporal_patch_size must be > 0, defaulting to 1")
temporalFactor = 1
}
if numFrames < temporalFactor {
slog.Warn("num_frames must be >= temporal_patch_size, adjusting num_frames", "num_frames", numFrames, "temporal_patch_size", temporalFactor)
numFrames = temporalFactor
}
if aspectRatio := float64(max(height, width)) / float64(min(height, width)); aspectRatio > 200 {
slog.Warn("aspect ratio exceeds 200, image quality may be affected", "aspect_ratio", aspectRatio)
}
round := func(x float64) int { return int(math.RoundToEven(x)) }
hBar := round(float64(height)/float64(factor)) * factor
wBar := round(float64(width)/float64(factor)) * factor
tBar := round(float64(numFrames)/float64(temporalFactor)) * temporalFactor
if tBar*hBar*wBar > p.maxPixels {
beta := math.Sqrt(float64(numFrames*height*width) / float64(p.maxPixels))
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
} else if tBar*hBar*wBar < p.minPixels {
beta := math.Sqrt(float64(p.minPixels) / float64(numFrames*height*width))
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar
}
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
img = imageproc.Composite(img)
origWidth := img.Bounds().Dx()
origHeight := img.Bounds().Dy()
// Calculate smart resize dimensions
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
// Resize image
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeCatmullrom)
// Normalize pixels - output format is [C, H, W] with rescale and channelFirst
// We keep [C, H, W] for patch extraction
normalizedPixels := imageproc.Normalize(resizedImg, p.imageMean, p.imageStd, true, true)
// Calculate grid dimensions (after Conv2D patching)
grid := &Grid{
Height: resizedHeight / p.patchSize,
Width: resizedWidth / p.patchSize,
Temporal: 1, // Single image
ImageHeight: resizedHeight,
ImageWidth: resizedWidth,
}
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
if err != nil {
return nil, nil, err
}
return patches, grid, nil
}
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
channels := 3
patchSize := p.patchSize
mergeSize := p.spatialMergeSize
temporalPatchSize := p.temporalPatchSize
numPatches := grid.Temporal * grid.Height * grid.Width
patchDim := channels * temporalPatchSize * patchSize * patchSize
result := make([]float32, numPatches*patchDim)
patchIndex := 0
// Single temporal frame handling (copies to all frames)
for range grid.Temporal {
for h := 0; h < grid.Height; h += mergeSize {
for w := 0; w < grid.Width; w += mergeSize {
for mh := range mergeSize {
for mw := range mergeSize {
baseOffset := patchIndex * patchDim
for c := range channels {
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
for py := range patchSize {
for px := range patchSize {
y := (h+mh)*patchSize + py
x := (w+mw)*patchSize + px
srcIdx := c*height*width + y*width + x
dstIdx := channelOffset + (py * patchSize) + px
result[dstIdx] = pixels[srcIdx]
}
}
if temporalPatchSize > 1 {
frameSize := patchSize * patchSize
for tp := 1; tp < temporalPatchSize; tp++ {
currentFrameOffset := channelOffset + (tp * frameSize)
copy(result[currentFrameOffset:currentFrameOffset+frameSize],
result[channelOffset:channelOffset+frameSize])
}
}
}
patchIndex++
}
}
}
}
}
return result, nil
}

View File

@@ -0,0 +1,235 @@
package glmocr
import (
"bytes"
"errors"
"image"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v"`
VisionDownsample *VisionDownsample `gguf:"mm.patch_merger"`
PatchMerger *PatchMerger `gguf:"mm"`
ImageProcessor
imageTokenID int32
imageStartTokenID int32
imageEndTokenID int32
}
var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
eosTokenID := int32(c.Uint("tokenizer.ggml.eos_token_id"))
eosTokenIDs := c.Ints("tokenizer.ggml.eos_token_ids")
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: allEOS,
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
imageTokenID: int32(c.Uint("image_token_id", 59280)),
imageStartTokenID: int32(c.Uint("image_start_token_id", 59256)),
imageEndTokenID: int32(c.Uint("image_end_token_id", 59257)),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
return m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Blocks) == 0 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, grid, err := m.ImageProcessor.ProcessImage(img)
if err != nil {
return nil, err
}
// Create pixel values tensor from flattened patches
// Shape: [patchDim, numPatches]
patchDim := m.VisionModel.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
// Forward through vision encoder
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid)
// Forward through downsample (patch merger)
if m.VisionDownsample == nil || m.VisionDownsample.Weight == nil {
return nil, errors.New("glmocr: missing vision downsample weights")
}
visionOutputs = m.VisionDownsample.Forward(ctx, visionOutputs, grid, m.VisionModel.VisionModelOptions)
// Forward through patch merger (FC + LayerNorm + GELU + SwiGLU FFN)
if m.PatchMerger == nil {
return nil, errors.New("glmocr: missing patch merger weights")
}
visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.VisionModelOptions)
return []input.Multimodal{{Tensor: visionOutputs, Data: grid}}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
// Reset position cache
m.TextModel.positionCache = m.TextModel.positionCache[:0]
m.TextModel.ropeDelta = 0
pos := int32(0)
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
pos++
continue
}
// Get grid info for position calculation
grid := inp.Multimodal[0].Data.(*Grid)
mergedH := grid.Height / m.VisionModel.spatialMergeSize
mergedW := grid.Width / m.VisionModel.spatialMergeSize
// Add image start token
result = append(result, &input.Input{Token: m.imageStartTokenID})
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
pos++
// Add image tokens with multimodal data
// All image tokens share the same base position for temporal dimension
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
basePos := pos
sameBatch := tokensPerGrid - 1
if sameBatch < 0 {
sameBatch = 0
}
result = append(result, &input.Input{
Token: m.imageTokenID,
Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash,
SameBatch: sameBatch,
})
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
// Add placeholder tokens for remaining positions
// All image tokens use the same base position (temporal stays constant)
for range tokensPerGrid - 1 {
result = append(result, &input.Input{Token: m.imageTokenID})
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
}
// Advance position by max(mergedH, mergedW) after image tokens
pos = basePos + int32(max(mergedH, mergedW))
// Add image end token
result = append(result, &input.Input{Token: m.imageEndTokenID})
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
pos++
}
// Compute rope delta for continuation after the prefill segment:
// delta = (max_position_id + 1) - sequence_length
if len(m.TextModel.positionCache) > 0 {
last := m.TextModel.positionCache[len(m.TextModel.positionCache)-1]
m.TextModel.ropeDelta = last + 1 - int32(len(m.TextModel.positionCache))
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
// Initial token embedding
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
ctx.Forward(hiddenStates)
// Build position slices for M-RoPE
positionSlice := func() [][]int32 {
s := [][]int32{
make([]int32, len(batch.Positions)), // temporal
make([]int32, len(batch.Positions)), // height
make([]int32, len(batch.Positions)), // width
make([]int32, len(batch.Positions)), // unused (zeros)
}
for i, position := range batch.Positions {
// Translate through position cache or continue sequence
if position < int32(len(m.TextModel.positionCache)) {
position = m.TextModel.positionCache[position]
} else if len(m.TextModel.positionCache) > 0 {
// Continue sequence after cached positions using ropeDelta
position = position + m.TextModel.ropeDelta
}
s[0][i] = position
s[1][i] = position
s[2][i] = position
}
return s
}()
// Inject vision embeddings and adjust positions for image tokens
for _, mi := range batch.Multimodal {
img := mi.Multimodal[0].Tensor
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
w := grid.Width / m.VisionModel.spatialMergeSize
for i := range img.Dim(1) {
positionSlice[1][mi.Index+i] += int32(i / w)
positionSlice[2][mi.Index+i] += int32(i % w)
}
}
}
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
// Process through transformer layers
for i, layer := range m.TextModel.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.TextModel.Layers)-1 {
lastLayerOutputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, m.Cache, m.TextModel.TextModelOptions)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.TextModel.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("glmocr", New)
}

View File

@@ -0,0 +1,190 @@
package glmocr
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
)
type TextModelOptions struct {
hiddenSize int
numHeads int
numKVHeads int
headDim int
rotaryDim int
intermediateSize int
eps float32
ropeBase float32
mropeSections []int
}
func (o *TextModelOptions) applyMRoPE(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
// With 4 sections for [temporal, height, width, unused]
return nn.RoPE(ctx, states, positions, o.rotaryDim, o.ropeBase, 1.0, rope.WithMRoPE(o.mropeSections))
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenStates.Dim(1)
// Separate Q, K, V projections
q := sa.Query.Forward(ctx, hiddenStates)
k := sa.Key.Forward(ctx, hiddenStates)
v := sa.Value.Forward(ctx, hiddenStates)
// Reshape for GQA
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
k = k.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
v = v.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
// Apply M-RoPE (multi-resolution rotary position embeddings)
q = opts.applyMRoPE(ctx, q, positions)
k = opts.applyMRoPE(ctx, k, positions)
// Scaled dot-product attention with KV cache
scaleFactor := 1.0 / math.Sqrt(float64(opts.headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
// Reshape attention output: [headDim, numHeads, batchSize] -> [numHeads*headDim, batchSize]
// Note: numHeads * headDim = 16 * 128 = 2048, which is the attention hidden size
kqv = kqv.Reshape(ctx, opts.numHeads*opts.headDim, batchSize)
return sa.Output.Forward(ctx, kqv)
}
type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextModelOptions) ml.Tensor {
// SwiGLU: down(silu(gate(x)) * up(x))
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, gate)
}
type TextDecoderLayer struct {
// Input layernorm (before attention)
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *TextSelfAttention
// Post self-attention layernorm (after attention, before residual add)
PostAttnNorm *nn.RMSNorm `gguf:"post_attn_norm"`
// FFN input layernorm (after first residual, before MLP)
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *TextMLP
// Post MLP layernorm (after MLP, before residual add)
PostFFNNorm *nn.RMSNorm `gguf:"post_ffn_norm"`
}
func (l *TextDecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
// Attention block
residual := hiddenStates
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = l.SelfAttention.Forward(ctx, hiddenStates, positions, cache, opts)
hiddenStates = l.PostAttnNorm.Forward(ctx, hiddenStates, opts.eps)
// Prune to output positions in final layer
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
// MLP block
residual = hiddenStates
hiddenStates = l.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = l.PostFFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []TextDecoderLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextModelOptions
// positionCache stores the M-RoPE position for each token in the sequence.
// This is needed because image tokens share the same base position but have
// different height/width offsets, and the end token position depends on the
// image grid dimensions.
positionCache []int32
ropeDelta int32
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// Clear position cache when KV cache shifts
m.positionCache = nil
m.ropeDelta = 0
return m.applyMRoPE(ctx, key, shift), nil
}
func newTextModel(c fs.Config) *TextModel {
hiddenSize := int(c.Uint("embedding_length", 1536))
numHeads := int(c.Uint("attention.head_count", 16))
numKVHeads := int(c.Uint("attention.head_count_kv", 8))
intermediateSize := int(c.Uint("feed_forward_length", 4608))
eps := c.Float("attention.layer_norm_rms_epsilon", 1e-5)
ropeBase := c.Float("rope.freq_base", 10000)
headDim := int(c.Uint("attention.key_length", uint32(hiddenSize/numHeads)))
ropeDim := int(c.Uint("rope.dimension_count", uint32(headDim)))
if ropeDim <= 0 {
ropeDim = headDim
}
mropeSections := c.Ints("rope.mrope_section")
var sectionInts []int
if len(mropeSections) > 0 {
sectionInts = make([]int, len(mropeSections))
for i, section := range mropeSections {
sectionInts[i] = int(section)
}
} else {
// Default to GLM-OCR's HF ratio (2:3:3) scaled to rotaryDim/2.
// For rotaryDim=64 this yields [8, 12, 12].
total := ropeDim / 2
if total <= 0 {
total = 32
}
s0 := total * 2 / 8
s1 := total * 3 / 8
s2 := total - s0 - s1
sectionInts = []int{s0, s1, s2}
}
// GGML rope_multi: sector = (dim_pair) % sum(sections), mapping each pair to its position dim
rotaryDim := ropeDim
return &TextModel{
Layers: make([]TextDecoderLayer, c.Uint("block_count", 16)),
TextModelOptions: &TextModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
numKVHeads: numKVHeads,
headDim: headDim,
rotaryDim: rotaryDim,
intermediateSize: intermediateSize,
eps: eps,
ropeBase: ropeBase,
mropeSections: sectionInts,
},
}
}

View File

@@ -0,0 +1,355 @@
package glmocr
import (
"log/slog"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
)
type Grid struct {
Height int // Number of patches in height direction
Width int // Number of patches in width direction
Temporal int
ImageHeight int // Full image height in pixels
ImageWidth int // Full image width in pixels
}
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
numChannels int
patchSize int
temporalPatchSize int
imageSize int
spatialMergeSize int
outHiddenSize int
intermediateSize int
eps float32
}
type VisionPatchEmbed struct {
Proj *nn.Conv2D `gguf:"patch_embd_0"`
Proj1 *nn.Conv2D `gguf:"patch_embd_1"`
Bias ml.Tensor `gguf:"patch_embd.bias"`
}
func (pe *VisionPatchEmbed) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
_ = grid // patches are already in merge-block order
// pixelValues shape: [patchDim, numPatches]
numPatches := pixelValues.Shape()[1]
// Reshape to [patchSize*patchSize, temporalPatchSize, numChannels, numPatches]
pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
// Permute to [temporalPatchSize, patchSize*patchSize, numChannels, numPatches]
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Slice temporal frames for Conv2D (simulate Conv3D)
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
s0, s1 := opts.patchSize, opts.patchSize
p0, p1 := 0, 0
d0, d1 := 1, 1
hiddenStates := pe.Proj.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
if pe.Proj1 != nil && opts.temporalPatchSize > 1 {
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
out1 := pe.Proj1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
hiddenStates = hiddenStates.Add(ctx, out1)
}
// Flatten to [hidden_size, num_patches]
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, numPatches)
// Add patch bias - reshape from [hidden_size] to [hidden_size, 1] for broadcasting
if pe.Bias != nil {
hiddenStates = hiddenStates.Add(ctx, pe.Bias.Reshape(ctx, opts.hiddenSize, 1))
}
return hiddenStates
}
type VisionSelfAttention struct {
QKV *nn.Linear `gguf:"attn_qkv"`
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
batchSize := hiddenStates.Dim(1)
// Combined QKV projection: [3*hidden_size, batch_size]
qkv := sa.QKV.Forward(ctx, hiddenStates)
// Split using ChunkSections along dim 0 (handles byte offsets correctly)
// ChunkSections returns views - must make contiguous before further operations
chunks := qkv.ChunkSections(ctx, 0, opts.hiddenSize, opts.hiddenSize, opts.hiddenSize)
q := chunks[0].Contiguous(ctx)
k := chunks[1].Contiguous(ctx)
v := chunks[2].Contiguous(ctx)
// Reshape for multi-head attention: [hiddenSize, N] -> [headDim, numHeads, N]
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
k = k.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
v = v.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
// Apply Q-norm and K-norm after head reshape
// Weights are [headDim]=64, tensor is [headDim, numHeads, N]
q = sa.QNorm.Forward(ctx, q, opts.eps)
k = sa.KNorm.Forward(ctx, k, opts.eps)
// Apply rotary position embeddings with vision-style 2D positions.
// ggml's vision RoPE uses two position dimensions (H/W) with half-rotation pairs.
// We provide H/W sections and leave the remaining sections empty.
ropeFreqBase := float32(10000.0)
section := opts.headDim / 4
if section <= 0 {
section = 1
}
sections := []int{section, section, 0, 0}
q = nn.RoPE(ctx, q, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
k = nn.RoPE(ctx, k, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
// Try flash attention first (ScaledDotProductAttention), fall back to manual
if sdpa, ok := q.(ml.ScaledDotProductAttention); ok {
attention := sdpa.ScaledDotProductAttention(ctx, k, v, nil, nil, nil, scale, false)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
slog.Warn("glmocr: vision attention falling back to manual attention",
"batchSize", batchSize, "numHeads", opts.numHeads,
"hint", "set OLLAMA_FLASH_ATTENTION=1 to enable flash attention")
// Manual attention fallback
// q, k, v are [headDim, numHeads, batchSize] - GGML treats as 4D with implicit dim 3 = 1
q = q.Permute(ctx, 0, 2, 1, 3)
k = k.Permute(ctx, 0, 2, 1, 3)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
// Attention scores
kq := k.MulmatFullPrec(ctx, q)
kq = kq.Scale(ctx, scale)
kq = kq.Softmax(ctx)
// Attention output: v @ kq (note: v first)
kqv := v.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
// SwiGLU: down(silu(gate(x)) * up(x))
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, gate)
}
type VisionBlock struct {
Norm1 *nn.RMSNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
Norm2 *nn.RMSNorm `gguf:"ln2"`
MLP *VisionMLP
}
func (b *VisionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Pre-norm architecture
residual := hiddenStates
hiddenStates = b.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = b.SelfAttention.Forward(ctx, hiddenStates, positions, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = b.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = b.MLP.Forward(ctx, hiddenStates)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type VisionDownsample struct {
*nn.Conv2D
}
func (d *VisionDownsample) Forward(ctx ml.Context, hiddenStates ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
// Apply spatial downsampling via Conv2D
// Input: [hidden_size, num_patches] where patches are in merge-block order
if d.Conv2D == nil || d.Weight == nil {
slog.Error("VisionDownsample weights not loaded - model may be corrupted or incompatible")
return hiddenStates // Return input unchanged as fallback
}
merge := opts.spatialMergeSize
numOutputTokens := (grid.Height / merge) * (grid.Width / merge)
// Step 1: Reshape to [hidden_size, merge, merge, num_output_tokens]
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, merge, merge, numOutputTokens)
// Step 2: Permute to [merge, merge, hidden_size, num_output_tokens]
// ggml semantics: result.ne[perm[i]] = input.ne[i]
// So permute(2,0,1,3) on [1024,2,2,N] gives: ne[2]=1024, ne[0]=2, ne[1]=2, ne[3]=N -> [2,2,1024,N]
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
// Step 3: Apply Conv2D without bias (bias added after reshape)
// Note: ggml_conv_2d takes (kernel, input) - kernel must be receiver in ollama
s0, s1 := merge, merge
p0, p1 := 0, 0
d0, d1 := 1, 1
hiddenStates = d.Weight.Conv2D(ctx, hiddenStates, s0, s1, p0, p1, d0, d1)
// Step 4: Reshape to [out_hidden_size, num_output_tokens]
hiddenStates = hiddenStates.Reshape(ctx, opts.outHiddenSize, numOutputTokens)
// Step 5: Add bias after reshape
// Reshape bias from [out_hidden_size] to [out_hidden_size, 1] for proper broadcasting
if d.Bias != nil {
hiddenStates = hiddenStates.Add(ctx, d.Bias.Reshape(ctx, opts.outHiddenSize, 1))
}
return hiddenStates
}
type PatchMerger struct {
// GGUF tags align with mm.* keys used by the model
Proj *nn.Linear `gguf:"model.fc"` // mm.model.fc.weight
PostLN *nn.LayerNorm `gguf:"post_norm"` // mm.post_norm.weight/bias
GateProj *nn.Linear `gguf:"gate"` // mm.gate.weight
UpProj *nn.Linear `gguf:"up"` // mm.up.weight
DownProj *nn.Linear `gguf:"down"` // mm.down.weight
}
func (m *PatchMerger) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Linear projection
hiddenStates = m.Proj.Forward(ctx, hiddenStates)
// Post-projection layer norm + GELU ERF
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = hiddenStates.GELU_ERF(ctx)
// Force a copy to avoid in-place mutation issues with GELU_ERF
hiddenStates = hiddenStates.Contiguous(ctx)
// SwiGLU MLP: down(silu(gate(x)) * up(x))
gateOut := m.GateProj.Forward(ctx, hiddenStates)
upOut := m.UpProj.Forward(ctx, hiddenStates)
gate := gateOut.SILU(ctx, upOut)
return m.DownProj.Forward(ctx, gate)
}
type VisionModel struct {
PatchEmbed *VisionPatchEmbed
Blocks []VisionBlock `gguf:"blk"`
PostLN *nn.RMSNorm `gguf:"post_ln"`
// Note: Downsample is applied at the model level so mm.patch_merger stays separate
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
// Extract patch embeddings from flattened patches
hiddenStates := m.PatchEmbed.Forward(ctx, pixelValues, grid, m.VisionModelOptions)
// Create position IDs for RoPE (spatial grid)
// Patches are already in merge-block order from preprocessing
positions := m.createPositions(ctx, grid)
// Process through vision blocks
for _, block := range m.Blocks {
hiddenStates = block.Forward(ctx, hiddenStates, positions, m.VisionModelOptions)
}
// Post-layernorm
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, m.eps)
// Note: Downsample is now applied separately in Model.EncodeMultimodal
// so mm.patch_merger remains a distinct module
return hiddenStates
}
func (m *VisionModel) createPositions(ctx ml.Context, grid *Grid) ml.Tensor {
// Create spatial position IDs for vision RoPE
// Position layout: [height, width, height, width] - 4 sections for mrope
// Patches are in MERGE-BLOCK order after VisionPatchEmbed interleaving
// This follows the GLM-OCR rot_pos_emb layout
numPatches := grid.Height * grid.Width
mergeRatio := m.spatialMergeSize
// Build position arrays in merge-block order
// Each merge_ratio x merge_ratio block of patches is grouped together
hpos := make([]int32, numPatches)
wpos := make([]int32, numPatches)
ptr := 0
for y := 0; y < grid.Height; y += mergeRatio {
for x := 0; x < grid.Width; x += mergeRatio {
for dy := range mergeRatio {
for dx := range mergeRatio {
hpos[ptr] = int32(y + dy)
wpos[ptr] = int32(x + dx)
ptr++
}
}
}
}
// Build position arrays for 4 sections (mrope). ggml vision RoPE uses only H/W;
// keep remaining sections zeroed to match its conventions.
zeros := make([]int32, numPatches)
s := [][]int32{
hpos, // Section 0: height
wpos, // Section 1: width
zeros, // Section 2: unused
zeros, // Section 3: unused
}
return ctx.Input().FromInts(slices.Concat(s...), numPatches*4)
}
func newVisionModel(c fs.Config) *VisionModel {
hiddenSize := int(c.Uint("vision.embedding_length", 1024))
numHeads := int(c.Uint("vision.attention.head_count", 16))
numChannels := int(c.Uint("vision.num_channels", 3))
patchSize := int(c.Uint("vision.patch_size", 14))
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
imageSize := int(c.Uint("vision.image_size", 336))
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
outHiddenSize := int(c.Uint("vision.out_hidden_size", 1536))
intermediateSize := int(c.Uint("vision.intermediate_size", 4096))
eps := c.Float("vision.attention.layer_norm_rms_epsilon", 1e-5)
return &VisionModel{
Blocks: make([]VisionBlock, c.Uint("vision.block_count", 24)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: hiddenSize / numHeads,
numChannels: numChannels,
patchSize: patchSize,
temporalPatchSize: temporalPatchSize,
imageSize: imageSize,
spatialMergeSize: spatialMergeSize,
outHiddenSize: outHiddenSize,
intermediateSize: intermediateSize,
eps: eps,
},
}
}

View File

@@ -8,6 +8,7 @@ import (
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/glm4moelite"
_ "github.com/ollama/ollama/model/models/glmocr"
_ "github.com/ollama/ollama/model/models/gptoss"
_ "github.com/ollama/ollama/model/models/lfm2"
_ "github.com/ollama/ollama/model/models/llama"

17
model/parsers/glmocr.go Normal file
View File

@@ -0,0 +1,17 @@
package parsers
import "github.com/ollama/ollama/api"
// GlmOcrParser is the GLM46 parser with thinking disabled.
type GlmOcrParser struct {
GLM46Parser
}
func (p *GlmOcrParser) HasThinkingSupport() bool {
return false
}
func (p *GlmOcrParser) Init(tools []api.Tool, _ *api.Message, _ *api.ThinkValue) []api.Tool {
p.tools = tools
return tools
}

View File

@@ -71,6 +71,8 @@ func ParserForName(name string) Parser {
return &FunctionGemmaParser{}
case "glm-4.7":
return &GLM47Parser{}
case "glm-ocr":
return &GlmOcrParser{}
case "lfm2":
return &LFM2Parser{hasThinkingSupport: false}
case "lfm2-thinking":

109
model/renderers/glmocr.go Normal file
View File

@@ -0,0 +1,109 @@
package renderers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
type GlmOcrRenderer struct{}
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(formatGLM47ToolJSON(d))
sb.WriteString("\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
}
enableThinking := false
thinkingExplicitlySet := false
if thinkValue != nil {
enableThinking = thinkValue.Bool()
thinkingExplicitlySet = true
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>\n")
sb.WriteString(message.Content)
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
sb.WriteString("/nothink")
}
case "assistant":
sb.WriteString("<|assistant|>\n")
if message.Thinking != "" {
sb.WriteString("<think>" + strings.TrimSpace(message.Thinking) + "</think>")
} else {
sb.WriteString("<think></think>")
}
if message.Content != "" {
sb.WriteString("\n" + strings.TrimSpace(message.Content))
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>" + toolCall.Function.Name)
sb.WriteString(renderGlmOcrToolArguments(toolCall.Function.Arguments))
sb.WriteString("</tool_call>")
}
}
sb.WriteString("\n")
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("\n<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>\n")
case "system":
sb.WriteString("<|system|>\n")
sb.WriteString(message.Content)
sb.WriteString("\n")
}
}
sb.WriteString("<|assistant|>\n")
if thinkingExplicitlySet && !enableThinking {
sb.WriteString("<think></think>\n")
}
return sb.String(), nil
}
func renderGlmOcrToolArguments(args api.ToolCallFunctionArguments) string {
var sb strings.Builder
for key, value := range args.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
}
return sb.String()
}

View File

@@ -82,6 +82,8 @@ func rendererForName(name string) Renderer {
return &FunctionGemmaRenderer{}
case "glm-4.7":
return &GLM47Renderer{}
case "glm-ocr":
return &GlmOcrRenderer{}
case "lfm2":
return &LFM2Renderer{IsThinking: false}
case "lfm2-thinking":

View File

@@ -3,7 +3,7 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
"github.com/ollama/ollama/x/mlxrunner"
)
func Execute(args []string) error {
@@ -12,18 +12,18 @@ func Execute(args []string) error {
}
var newRunner bool
var imageRunner bool
var mlxRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--image-engine" {
if len(args) > 0 && args[0] == "--mlx-engine" {
args = args[1:]
imageRunner = true
mlxRunner = true
}
if imageRunner {
return imagerunner.Execute(args)
if mlxRunner {
return mlxrunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {

View File

@@ -21,7 +21,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
type LlmRequest struct {
@@ -195,14 +195,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for image generation model before attempting GGML load
// Check for image generation models - all use MLX runner
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadImageGen(pending) {
if s.loadMLX(pending) {
break
}
continue
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) {
break
}
continue
}
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -552,11 +563,20 @@ iGPUScan:
return false
}
// loadImageGen loads an image generation model.
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
// Use model name for imagegen (it resolves manifests by name, not file path)
// loadMLX loads an experimental safetensors model using the unified MLX runner.
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode mlxrunner.ModelMode
if slices.Contains(req.model.Config.Capabilities, "image") {
mode = mlxrunner.ModeImageGen
} else {
mode = mlxrunner.ModeLLM
}
// Use model name for MLX (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := imagegen.NewServer(modelName)
server, err := mlxrunner.NewServer(modelName, mode)
if err != nil {
req.errCh <- err
return true

View File

@@ -13,6 +13,7 @@ import (
"io"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/progress"
@@ -34,7 +35,7 @@ type ModelfileConfig struct {
type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "fp8" for quantization
Quantize string // "q4", "q8", "nvfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
}
@@ -53,10 +54,20 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
// Determine model type settings
var modelType, spinnerKey string
var capabilities []string
var parserName, rendererName string
if isSafetensors {
modelType = "safetensors model"
spinnerKey = "create"
capabilities = []string{"completion"}
// Check if model supports thinking based on architecture
if supportsThinking(opts.ModelDir) {
capabilities = append(capabilities, "thinking")
}
// Set parser and renderer name based on architecture
parserName = getParserName(opts.ModelDir)
rendererName = getRendererName(opts.ModelDir)
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
@@ -81,14 +92,14 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
err = create.CreateSafetensorsModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
newManifestWriter(opts, capabilities, parserName, rendererName),
progressFn,
)
} else {
err = create.CreateImageGenModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
newManifestWriter(opts, capabilities, "", ""),
progressFn,
)
}
@@ -204,7 +215,7 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error
}
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
@@ -229,6 +240,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
ModelFormat: "safetensors",
Capabilities: caps,
Requires: MinOllamaVersion,
Parser: parserName,
Renderer: rendererName,
}
configJSON, err := json.Marshal(configData)
if err != nil {
@@ -295,3 +308,146 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
return layers, nil
}
// supportsThinking checks if the model supports thinking mode based on its architecture.
// This reads the config.json from the model directory and checks the architectures field.
func supportsThinking(modelDir string) bool {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return false
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
// Check architectures that support thinking
thinkingArchitectures := []string{
"glm4moe", // GLM-4 MoE models
"deepseek", // DeepSeek models
"qwen3", // Qwen3 models
}
// Check the architecture list
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
for _, thinkArch := range thinkingArchitectures {
if strings.Contains(archLower, thinkArch) {
return true
}
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
for _, thinkArch := range thinkingArchitectures {
if strings.Contains(typeLower, thinkArch) {
return true
}
}
}
return false
}
// getParserName returns the parser name for a model based on its architecture.
// This reads the config.json from the model directory and determines the appropriate parser.
func getParserName(modelDir string) string {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return ""
}
// Check architectures for known parsers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(archLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3-coder"
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3-coder"
}
}
return ""
}
// getRendererName returns the renderer name for a model based on its architecture.
// This reads the config.json from the model directory and determines the appropriate renderer.
func getRendererName(modelDir string) string {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return ""
}
// Check architectures for known renderers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(archLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3-coder"
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3-coder"
}
}
return ""
}

View File

@@ -13,7 +13,11 @@ import (
// quantizeTensor loads a tensor from safetensors format, quantizes it,
// and returns safetensors data for the quantized weights, scales, and biases.
// Supported quantization types: "fp8" (affine 8-bit)
// Supported quantization types:
// - "q4": affine 4-bit, group_size=32 (with qbiases)
// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales)
// - "q8": affine 8-bit, group_size=64 (with qbiases)
// - "mxfp8": Microsoft MX FP8, group_size=32 (no qbiases, E4M3 scales)
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir()
@@ -54,12 +58,18 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
// Quantize based on quantization type
var qweight, scales, qbiases *mlx.Array
switch quantize {
case "fp4":
// affine mode: group_size=32, bits=4
case "q4":
// affine mode: group_size=32, bits=4 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
case "fp8":
// affine mode: group_size=32, bits=8
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
case "nvfp4":
// NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales)
qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4")
case "q8":
// affine mode: group_size=64, bits=8 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 64, 8, "affine")
case "mxfp8":
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbiases)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "mxfp8")
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}

View File

@@ -228,7 +228,7 @@ type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
// When quantize is non-empty (e.g., "q8"), returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
// ManifestWriter writes the manifest file.
@@ -262,36 +262,134 @@ func ShouldQuantize(name, component string) bool {
return strings.HasSuffix(name, ".weight")
}
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
// This is a more detailed check that also considers tensor dimensions.
func ShouldQuantizeTensor(name string, shape []int32) bool {
// The quantize parameter specifies the quantization type (e.g., "q4", "nvfp4", "q8", "mxfp8").
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
return GetTensorQuantization(name, shape, quantize) != ""
}
// normalizeQuantType converts various quantization type aliases to canonical forms.
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> q4, q8/Q8/int8/INT8/fp8/FP8 -> q8, nvfp4/NVFP4, mxfp8/MXFP8
func normalizeQuantType(quantize string) string {
switch strings.ToUpper(quantize) {
case "Q4", "INT4", "FP4":
return "q4"
case "Q8", "INT8", "FP8":
return "q8"
case "NVFP4":
return "nvfp4"
case "MXFP8":
return "mxfp8"
default:
return quantize
}
}
// getQuantGroupSize returns the group size for a given quantization type.
// These must match the values used in quantize.go when creating quantized models.
func getQuantGroupSize(quantize string) int {
switch normalizeQuantType(quantize) {
case "nvfp4":
return 16
case "q4":
return 32
case "mxfp8":
return 32
case "q8":
return 64
default:
return 32
}
}
// GetTensorQuantization returns the appropriate quantization type for a tensor.
// Returns "" if the tensor should not be quantized.
// This implements mixed-precision quantization:
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
// - Output projection, gate/up weights: q4 (less sensitive)
// - Down projection weights: q8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - Norms, embeddings, biases, routing gates: no quantization
func GetTensorQuantization(name string, shape []int32, quantize string) string {
// Use basic name-based check first
if !ShouldQuantize(name, "") {
return false
return ""
}
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
return false
return ""
}
// Skip small tensors (less than 1024 elements) - not worth quantizing
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
return false
return ""
}
// MLX quantization requires last dimension to be divisible by group size (32)
if shape[len(shape)-1]%32 != 0 {
return false
// Normalize quantization type to canonical form
quantNorm := normalizeQuantType(quantize)
// MLX quantization requires last dimension to be divisible by group size
// nvfp4: 16, q4/mxfp8: 32, q8: 64
groupSize := int32(32)
switch quantNorm {
case "nvfp4":
groupSize = 16
case "q8":
groupSize = 64
}
if shape[len(shape)-1]%groupSize != 0 {
return ""
}
return true
// Skip routing gate weights (should stay high precision)
// In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
return ""
}
// For NVFP4 or MXFP8, use the same quantization for all (no mixed precision)
if quantNorm == "nvfp4" || quantNorm == "mxfp8" {
return quantNorm
}
// Attention MLA weights - keep unquantized (bf16)
// These are highly sensitive: errors accumulate in the KV cache over time
// q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj
if strings.Contains(name, "q_a_proj") ||
strings.Contains(name, "q_b_proj") ||
strings.Contains(name, "kv_a_proj") ||
strings.Contains(name, "kv_b_proj") {
return "" // No quantization - keep bf16
}
// Down projection weights - use Q8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
if strings.Contains(name, "down_proj") {
return "q8"
}
// Output projection, gate/up weights - use requested quantization (Q4)
// o_proj, gate_proj, up_proj
if strings.Contains(name, "o_proj") ||
strings.Contains(name, "gate_proj") ||
strings.Contains(name, "up_proj") {
return quantNorm
}
// LM head - use requested quantization
if strings.Contains(name, "lm_head") {
return quantNorm
}
// Default to requested quantization for other weights
return quantNorm
}
// CreateSafetensorsModel imports a standard safetensors model from a directory.
// This handles Hugging Face style models with config.json and *.safetensors files.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
// If quantize is non-empty (e.g., "q8"), eligible tensors will be quantized.
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
@@ -330,9 +428,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
}
// Determine quantization type for this tensor (empty string if not quantizing)
// GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
quantizeType := ""
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
quantizeType = quantize
if quantize != "" {
quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize)
}
// Store as minimal safetensors format (88 bytes header overhead)
@@ -388,6 +487,23 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
return fmt.Errorf("config.json not found in %s", modelDir)
}
// Create model_index.json with quantization info if quantizing
if quantize != "" {
modelIndex := map[string]any{
"quantization": strings.ToUpper(quantize),
"group_size": getQuantGroupSize(quantize),
}
indexData, err := json.MarshalIndent(modelIndex, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal model_index.json: %w", err)
}
indexLayer, err := createLayer(strings.NewReader(string(indexData)), "application/vnd.ollama.image.json", "model_index.json")
if err != nil {
return fmt.Errorf("failed to create model_index.json layer: %w", err)
}
layers = append(layers, indexLayer)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {

View File

@@ -536,41 +536,51 @@ func TestShouldQuantize(t *testing.T) {
func TestShouldQuantizeTensor(t *testing.T) {
tests := []struct {
name string
tensor string
shape []int32
want bool
name string
tensor string
shape []int32
quantize string
want bool
}{
// 2D tensors with sufficient size should be quantized
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
// Small tensors should not be quantized (< 1024 elements)
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
{"small 2D weight", "small.weight", []int32{31, 31}, false},
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
{"small 2D weight", "small.weight", []int32{31, 31}, "fp8", false},
// 1D tensors should not be quantized
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
{"1D tensor", "layer_norm.weight", []int32{4096}, "fp8", false},
// 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
// Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
// Norms should not be quantized regardless of shape
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, "fp8", false},
// Biases should not be quantized
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
// Group size divisibility tests
// FP8/FP4 require divisible by 32
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
// NVFP4 requires divisible by 16
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
got := ShouldQuantizeTensor(tt.tensor, tt.shape, tt.quantize)
if got != tt.want {
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
t.Errorf("ShouldQuantizeTensor(%q, %v, %q) = %v, want %v", tt.tensor, tt.shape, tt.quantize, got, tt.want)
}
})
}
@@ -741,7 +751,7 @@ func TestCreateImageGenModel_WithQuantize(t *testing.T) {
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
err := CreateImageGenModel("test-imagegen", dir, "q8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}

View File

@@ -15,15 +15,15 @@ import (
// CreateImageGenModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
// Supported quantization types: fp8 (or empty for no quantization).
// Supported quantization types: q4, q8, nvfp4, mxfp8 (or empty for no quantization).
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Validate quantization type
switch quantize {
case "", "fp4", "fp8":
case "", "q4", "q8", "nvfp4", "mxfp8":
// valid
default:
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
return fmt.Errorf("unsupported quantization type %q: supported types are q4, q8, nvfp4, mxfp8", quantize)
}
var layers []LayerInfo
@@ -89,7 +89,7 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
// Determine quantization type for this tensor (empty string if not quantizing)
quantizeType := ""
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape, quantize) {
quantizeType = quantize
}
@@ -213,10 +213,18 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
}
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
// MLX requires the last dimension to be divisible by the group size (32).
func canQuantizeShape(shape []int32) bool {
// MLX requires the last dimension to be divisible by the group size.
// nvfp4: 16, q4/mxfp8: 32, q8: 64
func canQuantizeShape(shape []int32, quantize string) bool {
if len(shape) < 2 {
return false
}
return shape[len(shape)-1]%32 == 0
groupSize := int32(32)
switch strings.ToUpper(quantize) {
case "NVFP4":
groupSize = 16
case "Q8":
groupSize = 64
}
return shape[len(shape)-1]%groupSize == 0
}

View File

@@ -9,6 +9,7 @@ type Cache interface {
Offset() int
Len() int
State() []*mlx.Array
Reset()
}
type KVCache struct {
@@ -63,6 +64,13 @@ func (c *KVCache) State() []*mlx.Array {
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// Reset clears the cache state for a new generation session
func (c *KVCache) Reset() {
c.keys = nil
c.values = nil
c.offset = 0
}
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
keys, values *mlx.Array
@@ -154,3 +162,11 @@ func (c *RotatingKVCache) State() []*mlx.Array {
func (c *RotatingKVCache) Offset() int { return c.offset }
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
// Reset clears the cache state for a new generation session
func (c *RotatingKVCache) Reset() {
c.keys = nil
c.values = nil
c.offset = 0
c.idx = 0
}

View File

@@ -102,14 +102,17 @@ func (m *ModelManifest) BlobPath(digest string) string {
return filepath.Join(m.BlobDir, blobName)
}
// GetTensorLayers returns all tensor layers for a given component.
// Component should be "text_encoder", "transformer", or "vae".
// Tensor names are path-style: "component/tensor_name" (e.g., "text_encoder/model.embed_tokens.weight").
// GetTensorLayers returns tensor layers, optionally filtered by component.
// If component is empty, returns all tensor layers (for LLM models).
// If component is specified (e.g., "text_encoder", "transformer", "vae"),
// returns only layers with that prefix.
func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer {
prefix := component + "/"
var layers []ManifestLayer
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" && strings.HasPrefix(layer.Name, prefix) {
if layer.MediaType != "application/vnd.ollama.image.tensor" {
continue
}
if component == "" || strings.HasPrefix(layer.Name, component+"/") {
layers = append(layers, layer)
}
}
@@ -206,7 +209,7 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
info.Quantization = "Q8"
break
}
}

View File

@@ -991,6 +991,19 @@ func Concat(a, b *Array, axis int) *Array {
return Concatenate([]*Array{a, b}, axis)
}
// Stack stacks arrays along a new axis (axis 0 by default)
func Stack(arrays []*Array, axis int) *Array {
handles := make([]C.mlx_array, len(arrays))
for i, arr := range arrays {
handles[i] = arr.c
}
vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
res := C.mlx_array_new()
C.mlx_stack_axis(&res, vec, C.int(axis), C.default_stream())
C.mlx_vector_array_free(vec)
return newArray(res)
}
// Slice slices the array
func Slice(a *Array, start, stop []int32) *Array {
n := len(start)

View File

@@ -0,0 +1,840 @@
//go:build mlx
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
package glm4_moe_lite
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// RopeScaling holds RoPE scaling configuration
type RopeScaling struct {
Factor float32 `json:"factor"`
MscaleAllDim float32 `json:"mscale_all_dim"`
}
// Config holds GLM4-MoE-Lite model configuration
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
AttentionBias bool `json:"attention_bias"`
// MLA (Multi-head Latent Attention) parameters
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
// MoE parameters
NRoutedExperts int32 `json:"n_routed_experts"`
NSharedExperts int32 `json:"n_shared_experts"`
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
NormTopKProb bool `json:"norm_topk_prob"`
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
NGroup int32 `json:"n_group"`
TopKGroup int32 `json:"topk_group"`
// RoPE scaling
RopeScaling *RopeScaling `json:"rope_scaling"`
// Quantization parameters (set during load based on model quantization)
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
QuantBits int `json:"-"` // Bits per weight (4 or 8)
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
// Computed fields
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
}
// MLAAttention implements Multi-head Latent Attention with absorption.
// This uses absorbed MLA which operates in latent space for reduced KV cache.
type MLAAttention struct {
// Low-rank query projections
QAProj nn.LinearLayer `weight:"self_attn.q_a_proj"`
QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"`
QBProj nn.LinearLayer `weight:"self_attn.q_b_proj"`
// Low-rank KV projections (with shared rope component)
KVAProjWithMQA nn.LinearLayer `weight:"self_attn.kv_a_proj_with_mqa"`
KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"`
// Absorbed MLA projections (derived from kv_b_proj)
// EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
// UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
EmbedQ *nn.MultiLinear `weight:"-"`
UnembedOut *nn.MultiLinear `weight:"-"`
// Output projection
OProj nn.LinearLayer `weight:"self_attn.o_proj"`
}
// Forward computes absorbed MLA attention output.
// This operates in latent space for reduced KV cache memory.
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
// Query path: q_a_proj -> layernorm -> q_b_proj
q := a.QAProj.Forward(x)
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
q = a.QBProj.Forward(q)
// Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim]
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
// Split Q into nope and rope parts
qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
// KV path: get compressed KV and k_pe
compressedKV := a.KVAProjWithMQA.Forward(x)
// Split into compressed_kv and k_pe (shared rope component)
kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
// k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim]
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
// Apply layernorm to get kv latent representation
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
// kvLatent: [B, L, kv_lora_rank] -> [B, 1, L, kv_lora_rank] for broadcasting
kvLatent = mlx.ExpandDims(kvLatent, 1)
// Apply RoPE to the rope parts
offset := 0
if c != nil {
offset = c.Offset()
}
qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
// ABSORBED MLA: project q_nope to latent space
// qNope: [B, num_heads, L, qk_nope_head_dim]
// EmbedQ: [num_heads, kv_lora_rank, qk_nope_head_dim]
// Result: [B, num_heads, L, kv_lora_rank]
qLatent := a.EmbedQ.Forward(qNope)
// Keys = concat(kvLatent, kPE)
// kvLatent: [B, 1, L, kv_lora_rank]
// kPE: [B, 1, L, qk_rope_head_dim]
// keys: [B, 1, L, kv_lora_rank + qk_rope_head_dim]
keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
// Cache the smaller latent representation
// We cache keys (latent + rope) and use empty values since values are derived from keys
cachedL := L
if c != nil {
// Create placeholder values with 0 dims for cache (we don't actually use cached values)
placeholderValues := mlx.Zeros([]int32{B, 1, L, 0}, mlx.DtypeFloat32)
keys, _ = c.Update(keys, placeholderValues, int(L))
cachedL = int32(keys.Shape()[2])
}
// Values are the first kv_lora_rank dims of keys (slice off rope part)
values := mlx.Slice(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank})
// Queries = concat(qLatent, qPE)
// qLatent: [B, num_heads, L, kv_lora_rank]
// qPE: [B, num_heads, L, qk_rope_head_dim]
// queries: [B, num_heads, L, kv_lora_rank + qk_rope_head_dim]
queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
// Attention in latent space
// queries: [B, num_heads, L, kv_lora_rank + rope_dim]
// keys: [B, 1, cachedL, kv_lora_rank + rope_dim]
// values: [B, 1, cachedL, kv_lora_rank]
out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1)
// ABSORBED MLA: unembed from latent space
// out: [B, num_heads, L, kv_lora_rank]
// UnembedOut: [num_heads, v_head_dim, kv_lora_rank]
// Result: [B, num_heads, L, v_head_dim]
out = a.UnembedOut.Forward(out)
// Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim]
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
return a.OProj.Forward(out)
}
// DenseMLP implements the standard SwiGLU MLP for dense layers
type DenseMLP struct {
GateProj nn.LinearLayer `weight:"mlp.gate_proj"`
UpProj nn.LinearLayer `weight:"mlp.up_proj"`
DownProj nn.LinearLayer `weight:"mlp.down_proj"`
}
// Forward applies the SwiGLU MLP
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(m.GateProj.Forward(x))
up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up))
}
// MoEGate implements the expert gating mechanism
type MoEGate struct {
Gate nn.LinearLayer `weight:"mlp.gate"`
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
}
// Forward computes expert selection indices and scores
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
// Compute gate logits through linear layer (handles both quantized and non-quantized)
gates := g.Gate.Forward(x)
// Sigmoid scoring
scores := mlx.Sigmoid(gates)
origScores := scores
// Add correction bias if present
if g.EScoreCorrectionBias != nil {
scores = mlx.Add(scores, g.EScoreCorrectionBias)
}
// Group-wise expert selection (simplified for n_group=1)
// Select top-k experts
topK := cfg.NumExpertsPerTok
negScores := mlx.Neg(scores)
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
shape := inds.Shape()
inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK})
// Get scores for selected experts
scores = mlx.TakeAlongAxis(origScores, inds, -1)
// Normalize if configured
if topK > 1 && cfg.NormTopKProb {
sumScores := mlx.Sum(scores, -1, true)
scores = mlx.Div(scores, sumScores)
}
// Apply routing scaling factor
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
return inds, scores
}
// SwitchMLP implements the MoE expert computation using stacked weights
// Note: No weight tags - these are populated manually by stacking expert weights
type SwitchMLP struct {
// Dequantized weights (used when GatherQMM not available)
GateWeight *mlx.Array
UpWeight *mlx.Array
DownWeight *mlx.Array
// Quantized weights (used with GatherQMM for 4/8-bit affine)
GateWeightQ, GateScales, GateBiases *mlx.Array
UpWeightQ, UpScales, UpBiases *mlx.Array
DownWeightQ, DownScales, DownBiases *mlx.Array
// Quantization bits per projection (supports mixed precision Q4/Q8)
GateBits int
UpBits int
DownBits int
// Quantization group size per projection (detected from tensor shapes)
GateGroupSize int
UpGroupSize int
DownGroupSize int
// If true, use GatherQMM with quantized weights
UseQuantized bool
}
// Forward applies the switched expert MLP
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
shape := x.Shape()
B, L := shape[0], shape[1]
topK := cfg.NumExpertsPerTok
// Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D]
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
// Flatten for gather_mm: [B*L, 1, 1, D]
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
// Flatten indices: [B, L, topK] -> [B*L, topK]
idxFlat := mlx.Reshape(indices, B*L, topK)
// Sort for efficient gather (when we have many tokens)
doSort := B*L >= 64
var invOrder *mlx.Array
n := B * L * topK
if doSort {
idxAll := mlx.Flatten(idxFlat)
order := mlx.Argsort(idxAll, 0)
invOrder = mlx.Argsort(order, 0)
// Reorder x based on sorted indices
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
}
var gate, up, hidden, down *mlx.Array
if s.UseQuantized {
// Use GatherQMM for quantized weights (faster, keeps weights quantized)
// Each projection may have different bits and group sizes (mixed precision: Q4 for gate/up, Q8 for down)
gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
} else {
// Use GatherMM for dequantized/non-quantized weights
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
}
// Unsort if we sorted
if doSort {
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
} else {
down = mlx.Squeeze(down, 2)
}
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
}
// SharedExperts implements the shared expert MLP
type SharedExperts struct {
GateProj nn.LinearLayer `weight:"mlp.shared_experts.gate_proj"`
UpProj nn.LinearLayer `weight:"mlp.shared_experts.up_proj"`
DownProj nn.LinearLayer `weight:"mlp.shared_experts.down_proj"`
}
// Forward applies the shared expert MLP
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(s.GateProj.Forward(x))
up := s.UpProj.Forward(x)
return s.DownProj.Forward(mlx.Mul(gate, up))
}
// MoE implements the full Mixture of Experts layer
type MoE struct {
Gate *MoEGate
SwitchMLP *SwitchMLP
SharedExperts *SharedExperts
}
// Forward applies the MoE layer
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
shape := x.Shape()
B, L := shape[0], shape[1]
// Get expert indices and scores
inds, scores := m.Gate.Forward(x, cfg)
// Apply routed experts
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
// Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK
scoresExpanded := mlx.ExpandDims(scores, -1)
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
// Add shared experts if present
if m.SharedExperts != nil {
y = mlx.Add(y, m.SharedExperts.Forward(x))
}
return mlx.Reshape(y, B, L, cfg.HiddenSize)
}
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
type DenseBlock struct {
Attention *MLAAttention
MLP *DenseMLP
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the dense block
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
// Pre-norm attention with residual
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
h := mlx.Add(x, r)
// Pre-norm MLP with residual
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
return mlx.Add(h, r)
}
// MoEBlock represents a MoE transformer block
type MoEBlock struct {
Attention *MLAAttention
MoE *MoE
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the MoE block
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
// Pre-norm attention with residual
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
h := mlx.Add(x, r)
// Pre-norm MoE with residual
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
return mlx.Add(h, r)
}
// Block interface for both dense and MoE blocks
type Block interface {
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
}
// Model represents the complete GLM4-MoE-Lite model
type Model struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []Block `weight:"-"` // Loaded manually due to different block types
Norm *nn.RMSNorm `weight:"model.norm"`
LMHead nn.LinearLayer `weight:"lm_head"`
tok *tokenizer.Tokenizer
*Config
}
// computeScale computes the attention scale.
// Uses the full key head dimension (qkNopeHeadDim + qkRopeHeadDim) to match the Ollama runner.
func computeScale(cfg *Config) float32 {
keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
scale := float32(1.0 / math.Sqrt(float64(keyLength)))
if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
scale *= s * s
}
return scale
}
// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
// Currently only 4-bit and 8-bit affine quantization are supported.
func supportsGatherQMM(mode string, bits int) bool {
return mode == "affine" && (bits == 4 || bits == 8)
}
// ExpertWeight holds a single expert's weight with optional quantization components.
type ExpertWeight struct {
Weight *mlx.Array // Quantized weight (if quantized) or dequantized weight
Scales *mlx.Array // Quantization scales (nil if not quantized)
Biases *mlx.Array // Quantization biases (nil if not quantized or mode doesn't use biases)
Bits int // Quantization bits (4 or 8), 0 if not quantized
GroupSize int // Quantization group size, 0 if not quantized
}
// getQuantParams returns quantization parameters from model metadata.
// Returns groupSize, bits, and mode for the model's quantization type.
func getQuantParams(weights safetensors.WeightSource) (groupSize, bits int, mode string) {
groupSize, bits, mode = safetensors.QuantizationParams(weights.Quantization())
// Use metadata group_size if available (overrides default)
if gs := weights.GroupSize(); gs > 0 {
groupSize = gs
}
return groupSize, bits, mode
}
// loadExpertWeight loads an expert weight.
// If useQuantized is true and the weight is quantized with a supported mode, returns quantized components.
// Otherwise dequantizes and returns only the weight.
func loadExpertWeight(weights safetensors.WeightSource, path string, useQuantized bool, cfg *Config) *ExpertWeight {
w, _ := weights.GetTensor(path + ".weight")
if w == nil {
return nil
}
// Check if this is a quantized weight by looking for scales
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
scales, _ := weights.GetTensor(scalePath)
var qbiases *mlx.Array
qbiasPath := path + ".weight_qbias"
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
// Get quantization params from metadata
groupSize, bits, mode := getQuantParams(weights)
// Update config with group size (for GatherQMM calls)
if cfg.QuantGroupSize == 0 {
cfg.QuantGroupSize = groupSize
}
// If GatherQMM is supported and requested, return quantized components
if useQuantized && supportsGatherQMM(mode, bits) {
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
}
// Otherwise dequantize
return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
}
return &ExpertWeight{Weight: w}
}
// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
// Returns embed_q and unembed_out weights for per-head projections.
//
// kv_b_proj.weight shape: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
// Output:
// - embed_q: [num_heads, kv_lora_rank, qk_nope_head_dim] - projects q_nope to latent
// - unembed_out: [num_heads, v_head_dim, kv_lora_rank] - projects latent to output
func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
path := prefix + ".self_attn.kv_b_proj"
w, err := weights.GetTensor(path + ".weight")
if err != nil || w == nil {
return nil, nil
}
// Check if quantized and dequantize
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
scales, _ := weights.GetTensor(scalePath)
var qbiases *mlx.Array
qbiasPath := path + ".weight_qbias"
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
groupSize, bits, mode := getQuantParams(weights)
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
}
// w: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
// Reshape to [num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
// Split into wk and wv
// wk: [num_heads, qk_nope_head_dim, kv_lora_rank]
// wv: [num_heads, v_head_dim, kv_lora_rank]
wk := mlx.Slice(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
wv := mlx.Slice(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
// Transform for absorbed MLA:
// embed_q: transpose(wk) -> [num_heads, kv_lora_rank, qk_nope_head_dim]
// This allows: q_nope @ embed_q.T = q_nope @ wk (absorbed key projection)
embedQ := mlx.Transpose(wk, 0, 2, 1)
// unembed_out: wv stays [num_heads, v_head_dim, kv_lora_rank]
// This allows: latent_out @ unembed_out.T = latent_out @ wv.T (absorbed value projection)
unembedOut := wv
return embedQ, unembedOut
}
// StackedExpertWeights holds stacked weights for all experts.
type StackedExpertWeights struct {
Weight *mlx.Array // Stacked weights [num_experts, out, in] or [num_experts, out, in_packed]
Scales *mlx.Array // Stacked scales (nil if not quantized)
Biases *mlx.Array // Stacked biases (nil if not quantized)
Bits int // Quantization bits (4 or 8), 0 if not quantized
GroupSize int // Quantization group size, 0 if not quantized
}
// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
func collectAndStackExpertWeights(
weights safetensors.WeightSource,
prefix string,
projName string,
numExperts int32,
useQuantized bool,
cfg *Config,
) *StackedExpertWeights {
var w, s, b []*mlx.Array
var bits, groupSize int
for e := int32(0); e < numExperts; e++ {
path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
ew := loadExpertWeight(weights, path, useQuantized, cfg)
if ew == nil {
continue
}
w = append(w, ew.Weight)
if ew.Scales != nil {
s = append(s, ew.Scales)
}
if ew.Biases != nil {
b = append(b, ew.Biases)
}
if e == 0 {
bits = ew.Bits
groupSize = ew.GroupSize
}
}
result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
if len(w) > 0 {
result.Weight = mlx.Stack(w, 0)
if len(s) > 0 {
result.Scales = mlx.Stack(s, 0)
}
if len(b) > 0 {
result.Biases = mlx.Stack(b, 0)
}
}
return result
}
// sanitizeExpertWeights stacks individual expert weights into tensors.
// If useQuantized is true and weights support GatherQMM, returns quantized components.
// Otherwise returns dequantized weights with nil scales/biases.
// Bits and GroupSize are detected per-weight to support mixed-precision (Q4 for gate/up, Q8 for down).
func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
gate = collectAndStackExpertWeights(weights, prefix, "gate_proj", numExperts, useQuantized, cfg)
up = collectAndStackExpertWeights(weights, prefix, "up_proj", numExperts, useQuantized, cfg)
down = collectAndStackExpertWeights(weights, prefix, "down_proj", numExperts, useQuantized, cfg)
return gate, up, down
}
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
// Read config from manifest
configData, err := manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Compute derived fields
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
cfg.Scale = computeScale(&cfg)
// Load weights from manifest blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "")
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
if err := weights.Load(0); err != nil {
return nil, fmt.Errorf("load weight data: %w", err)
}
// Set up quantization parameters (only if model is actually quantized)
// Note: QuantGroupSize will be detected dynamically from tensor shapes during weight loading
quantization := weights.Quantization()
useQuantized := false
if quantization != "" {
_, cfg.QuantBits, cfg.QuantMode = safetensors.QuantizationParams(quantization)
useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
}
// Load tokenizer from manifest with config files for EOS token detection
tokData, err := manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
// Build tokenizer config with companion files for EOS/BOS token loading
tokConfig := &tokenizer.TokenizerConfig{
ConfigJSON: configData, // Already loaded above, contains eos_token_id
}
// Try to load generation_config.json if available (preferred source for EOS)
if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
// Try to load tokenizer_config.json if available
if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]Block, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
// Load embedding, norm, and lm_head
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Load layers manually due to different block types
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
// Load attention (same for both block types)
attn := &MLAAttention{}
if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d attention: %w", i, err)
}
// Sanitize MLA weights for absorbed attention
embedQ, unembedOut := sanitizeMLAWeights(weights, prefix, &cfg)
attn.EmbedQ = nn.NewMultiLinear(embedQ)
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
if i < cfg.FirstKDenseReplace {
// Dense block
block := &DenseBlock{Attention: attn}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d dense: %w", i, err)
}
m.Layers[i] = block
} else {
// MoE block
block := &MoEBlock{Attention: attn}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
}
// Stack expert weights (pass cfg so group sizes can be detected)
gate, up, down := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts, useQuantized, &cfg)
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
if useQuantized {
switchMLP.GateWeightQ = gate.Weight
switchMLP.GateScales = gate.Scales
switchMLP.GateBiases = gate.Biases
switchMLP.GateBits = gate.Bits
switchMLP.GateGroupSize = gate.GroupSize
switchMLP.UpWeightQ = up.Weight
switchMLP.UpScales = up.Scales
switchMLP.UpBiases = up.Biases
switchMLP.UpBits = up.Bits
switchMLP.UpGroupSize = up.GroupSize
switchMLP.DownWeightQ = down.Weight
switchMLP.DownScales = down.Scales
switchMLP.DownBiases = down.Biases
switchMLP.DownBits = down.Bits
switchMLP.DownGroupSize = down.GroupSize
} else {
switchMLP.GateWeight = gate.Weight
switchMLP.UpWeight = up.Weight
switchMLP.DownWeight = down.Weight
}
block.MoE = &MoE{
Gate: &MoEGate{},
SwitchMLP: switchMLP,
}
// Load gate weights
if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d gate: %w", i, err)
}
// Load shared experts if present
if cfg.NSharedExperts > 0 {
block.MoE.SharedExperts = &SharedExperts{}
if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
}
}
m.Layers[i] = block
}
}
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
return m, nil
}
// Forward computes the forward pass of the model
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
}
h = m.Norm.Forward(h, m.RMSNormEps)
return m.LMHead.Forward(h)
}
// Interface methods
// NumLayers returns the number of transformer layers
func (m *Model) NumLayers() int { return len(m.Layers) }
// MaxContextLength returns the maximum context length
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
// VocabSize returns the vocabulary size
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
// Tokenizer returns the model's tokenizer
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
// NewCache creates a new KV cache for the model
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
// This follows the GLM-4.7 format with <think> tag for reasoning mode.
func (m *Model) FormatPrompt(prompt string) string {
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
}
// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
// When think is true, the prompt ends with <think> to enable reasoning mode.
// When think is false, the prompt ends with </think> to skip reasoning.
func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
if think {
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
}
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|></think>"
}
// NewRenderer returns a new Renderer for formatting multi-turn conversations.
func (m *Model) NewRenderer() *Renderer {
return &Renderer{}
}
// NewParser returns a new Parser for extracting thinking and tool calls from output.
func (m *Model) NewParser() *Parser {
return &Parser{}
}

View File

@@ -0,0 +1,479 @@
//go:build mlx
package glm4_moe_lite
import (
"context"
"encoding/json"
"encoding/xml"
"fmt"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type parserState int
const (
parserState_LookingForThinkingOpen parserState = iota
parserState_ThinkingStartedEatingWhitespace
parserState_CollectingThinking
parserState_ThinkingDoneEatingWhitespace
parserState_CollectingContent
parserState_ToolStartedEatingWhitespace
parserState_CollectingToolContent
)
const (
thinkingOpenTag = "<think>"
thinkingCloseTag = "</think>"
toolOpenTag = "<tool_call>"
toolCloseTag = "</tool_call>"
)
// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
// GLM-4's prompt ends with <think> when thinking is enabled, so the parser
// must start in CollectingThinking state (the model outputs thinking content directly).
type Parser struct {
state parserState
buffer strings.Builder
tools []api.Tool
}
// HasToolSupport returns true as GLM4 supports tool calling.
func (p *Parser) HasToolSupport() bool {
return true
}
// HasThinkingSupport returns true as GLM4 supports thinking mode.
func (p *Parser) HasThinkingSupport() bool {
return true
}
// Init initializes the parser with tools and thinking configuration.
func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
// When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() {
p.state = parserState_CollectingThinking
}
return tools
}
type parserEvent interface {
isParserEvent()
}
type eventContent struct {
content string
}
func (eventContent) isParserEvent() {}
type eventRawToolCall struct {
raw string
}
func (eventRawToolCall) isParserEvent() {}
type eventThinkingContent struct {
content string
}
func (eventThinkingContent) isParserEvent() {}
// Add processes new output text and returns parsed content, thinking, and tool calls.
func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case eventRawToolCall:
toolCall, err := parseToolCall(event, p.tools)
if err != nil {
slog.Warn("glm-4 tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case eventThinkingContent:
thinkingSb.WriteString(event.content)
case eventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *Parser) parseEvents() []parserEvent {
var all []parserEvent
keepLooping := true
for keepLooping {
var events []parserEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
// and transitions to the next state. Returns (nil, false) if only whitespace remains
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false // Still only whitespace, keep waiting for more input
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true // Successfully transitioned
}
// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
// the content after (optionally trimmed of leading whitespace), and updates the buffer
func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
p.buffer.Reset()
p.buffer.WriteString(after)
return before, after
}
func (p *Parser) eat() ([]parserEvent, bool) {
var events []parserEvent
switch p.state {
case parserState_LookingForThinkingOpen:
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, thinkingOpenTag) {
// Found <think> opening tag
after := strings.TrimPrefix(trimmed, thinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
p.state = parserState_ThinkingStartedEatingWhitespace
} else {
p.state = parserState_CollectingThinking
}
return events, true
} else if strings.HasPrefix(thinkingOpenTag, trimmed) {
// Partial opening tag seen, keep accumulating
return events, false
} else if trimmed == "" {
// Only whitespace, keep accumulating
return events, false
} else {
// No thinking tag found, skip to content collection
p.state = parserState_CollectingContent
// Don't trim - we want to keep the original content
return events, true
}
case parserState_ThinkingStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
case parserState_CollectingThinking:
acc := p.buffer.String()
if strings.Contains(acc, thinkingCloseTag) {
thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, eventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = parserState_ThinkingDoneEatingWhitespace
} else {
p.state = parserState_CollectingContent
}
return events, true
} else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
// Partial closing tag - withhold it along with any trailing whitespace before it
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventThinkingContent{content: unambiguous})
}
return events, false
} else {
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventThinkingContent{content: unambiguous})
}
return events, false
}
case parserState_ThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
case parserState_CollectingContent:
if strings.Contains(p.buffer.String(), toolOpenTag) {
before, after := p.splitAtTag(toolOpenTag, true)
if len(before) > 0 {
events = append(events, eventContent{content: before})
}
if after == "" {
p.state = parserState_ToolStartedEatingWhitespace
} else {
p.state = parserState_CollectingToolContent
}
return events, true
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventContent{content: unambiguous})
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventContent{content: unambiguous})
}
return events, false
}
case parserState_ToolStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
case parserState_CollectingToolContent:
acc := p.buffer.String()
if strings.Contains(acc, toolCloseTag) {
toolContent, _ := p.splitAtTag(toolCloseTag, true)
if len(toolContent) == 0 {
slog.Warn("glm4 tool call closing tag found but no content before it")
}
events = append(events, eventRawToolCall{raw: toolContent})
p.state = parserState_CollectingContent
return events, true
} else {
// Keep accumulating - tool calls are not streamed
// We just wait for the closing tag
return events, false
}
default:
panic("unreachable")
}
}
// overlap returns the length of the overlap between the end of s and the start of tag.
func overlap(s, tag string) int {
for i := 1; i <= len(tag) && i <= len(s); i++ {
if strings.HasSuffix(s, tag[:i]) {
return i
}
}
return 0
}
// trailingWhitespaceLen returns the length of trailing whitespace in s.
func trailingWhitespaceLen(s string) int {
trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
return len(s) - len(trimmed)
}
// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
type ToolCallXML struct {
XMLName xml.Name `xml:"tool_call"`
Content string `xml:",chardata"` // Function name (text nodes between tags)
Keys []string `xml:"arg_key"` // All arg_key elements in document order
Values []string `xml:"arg_value"` // All arg_value elements in document order
}
// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
func escapeContent(s string) string {
var result strings.Builder
inTag := false
for i := range len(s) {
ch := s[i]
if ch == '<' {
// Check if this is a known tag
if strings.HasPrefix(s[i:], "<arg_key>") ||
strings.HasPrefix(s[i:], "</arg_key>") ||
strings.HasPrefix(s[i:], "<arg_value>") ||
strings.HasPrefix(s[i:], "</arg_value>") {
inTag = true
}
}
if inTag {
result.WriteByte(ch)
if ch == '>' {
inTag = false
}
} else {
// Escape special characters in text content
switch ch {
case '&':
result.WriteString("&amp;")
case '<':
result.WriteString("&lt;")
case '>':
result.WriteString("&gt;")
default:
result.WriteByte(ch)
}
}
}
return result.String()
}
func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
// Escape any unescaped entities in text content
escaped := escapeContent(raw.raw)
// Wrap the content in a root element to make it valid XML
xmlString := "<tool_call>" + escaped + "</tool_call>"
// Parse XML into struct
var parsed ToolCallXML
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
}
// Extract and trim function name
functionName := strings.TrimSpace(parsed.Content)
if functionName == "" {
return api.ToolCall{}, fmt.Errorf("empty function name")
}
// Verify keys and values are paired correctly
if len(parsed.Keys) != len(parsed.Values) {
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
}
// Find the matching tool to get parameter types
var matchedTool *api.Tool
for i := range tools {
if tools[i].Function.Name == functionName {
matchedTool = &tools[i]
break
}
}
// Build arguments map by pairing keys and values
toolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: functionName,
Arguments: api.NewToolCallFunctionArguments(),
},
}
for i := range parsed.Keys {
key := strings.TrimSpace(parsed.Keys[i])
value := parsed.Values[i] // Don't trim here - parseValue handles it
// Look up parameter type
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
// Handle anyOf by collecting all types from the union
if len(prop.AnyOf) > 0 {
for _, anyOfProp := range prop.AnyOf {
paramType = append(paramType, anyOfProp.Type...)
}
} else {
paramType = prop.Type
}
}
}
// Parse value with type coercion
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
}
return toolCall, nil
}
// parseValue parses a string value and coerces it to the appropriate type based on paramType.
func parseValue(value string, paramType api.PropertyType) any {
value = strings.TrimSpace(value)
// If no type specified, return as string
if len(paramType) == 0 {
return value
}
// Try to parse based on specified types
for _, t := range paramType {
switch t {
case "boolean":
if value == "true" {
return true
}
if value == "false" {
return false
}
case "integer":
var i int64
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
return i
}
case "number":
var f float64
if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
return f
}
case "array", "object":
// Try to parse as JSON
var result any
if err := json.Unmarshal([]byte(value), &result); err == nil {
return result
}
}
}
// Default to string
return value
}

View File

@@ -0,0 +1,192 @@
//go:build mlx
package glm4_moe_lite
import (
"testing"
"github.com/ollama/ollama/api"
)
func TestParserThinking(t *testing.T) {
tests := []struct {
name string
input string
thinkEnabled bool
wantContent string
wantThinking string
wantToolCalls int
}{
{
name: "thinking enabled - simple thinking then content",
input: "Let me think about this...</think>Here is my answer.",
thinkEnabled: true,
wantThinking: "Let me think about this...",
wantContent: "Here is my answer.",
},
{
name: "thinking enabled - only thinking",
input: "I need to consider multiple factors...",
thinkEnabled: true,
wantThinking: "I need to consider multiple factors...",
wantContent: "",
},
{
name: "thinking disabled - direct content",
input: "Here is my direct answer.",
thinkEnabled: false,
wantThinking: "",
wantContent: "Here is my direct answer.",
},
{
name: "thinking with tool call",
input: "Let me search for that...</think>I'll use a tool.<tool_call>search<arg_key>query</arg_key><arg_value>test</arg_value></tool_call>",
thinkEnabled: true,
wantThinking: "Let me search for that...",
wantContent: "I'll use a tool.",
wantToolCalls: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Parser{}
var thinkValue *api.ThinkValue
if tt.thinkEnabled {
thinkValue = &api.ThinkValue{Value: true}
} else {
thinkValue = &api.ThinkValue{Value: false}
}
// Define tools for tool call tests
props := api.NewToolPropertiesMap()
props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "search",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
},
}
p.Init(tools, nil, thinkValue)
content, thinking, calls, err := p.Add(tt.input, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if thinking != tt.wantThinking {
t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
}
if content != tt.wantContent {
t.Errorf("content = %q, want %q", content, tt.wantContent)
}
if len(calls) != tt.wantToolCalls {
t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
}
})
}
}
func TestParserToolCall(t *testing.T) {
p := &Parser{}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
},
}
// Initialize with thinking disabled
tv := &api.ThinkValue{Value: false}
p.Init(tools, nil, tv)
input := "<tool_call>get_weather<arg_key>location</arg_key><arg_value>San Francisco</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call>"
_, _, calls, err := p.Add(input, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
call := calls[0]
if call.Function.Name != "get_weather" {
t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
}
location, ok := call.Function.Arguments.Get("location")
if !ok || location != "San Francisco" {
t.Errorf("location = %v, want %q", location, "San Francisco")
}
unit, ok := call.Function.Arguments.Get("unit")
if !ok || unit != "celsius" {
t.Errorf("unit = %v, want %q", unit, "celsius")
}
}
func TestOverlap(t *testing.T) {
tests := []struct {
s string
tag string
want int
}{
{"hello<", "</think>", 1},
{"hello</", "</think>", 2},
{"hello</t", "</think>", 3},
{"hello</th", "</think>", 4},
{"hello</thi", "</think>", 5},
{"hello</thin", "</think>", 6},
{"hello</think", "</think>", 7},
{"hello</think>", "</think>", 8}, // Complete tag at end returns full length
{"hello", "</think>", 0},
{"", "</think>", 0},
}
for _, tt := range tests {
t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
got := overlap(tt.s, tt.tag)
if got != tt.want {
t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
}
})
}
}
func TestTrailingWhitespaceLen(t *testing.T) {
tests := []struct {
s string
want int
}{
{"hello ", 3},
{"hello\n\t ", 3},
{"hello", 0},
{"", 0},
{" ", 3},
}
for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) {
got := trailingWhitespaceLen(tt.s)
if got != tt.want {
t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,175 @@
//go:build mlx
package glm4_moe_lite
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
// Renderer renders messages for GLM4-MoE-Lite models.
//
// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
//
// 1. INTERLEAVED THINKING
// The model thinks between tool calls and after receiving tool results.
// This enables complex step-by-step reasoning: interpreting each tool output
// before deciding what to do next. Thinking blocks are preserved and returned
// with tool results to maintain reasoning continuity.
//
// 2. PRESERVED THINKING
// The model retains reasoning content from previous assistant turns in context.
// This preserves reasoning continuity across multi-turn conversations. The
// upstream API has a "clear_thinking" parameter to control this:
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
//
// 3. TURN-LEVEL THINKING
// Controls whether the model should reason on each turn. The upstream API
// uses "enable_thinking" parameter:
// - enable_thinking=true: outputs <think> to start reasoning
// - enable_thinking=false: outputs </think> to skip reasoning
//
// OLLAMA DEFAULTS:
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
// - Users can disable thinking per-turn via thinkValue=false
type Renderer struct{}
// Render renders messages into the GLM4 chat format.
func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(formatToolJSON(d))
sb.WriteString("\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
}
think := true
if thinkValue != nil && !thinkValue.Bool() {
think = false
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>")
sb.WriteString(message.Content)
case "assistant":
sb.WriteString("<|assistant|>")
if message.Thinking != "" {
sb.WriteString("<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("</think>")
}
if message.Content != "" {
sb.WriteString(message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("<tool_call>" + toolCall.Function.Name)
sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("<tool_response>")
sb.WriteString(message.Content)
sb.WriteString("</tool_response>")
case "system":
sb.WriteString("<|system|>")
sb.WriteString(message.Content)
}
}
sb.WriteString("<|assistant|>")
if think {
sb.WriteString("<think>")
} else {
sb.WriteString("</think>")
}
return sb.String(), nil
}
// renderToolArguments converts tool call arguments to GLM4 XML format.
func renderToolArguments(args api.ToolCallFunctionArguments) string {
var sb strings.Builder
for key, value := range args.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
}
return sb.String()
}
// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
func formatToolJSON(raw []byte) string {
var sb strings.Builder
sb.Grow(len(raw) + len(raw)/10)
inString := false
escaped := false
for i := range raw {
ch := raw[i]
sb.WriteByte(ch)
if inString {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
continue
}
if ch == ':' || ch == ',' {
sb.WriteByte(' ')
}
}
return sb.String()
}

View File

@@ -0,0 +1,205 @@
//go:build mlx
package glm4_moe_lite
import (
"strings"
"testing"
"github.com/ollama/ollama/api"
)
func TestRendererSimple(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "Hello"},
}
// Thinking enabled (default)
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := "[gMASK]<sop><|user|>Hello<|assistant|><think>"
if result != expected {
t.Errorf("result = %q, want %q", result, expected)
}
}
func TestRendererThinkingDisabled(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "Hello"},
}
tv := &api.ThinkValue{Value: false}
result, err := r.Render(messages, nil, tv)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := "[gMASK]<sop><|user|>Hello<|assistant|></think>"
if result != expected {
t.Errorf("result = %q, want %q", result, expected)
}
}
func TestRendererMultiTurn(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "What is 2+2?"},
{Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
{Role: "user", Content: "And 3+3?"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check key parts
if !strings.Contains(result, "[gMASK]<sop>") {
t.Error("missing [gMASK]<sop> prefix")
}
if !strings.Contains(result, "<|user|>What is 2+2?") {
t.Error("missing first user message")
}
if !strings.Contains(result, "<|assistant|><think>Let me calculate: 2+2=4</think>4") {
t.Error("missing assistant message with thinking")
}
if !strings.Contains(result, "<|user|>And 3+3?") {
t.Error("missing second user message")
}
if !strings.HasSuffix(result, "<|assistant|><think>") {
t.Errorf("should end with <|assistant|><think>, got suffix: %q", result[len(result)-30:])
}
}
func TestRendererWithSystem(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
t.Error("missing system message")
}
}
func TestRendererWithTools(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "What's the weather?"},
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the weather for a location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
Required: []string{"location"},
},
},
},
}
result, err := r.Render(messages, tools, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check for tool system prompt
if !strings.Contains(result, "<|system|>") {
t.Error("missing system tag for tools")
}
if !strings.Contains(result, "# Tools") {
t.Error("missing tools header")
}
if !strings.Contains(result, "<tools>") {
t.Error("missing tools tag")
}
if !strings.Contains(result, "get_weather") {
t.Error("missing tool name")
}
if !strings.Contains(result, "</tools>") {
t.Error("missing closing tools tag")
}
}
func TestRendererWithToolCalls(t *testing.T) {
r := &Renderer{}
args := api.NewToolCallFunctionArguments()
args.Set("location", "San Francisco")
messages := []api.Message{
{Role: "user", Content: "What's the weather in SF?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
},
},
},
{Role: "tool", Content: "Sunny, 72F"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "<tool_call>get_weather") {
t.Error("missing tool call")
}
if !strings.Contains(result, "<arg_key>location</arg_key>") {
t.Error("missing arg_key")
}
if !strings.Contains(result, "<arg_value>San Francisco</arg_value>") {
t.Error("missing arg_value")
}
if !strings.Contains(result, "</tool_call>") {
t.Error("missing tool call closing tag")
}
if !strings.Contains(result, "<|observation|>") {
t.Error("missing observation tag")
}
if !strings.Contains(result, "<tool_response>Sunny, 72F</tool_response>") {
t.Error("missing tool response")
}
}
func TestFormatToolJSON(t *testing.T) {
input := []byte(`{"name":"test","value":123}`)
result := formatToolJSON(input)
// Should add spaces after : and ,
if !strings.Contains(result, ": ") {
t.Error("should add space after colon")
}
if !strings.Contains(result, ", ") {
t.Error("should add space after comma")
}
}

View File

@@ -32,10 +32,16 @@ func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
// Quantizes the weight immediately and evaluates to break lazy dependencies.
// Note: For modes like "nvfp4", qbiases will be nil.
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
// Eval immediately so bf16 weight can be freed
mlx.Eval(qw, scales, qbiases)
// Handle modes that don't return qbiases (e.g., nvfp4)
if qbiases != nil {
mlx.Eval(qw, scales, qbiases)
} else {
mlx.Eval(qw, scales)
}
return &QuantizedLinear{
Weight: qw,
Scales: scales,
@@ -77,10 +83,13 @@ func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear
// QuantizedLinear applies an affine transformation using quantized weights.
// Equivalent to mlx.nn.QuantizedLinear.
// Supports multiple quantization modes:
// - "affine": scale + zero-point bias (QBiases required)
// - "nvfp4": NVIDIA FP4 with E4M3 scales (QBiases nil)
type QuantizedLinear struct {
Weight *mlx.Array // Quantized weight data
Scales *mlx.Array // Scale factors for dequantization
QBiases *mlx.Array // Quantization biases (NOT layer bias)
QBiases *mlx.Array // Quantization biases (NOT layer bias), nil for nvfp4
Bias *mlx.Array // Layer bias [output_dims] or nil
GroupSize int
Bits int
@@ -220,3 +229,32 @@ func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
}
return out
}
// MultiLinearLayer is an interface for per-head linear layers.
// This allows swapping between MultiLinear (bf16) and pre-dequantized weights.
type MultiLinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
}
// MultiLinear performs per-head linear projections.
// Weight shape: [num_heads, output_dims, input_dims]
// Input shape: [B, num_heads, L, input_dims]
// Output shape: [B, num_heads, L, output_dims]
type MultiLinear struct {
Weight *mlx.Array `weight:"weight"`
}
// NewMultiLinear creates a MultiLinear layer with the given weight.
func NewMultiLinear(weight *mlx.Array) *MultiLinear {
return &MultiLinear{Weight: weight}
}
// Forward applies per-head linear transformation: x @ weight.T per head via broadcasting.
func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
// Weight: [num_heads, output_dims, input_dims]
// x: [B, num_heads, L, input_dims]
// wT: [num_heads, input_dims, output_dims]
// Result: [B, num_heads, L, output_dims]
wT := mlx.Transpose(ml.Weight, 0, 2, 1)
return mlx.Matmul(x, wT)
}

View File

@@ -1,284 +0,0 @@
//go:build mlx
// Package runner provides a subprocess server for image generation.
// It listens on a port and handles HTTP requests for image generation.
package runner
import (
"context"
"encoding/json"
"flag"
"fmt"
"image"
"log/slog"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
// Request is the image generation request format
type Request struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
}
// Response is streamed back for each progress update
type Response struct {
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"` // Base64-encoded PNG
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
}
// ImageModel is the interface for image generation models
type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
}
// ImageEditModel extends ImageModel with image editing/conditioning capability.
// Models that support input images for editing should implement this interface.
type ImageEditModel interface {
ImageModel
GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error)
}
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
model ImageModel
modelName string
}
// Execute is the entry point for the image runner subprocess
func Execute(args []string) error {
fs := flag.NewFlagSet("image-runner", flag.ExitOnError)
modelName := fs.String("model", "", "path to image model")
port := fs.Int("port", 0, "port to listen on")
if err := fs.Parse(args); err != nil {
return err
}
if *modelName == "" {
return fmt.Errorf("--model is required")
}
if *port == 0 {
return fmt.Errorf("--port is required")
}
err := mlx.InitMLX()
if err != nil {
slog.Error("unable to initialize MLX", "error", err)
return err
}
slog.Info("MLX library initialized")
slog.Info("starting image runner", "model", *modelName, "port", *port)
// Detect model type and load appropriate model
modelType := imagegen.DetectModelType(*modelName)
slog.Info("detected model type", "type", modelType)
var model ImageModel
switch modelType {
case "Flux2KleinPipeline":
m := &flux2.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
model = m
default:
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
m := &zimage.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
model = m
}
server := &Server{
model: model,
modelName: *modelName,
}
// Set up HTTP handlers
mux := http.NewServeMux()
mux.HandleFunc("/health", server.healthHandler)
mux.HandleFunc("/completion", server.completionHandler)
httpServer := &http.Server{
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
Handler: mux,
}
// Handle shutdown
done := make(chan struct{})
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
slog.Info("shutting down image runner")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpServer.Shutdown(ctx)
close(done)
}()
slog.Info("image runner listening", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
return err
}
<-done
return nil
}
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}
func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Validate and decode input images
const maxInputImages = 2
if len(req.Images) > maxInputImages {
http.Error(w, fmt.Sprintf("too many input images, maximum is %d", maxInputImages), http.StatusBadRequest)
return
}
var inputImages []image.Image
if len(req.Images) > 0 {
// TODO: add memory check for input images
inputImages = make([]image.Image, len(req.Images))
for i, imgBytes := range req.Images {
img, err := imagegen.DecodeImage(imgBytes)
if err != nil {
http.Error(w, fmt.Sprintf("invalid image %d: %v", i, err), http.StatusBadRequest)
return
}
inputImages[i] = img
}
slog.Info("decoded input images", "count", len(inputImages))
// Default width/height to first input image dimensions, scaled to max 1024
bounds := inputImages[0].Bounds()
w, h := bounds.Dx(), bounds.Dy()
if w > 1024 || h > 1024 {
if w > h {
h = h * 1024 / w
w = 1024
} else {
w = w * 1024 / h
h = 1024
}
}
req.Width = int32(w)
req.Height = int32(h)
}
// Serialize generation requests - MLX model may not handle concurrent generation
s.mu.Lock()
defer s.mu.Unlock()
// Model applies its own defaults for width/height/steps
// Only seed needs to be set here if not provided
if req.Seed <= 0 {
req.Seed = time.Now().UnixNano()
}
// Set up streaming response
w.Header().Set("Content-Type", "application/x-ndjson")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
// Generate image using the common interface
ctx := r.Context()
enc := json.NewEncoder(w)
// Progress callback streams step updates
progress := func(step, total int) {
resp := Response{Step: step, Total: total}
enc.Encode(resp)
w.Write([]byte("\n"))
flusher.Flush()
}
// Use ImageEditModel if available and images provided, otherwise use basic ImageModel
var img *mlx.Array
var err error
if len(inputImages) > 0 {
editModel, ok := s.model.(ImageEditModel)
if !ok {
http.Error(w, "model does not support image editing", http.StatusBadRequest)
return
}
img, err = editModel.GenerateImageWithInputs(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, inputImages, progress)
} else {
img, err = s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
}
if err != nil {
// Don't send error for cancellation
if ctx.Err() != nil {
return
}
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Encode image as base64 PNG
imageData, err := imagegen.EncodeImageBase64(img)
if err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Free the generated image array and clean up MLX state
img.Free()
mlx.ClearCache()
mlx.MetalResetPeakMemory()
// Send final response with image data
resp := Response{
Image: imageData,
Done: true,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
}

View File

@@ -17,17 +17,31 @@ type WeightSource interface {
GetTensor(name string) (*mlx.Array, error)
ListTensors() []string
HasTensor(name string) bool
Quantization() string // Returns "FP4", "FP8", or ""
Quantization() string // Returns "NVFP4", "Q4", "Q8", or ""
GroupSize() int // Returns quantization group size, or 0 if not specified
}
// quantizationParams returns groupSize, bits, mode for a quantization type.
// Returns defaults (32, 8, "affine") for unknown types (backward compatibility).
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
// QuantizationParams returns groupSize, bits, mode for a quantization type.
// MLX quantization modes:
// - "affine": scale + zero-point bias, group_size=32/64/128
// - "nvfp4": NVIDIA FP4 with E4M3 scales, group_size=16 (no bias)
// - "mxfp8": Microsoft MX FP8 with E4M3 scales, group_size=32 (no bias)
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "FP4":
case "NVFP4":
// NVIDIA FP4: group_size=16, bits=4, E4M3 scales (no qbias)
return 16, 4, "nvfp4"
case "FP4", "Q4", "INT4":
// 4-bit quantization with affine mode (scale + qbias)
return 32, 4, "affine"
case "MXFP8":
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbias)
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8", "":
// 8-bit quantization with affine mode (default for quantized models)
return 64, 8, "affine"
default:
return 32, 8, "affine" // FP8 or unknown
return 32, 8, "affine" // Default to affine
}
}
@@ -122,7 +136,8 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
}
// Handle nn.LinearLayer interface fields specially
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
linearLayerType := reflect.TypeOf((*nn.LinearLayer)(nil)).Elem()
if field.Type == linearLayerType {
if !hasTag {
continue // no tag = skip
}
@@ -137,6 +152,23 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
continue
}
// Handle nn.MultiLinearLayer interface fields specially
multiLinearLayerType := reflect.TypeOf((*nn.MultiLinearLayer)(nil)).Elem()
if field.Type == multiLinearLayerType {
if !hasTag {
continue // no tag = skip
}
layer, err := LoadMultiLinearLayer(weights, fullPath)
if err != nil {
if !optional {
*errs = append(*errs, fullPath+": "+err.Error())
}
continue
}
fieldVal.Set(reflect.ValueOf(layer))
continue
}
// Handle by kind
switch fieldVal.Kind() {
case reflect.Ptr:
@@ -216,12 +248,86 @@ func joinPath(prefix, suffix string) string {
return prefix + "." + suffix
}
// LoadMultiLinearLayer loads a per-head linear layer from weights.
// Weight shape should be [num_heads, output_dims, input_dims].
// If quantized, always dequantizes since batched quantized matmul isn't supported.
func LoadMultiLinearLayer(weights WeightSource, path string) (nn.MultiLinearLayer, error) {
// Check if this is a quantized layer by looking for scale tensor
scalePath := path + ".weight_scale"
hasScale := weights.HasTensor(scalePath)
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
}
if hasScale {
scales, err := weights.GetTensor(scalePath)
if err != nil {
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
}
var qbiases *mlx.Array
qbiasPath := path + ".weight_qbias"
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
// Always dequantize for MultiLinear - no batched quantized matmul support
// Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
weightShape := weight.Shape()
scalesShape := scales.Shape()
weightCols := int(weightShape[len(weightShape)-1])
scalesCols := int(scalesShape[len(scalesShape)-1])
// Detect quantization from tensor shapes
// groupSize = weightCols * packFactor / scalesCols
// Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
var bits, groupSize int
// Use metadata to help disambiguate when shapes are ambiguous
// (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
quantType := strings.ToUpper(weights.Quantization())
isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
if groupSize4 == 32 {
// Unambiguous: Q4 with group_size=32
bits = 4
groupSize = 32
} else if groupSize8 == 64 {
// Unambiguous: Q8 with group_size=64
bits = 8
groupSize = 64
} else if groupSize4 == 64 && groupSize8 == 32 {
// Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
if isQ8Type {
bits = 8
groupSize = 32
} else {
bits = 4
groupSize = 64
}
} else {
// Fallback: use global quantization params
_, bits, _ = QuantizationParams(weights.Quantization())
packFactor := 32 / bits
groupSize = weightCols * packFactor / scalesCols
}
weight = mlx.Dequantize(weight, scales, qbiases, groupSize, bits, "affine")
}
return nn.NewMultiLinear(weight), nil
}
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
// If {path}.weight_scale exists, dequantizes the weights.
// If {path}.weight_scale exists, creates a QuantizedLinear layer (or dequantizes if no kernel support).
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
// Check if this is a quantized layer by looking for scale tensor
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
hasScale := weights.HasTensor(scalePath)
if hasScale {
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
@@ -245,9 +351,52 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
qbiases, _ = weights.GetTensor(qbiasPath)
}
groupSize, bits, mode := quantizationParams(weights.Quantization())
// Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
weightShape := weight.Shape()
scalesShape := scales.Shape()
weightCols := int(weightShape[len(weightShape)-1])
scalesCols := int(scalesShape[len(scalesShape)-1])
if mlx.MetalIsAvailable() {
// Detect quantization from tensor shapes
// groupSize = weightCols * packFactor / scalesCols
// Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
var bits, groupSize int
mode := "affine"
// Use metadata to help disambiguate when shapes are ambiguous
// (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
quantType := strings.ToUpper(weights.Quantization())
isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
if groupSize4 == 32 {
// Unambiguous: Q4 with group_size=32
bits = 4
groupSize = 32
} else if groupSize8 == 64 {
// Unambiguous: Q8 with group_size=64
bits = 8
groupSize = 64
} else if groupSize4 == 64 && groupSize8 == 32 {
// Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
if isQ8Type {
bits = 8
groupSize = 32
} else {
bits = 4
groupSize = 64
}
} else {
// Fallback: use global quantization params
_, bits, mode = QuantizationParams(weights.Quantization())
packFactor := 32 / bits
groupSize = weightCols * packFactor / scalesCols
}
// NVFP4 and MXFP8 don't have native quantized matmul kernels in MLX,
// so we always dequantize at load time. Affine modes (FP4, FP8) have kernel support.
if mlx.MetalIsAvailable() && mode != "nvfp4" && mode != "mxfp8" {
return &nn.QuantizedLinear{
Weight: weight,
Scales: scales,

View File

@@ -303,6 +303,11 @@ func (mw *ModelWeights) Quantization() string {
return ""
}
// GroupSize returns 0 for directory-based weights (use default).
func (mw *ModelWeights) GroupSize() int {
return 0
}
// ReleaseAll releases all cached native file handles.
func (mw *ModelWeights) ReleaseAll() {
for path, native := range mw.nativeCache {

View File

@@ -1,48 +0,0 @@
package imagegen
import (
"runtime"
"testing"
)
// TestPlatformSupport verifies platform validation works correctly.
func TestPlatformSupport(t *testing.T) {
err := CheckPlatformSupport()
switch runtime.GOOS {
case "darwin":
if runtime.GOARCH == "arm64" {
// Apple Silicon should be supported
if err != nil {
t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
}
} else {
// Intel Mac should fail
if err == nil {
t.Error("Expected error on darwin/amd64 (Intel), got nil")
}
if err != nil && err.Error() == "" {
t.Error("Expected meaningful error message for unsupported platform")
}
}
case "linux", "windows":
// Linux/Windows are allowed (CUDA support checked at runtime)
if err != nil {
t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
}
default:
// Other platforms should fail
if err == nil {
t.Errorf("Expected error on unsupported platform %s, got nil", runtime.GOOS)
}
}
}
// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer.
// This is a compile-time check but we document it as a test.
func TestServerInterfaceCompliance(t *testing.T) {
// The var _ llm.LlamaServer = (*Server)(nil) line in server.go
// ensures compile-time interface compliance.
// This test documents that requirement.
t.Log("Server implements llm.LlamaServer interface (compile-time checked)")
}

View File

@@ -20,20 +20,28 @@ type ManifestWeights struct {
nativeCache []*mlx.SafetensorsFile // keep native handles alive
}
// LoadWeightsFromManifest creates a weight loader for a component from manifest storage.
// LoadWeightsFromManifest creates a weight loader from manifest storage.
// If component is empty, loads all tensors (for LLM models).
// If component is specified, loads only tensors for that component and strips the prefix.
func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*ManifestWeights, error) {
layers := manifest.GetTensorLayers(component)
if len(layers) == 0 {
if component == "" {
return nil, fmt.Errorf("no tensor layers found in manifest")
}
return nil, fmt.Errorf("no tensor layers found for component %q", component)
}
// Strip component prefix from tensor names for model loading
// e.g., "text_encoder/model.embed_tokens.weight" -> "model.embed_tokens.weight"
prefix := component + "/"
tensors := make(map[string]ManifestLayer, len(layers))
for _, layer := range layers {
tensorName := strings.TrimPrefix(layer.Name, prefix)
tensors[tensorName] = layer
if component == "" {
tensors[layer.Name] = layer
} else {
tensorName := strings.TrimPrefix(layer.Name, component+"/")
tensors[tensorName] = layer
}
}
return &ManifestWeights{
@@ -48,19 +56,30 @@ func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*Manife
// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
// If dtype is non-zero, tensors are converted to the specified dtype.
func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
// Track native handles to free after batch eval
nativeHandles := make([]*mlx.SafetensorsFile, 0, len(mw.tensors))
arrays := make([]*mlx.Array, 0, len(mw.tensors))
for name, layer := range mw.tensors {
path := mw.manifest.BlobPath(layer.Digest)
// Load blob as safetensors (native mmap, zero-copy)
sf, err := mlx.LoadSafetensorsNative(path)
if err != nil {
// Free any handles we've accumulated
for _, h := range nativeHandles {
h.Free()
}
return fmt.Errorf("load %s: %w", name, err)
}
nativeHandles = append(nativeHandles, sf)
// Blob contains single tensor named "data"
arr := sf.Get("data")
if arr == nil {
sf.Free()
for _, h := range nativeHandles {
h.Free()
}
return fmt.Errorf("tensor 'data' not found in blob for %s", name)
}
@@ -68,11 +87,18 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
if dtype != 0 && arr.Dtype() != dtype {
arr = mlx.AsType(arr, dtype)
}
// ALWAYS make a contiguous copy to ensure independence from mmap
// Make contiguous copy to ensure independence from mmap
arr = mlx.Contiguous(arr)
mlx.Eval(arr)
mw.cache[name] = arr
sf.Free() // Safe to free - arr is now an independent copy
arrays = append(arrays, arr)
}
// Batch evaluate all tensors at once (much faster than one at a time)
mlx.Eval(arrays...)
// Now safe to free all native handles
for _, sf := range nativeHandles {
sf.Free()
}
return nil
@@ -107,18 +133,112 @@ func (mw *ManifestWeights) HasTensor(name string) bool {
}
// Quantization returns the model's quantization type from model_index.json.
// Returns empty string if not quantized or unknown.
// Returns empty string if not quantized.
// Falls back to detecting from tensor names and shapes if not in config.
func (mw *ManifestWeights) Quantization() string {
if mw.manifest == nil {
return ""
}
// Try to read from model_index.json first
var index struct {
Quantization string `json:"quantization"`
}
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil {
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.Quantization != "" {
return index.Quantization
}
// Fallback: detect from tensor names
// Check if any tensors have _scale suffix (indicates quantization)
hasScales := false
hasQBias := false
for name := range mw.tensors {
if strings.HasSuffix(name, ".weight_scale") {
hasScales = true
}
if strings.HasSuffix(name, ".weight_qbias") {
hasQBias = true
}
}
if !hasScales {
// No scales = not quantized
return ""
}
return index.Quantization
// Has scales but no qbias = NVFP4 (or other non-affine mode)
if !hasQBias {
return "NVFP4"
}
// Has both scales and qbias = affine mode
// Need to determine FP4 vs FP8 from tensor shapes
// FP4: weight last dim is 1/8 of scales last dim * group_size
// FP8: weight last dim is 1/4 of scales last dim * group_size
//
// For affine mode with group_size=32:
// - FP4 (4 bits): 8 elements packed per uint32, so weight_dim = orig_dim / 8
// - FP8 (8 bits): 4 elements packed per uint32, so weight_dim = orig_dim / 4
// scales_dim = orig_dim / group_size
// So: weight_dim / scales_dim = group_size / pack_factor
// FP4: ratio = 32/8 = 4
// FP8: ratio = 32/4 = 8
// Find a weight/scale pair to check the ratio
for name := range mw.tensors {
if !strings.HasSuffix(name, ".weight") || strings.Contains(name, "_scale") || strings.Contains(name, "_qbias") {
continue
}
scaleName := name + "_scale"
if _, ok := mw.tensors[scaleName]; !ok {
continue
}
// Load both tensors to check shapes
weightLayer := mw.tensors[name]
scaleLayer := mw.tensors[scaleName]
// Get shapes from manifest layer metadata if available
// For now, default to FP4 since it's more common
// The actual shape check would require loading the tensor
// Simple heuristic: check if scale tensor is ~4x smaller than weight
// FP4: weight is packed 8 per uint32, scales are 1 per group (32)
// So scale size should be ~weight_size * 8 / 32 = weight_size / 4
// FP8: weight is packed 4 per uint32, scales are 1 per group (32)
// So scale size should be ~weight_size * 4 / 32 = weight_size / 8
// Rough size heuristic (assuming float16 scales)
// Q4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8
// Q8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16
ratio := float64(weightLayer.Size) / float64(scaleLayer.Size)
if ratio < 12 {
// Closer to 8 = Q4
return "Q4"
}
// Closer to 16 = Q8
return "Q8"
}
// Default to Q4 for affine mode (most common)
return "Q4"
}
// GroupSize returns the quantization group size from model_index.json.
// Returns 0 if not specified (caller should use default based on quantization type).
func (mw *ManifestWeights) GroupSize() int {
if mw.manifest == nil {
return 0
}
var index struct {
GroupSize int `json:"group_size"`
}
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.GroupSize > 0 {
return index.GroupSize
}
return 0
}
// ReleaseAll frees all native handles and clears the tensor cache.

View File

@@ -1,797 +1,144 @@
//go:build mlx
package kvcache
// import (
// "errors"
// "fmt"
// "log/slog"
// "math"
// "slices"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
// // Causal cache stores K and V tensors according to their position in the
// // sequence. Returns the history and a mask for attending to past tokens
// //
// // The tensors are of shape embed dim, kv heads, batch size
// // The mask is of shape history size, batch size
// type Causal struct {
// DType ml.DType
// // swaWindowSize is the number of tokens that will be included in the mask
// // during attention operations. swaMemorySize is the number of tokens that
// // will be retained in memory for partial prefix caching. Set to math.MaxInt32
// // for unlimited or if sliding window attention is not being used.
// swaWindowSize int32
// swaMemorySize int32
// chunkSize int32
// opts CausalOptions
// // maxBatch is the largest batch that we might receive
// maxBatch int
// // config controls mostly backend-specific optimizations
// config *ml.CacheConfig
// // ** current forward pass **
// // size of the current batch
// curBatchSize int
// // locations for data storage for this batch
// curLoc ml.Tensor
// // mask of the cache as used by this batch
// curMask ml.Tensor
// // the active layer for Get and Put
// curLayer int
// // locations in the cache that are needed for this batch
// curCellRange cellRange
// // curSequences is the sequences corresponding to this pass's entries in the cache
// curSequences []int
// // curPositions is the positions corresponding to this pass's entries in the cache
// curPositions []int32
// // ** cache metadata **
// // for each possible location in the cache, stores the position and set of sequences
// // that reference the data there
// cells []cacheCell
// // maps from sequence to the range of locations where it is stored in the cache
// cellRanges map[int]cellRange
// // ** cache data storage **
// shiftFn shiftFn
// backend ml.Backend
// ctxs map[int]ml.Context
// keys, values map[int]ml.Tensor
// kHeadDims, vHeadDims, numKVHeads map[int]int
// }
// type cacheCell struct {
// pos int32
// sequences []int
// }
// type cellRange struct {
// min int
// max int
// }
// func NewCausalCache(shift shiftFn) *Causal {
// return &Causal{
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func NewSWACache(windowSize int32, shift shiftFn) *Causal {
// return &Causal{
// swaWindowSize: windowSize,
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
// return &Causal{
// swaWindowSize: windowSize,
// swaMemorySize: memorySize,
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
// return &Causal{
// chunkSize: chunkSize,
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// if c.config == nil {
// var config ml.CacheConfig
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
// config = cc.CacheConfig()
// }
// c.config = &config
// }
// if c.config.CachePadding == 0 {
// c.config.CachePadding = 1
// }
// if c.config.MaskBatchPadding == 0 {
// c.config.MaskBatchPadding = 1
// }
// // TODO what types do we handle here?
// // if c.config.MaskDType == ml.DTypeOther {
// // c.config.MaskDType = ml.DTypeFloat32
// // }
// if c.swaWindowSize == 0 {
// c.swaWindowSize = math.MaxInt32
// }
// if c.swaMemorySize == 0 {
// c.swaMemorySize = c.swaWindowSize
// }
// // We will allocate space in the cache for the stop token, which won't be part of a follow on
// // sequence, so allocate an extra token of storage to ensure that we can jump back without
// // causing a cache break. As an optimization, only do this when we have parallel sequences
// // because the extra token will live in the batch buffer and won't get overwritten if we
// // only have a single sequence.
// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
// }
// if int(c.swaMemorySize) >= capacity {
// c.swaMemorySize = math.MaxInt32
// }
// if c.swaMemorySize < c.swaWindowSize {
// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
// }
// var cacheSize int
// if c.swaMemorySize == math.MaxInt32 {
// cacheSize = maxSequences * capacity
// } else {
// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
// }
// cacheSize = roundUp(cacheSize, c.config.CachePadding)
// c.cells = make([]cacheCell, cacheSize)
// c.DType = dtype
// c.cellRanges = make(map[int]cellRange)
// c.backend = backend
// c.maxBatch = maxBatch
// }
// func (c *Causal) SetConfig(config ml.CacheConfig) {
// if c.config != nil {
// panic("config cannot be changed after being previously set, either by the model or backend")
// }
// c.config = &config
// }
// func (c *Causal) Close() {
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
// for _, ctx := range c.ctxs {
// ctx.Close()
// }
// }
// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch)
// // panic("XXX Causal.StartForward")
// c.curBatchSize = len(batch.Positions)
// c.curSequences = batch.Sequences
// c.curPositions = batch.Positions
// c.opts.Except = nil
// var locs []int32
// if !reserve {
// c.updateSlidingWindow()
// var err error
// locs, err = c.findLocs()
// if err != nil {
// return err
// }
// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs))
// for i, pos := range batch.Positions {
// seq := batch.Sequences[i]
// loc := int(locs[i])
// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
// seqRange, ok := c.cellRanges[seq]
// if !ok {
// seqRange = newRange()
// }
// seqRange.min = min(seqRange.min, loc)
// c.curCellRange.min = min(c.curCellRange.min, loc)
// seqRange.max = max(seqRange.max, loc)
// c.curCellRange.max = max(c.curCellRange.max, loc)
// c.cellRanges[seq] = seqRange
// }
// } else {
// // If we are reserving memory, don't update any of the cache metadata but set the size
// // to the worst case.
// locs = make([]int32, c.curBatchSize)
// for i := range locs {
// locs[i] = int32(i)
// }
// c.curCellRange.min = 0
// c.curCellRange.max = len(c.cells) - 1
// }
// // XXX Building up the locs for what's already processed (if any)
// dummyLocs := []int{}
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
// for i := range c.curBatchSize {
// enabled := !slices.Contains(c.opts.Except, i)
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
// } else {
// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i {
// dummyLocs = append(dummyLocs, i)
// }
// }
// }
// }
// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs)
// slog.Info("XXX Causal.StartForward", "locs", locs)
// c.curLoc = ctx.Input().FromInts(locs, len(locs))
// c.curMask = c.buildMask(ctx)
// return nil
// }
// func newRange() cellRange {
// return cellRange{
// min: math.MaxInt,
// max: 0,
// }
// }
// // Returns a slice of locations where each token in the batch should be stored
// func (c *Causal) findLocs() ([]int32, error) {
// loc := make([]int32, 0, c.curBatchSize)
// for i := range c.cells {
// if len(c.cells[i].sequences) == 0 {
// loc = append(loc, int32(i))
// if len(loc) >= c.curBatchSize {
// return loc, nil
// }
// }
// }
// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
// }
// func (c *Causal) updateSlidingWindow() {
// c.curCellRange = newRange()
// if c.swaMemorySize == math.MaxInt32 {
// for _, seq := range c.curSequences {
// if seqRange, ok := c.cellRanges[seq]; ok {
// c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
// c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
// }
// }
// return
// }
// type lowestPosition struct {
// pos int32
// curBatch bool
// }
// // create a map of unique sequences to the lowest position in that sequence
// lowestPos := make(map[int]lowestPosition)
// for i := range c.curPositions {
// seq := c.curSequences[i]
// lowest, ok := lowestPos[seq]
// if !ok {
// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
// } else if c.curPositions[i] < lowest.pos {
// lowest.pos = c.curPositions[i]
// }
// lowestPos[seq] = lowest
// }
// // for any sequences are not part of this batch, clean up any tokens
// // that are no longer needed after the processing of the previous
// // batch
// for seq, seqRange := range c.cellRanges {
// if _, ok := lowestPos[seq]; !ok {
// var last int32
// for i := seqRange.min; i <= seqRange.max; i++ {
// if slices.Contains(c.cells[i].sequences, seq) {
// last = max(last, c.cells[i].pos)
// }
// }
// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
// }
// }
// // delete any entries that are beyond the window of the oldest position in the sequence
// for seq, lowest := range lowestPos {
// oldRange, ok := c.cellRanges[seq]
// if !ok {
// continue
// }
// newRange := newRange()
// for i := oldRange.min; i <= oldRange.max; i++ {
// if slices.Contains(c.cells[i].sequences, seq) {
// if c.cells[i].pos < lowest.pos-c.swaMemorySize {
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
// } else {
// newRange.min = min(newRange.min, i)
// newRange.max = max(newRange.max, i)
// }
// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
// c.curCellRange.min = min(c.curCellRange.min, i)
// c.curCellRange.max = max(c.curCellRange.max, i)
// }
// }
// }
// c.cellRanges[seq] = newRange
// }
// }
// func roundDown(length, pad int) int {
// return (length / pad) * pad
// }
// func roundUp(length, pad int) int {
// return ((length + pad - 1) / pad) * pad
// }
// // Builds a mask of history x batch indicating whether for each token in the batch the
// // token in the history should apply. This is based on both the sequence and causality (the
// // position of the history is not ahead of the token in the batch).
// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
// // Align and pad the two dimensions as required by the backend
// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
// length := c.curCellRange.max - c.curCellRange.min + 1
// mask := make([]float32, batchSize*length)
// for i := range c.curBatchSize {
// enabled := !slices.Contains(c.opts.Except, i)
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
// }
// }
// }
// // Mask out any padding tokens we added. For padding that we added to the cache history, this
// // has already been masked out because the sequence doesn't match.
// for i := c.curBatchSize * length; i < len(mask); i++ {
// mask[i] = float32(math.Inf(-1))
// }
// maskTensor := ctx.Input().FromFloats(mask, batchSize, length)
// // if c.config.MaskDType != ml.DTypeFloat32 {
// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
// // }
// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length})
// return maskTensor
// }
// func (c *Causal) SetLayer(layer int) {
// c.curLayer = layer
// }
// type CausalOptions struct {
// // Enabled controls whether the causal mask is generated for a particular index in a batch
// Except []int
// }
// // SetCausal disables causal mask generation for a particular range of indicies in
// // the current batch for subsequent calls to Get. The state resets for the next forward pass.
// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
// if !slices.Equal(c.opts.Except, opts.Except) {
// c.opts = opts
// if ctx != nil {
// c.curMask = c.buildMask(ctx)
// }
// }
// }
// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// key := c.keys[c.curLayer]
// value := c.values[c.curLayer]
// kHeadDim := c.kHeadDims[c.curLayer]
// vHeadDim := c.vHeadDims[c.curLayer]
// numKVHeads := c.numKVHeads[c.curLayer]
// // rowSize := numKVHeads * c.curBatchSize
// // cachedSize := c.curMask.Dim(1)
// cachedSize := c.curLoc.Dim(0)
// // kCellSize := kHeadDim * numKVHeads
// // vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Get full cache", "key", key)
// slog.Info("XXX Causal.Get full cache", "value", value)
// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc)
// slog.Info("XXX Causal.Get", "curMask", c.curMask)
// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim)
// // panic("XXX")
// // fmt.Fprintln(os.Stderr, key.ToString())
// // panic("full cache value")
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min)
// // slog.Info("XXX Causal.Get after AsStrided", "key", key)
// // panic("XXX")
// // if c.config.PermutedV {
// // panic("permuted")
// // // TODO not converted
// // vHeadDim := value.Dim(1)
// // elemSize := value.Stride(2)
// // value = value.AsStrided(ctx,
// // []int{numKVHeads, vHeadDim, cachedSize},
// // []int{value.Stride(0), value.Stride(1)},
// // elemSize*c.curCellRange.min,
// // )
// // } else {
// // vHeadDim := c.vHeadDims[c.curLayer]
// // rowSize := value.Stride(2)
// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize)
// // panic("XXX")
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min)
// // slog.Info("XXX Causal.Get after AsStrided", "value", value)
// // panic("XXX")
// // }
// // // TODO The mask changes from X,X to 1,X, and with the Row-order change
// // // the 1 becomes trailing and messes up later operations
// // // This isn't the right solution, but works around it...
// // if c.curMask.Dim(1) == 1 {
// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3)
// // }
// // fmt.Fprintln(os.Stderr, key.ToString())
// // fmt.Fprintln(os.Stderr, value.ToString())
// // panic("XXX")
// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape())
// return key, value, c.curMask
// }
// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
// kHeadDim := key.Dim(3)
// vHeadDim := value.Dim(3)
// numKVHeads := key.Dim(1)
// batchSize := key.Dim(2)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// // slog.Info("XXX Causal.Put", "key", key, "value", value)
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize)
// // panic("XXX")
// if c.curBatchSize != batchSize {
// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
// }
// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend)
// if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
// }
// if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize})
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize)
// c.kHeadDims[c.curLayer] = kHeadDim
// c.vHeadDims[c.curLayer] = vHeadDim
// c.numKVHeads[c.curLayer] = numKVHeads
// }
// if _, ok := c.values[c.curLayer]; !ok {
// // if c.config.PermutedV {
// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells))
// // } else {
// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize)
// // }
// }
// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache)
// // panic("XXX")
// // curLoc := 0 // TODO c.curLoc is now a tensor
// // kSize := numKVHeads * kHeadDim
// // vSize := numKVHeads * vHeadDim
// // start := []int{int(curLoc), 0}
// // kStop := []int{int(curLoc + batchSize), int(kSize)}
// // vStop := []int{int(curLoc + batchSize), int(vSize)}
// // strides := []int{1, 1}
// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache)
// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key)
// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides)
// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides))
// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0}))
// // fmt.Fprintln(os.Stderr, keyCache.ToString())
// // panic("input value")
// // fmt.Fprintln(os.Stderr, t.ToString())
// // panic("XXX")
// // if c.config.PermutedV {
// // panic("permuted")
// // // TODO not adjusted
// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
// // value = value.Transpose(ctx, 2, 0, 1, 3)
// // valueCache := c.values[c.curLayer]
// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides))
// // } else {
// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache)
// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value)
// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides)
// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0}))
// // }
// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString())
// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString())
// // panic("XXX")
// }
// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
// seqRange := newRange()
// for i := range c.cells {
// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
// if slices.Contains(c.cells[i].sequences, dstSeq) {
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
// }
// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
// if i < seqRange.min {
// seqRange.min = i
// }
// if i > seqRange.max {
// seqRange.max = i
// }
// }
// }
// c.cellRanges[dstSeq] = seqRange
// }
// func (c *Causal) CanResume(seq int, pos int32) bool {
// if c.swaMemorySize == math.MaxInt32 {
// return true
// }
// seqRange, ok := c.cellRanges[seq]
// if !ok {
// return false
// }
// // for sliding window, check that the window of the new sequence is contained in
// // the window of what we are storing
// var first int32 = math.MaxInt32
// var last int32 = -1
// for i := seqRange.min; i <= seqRange.max; i++ {
// if slices.Contains(c.cells[i].sequences, seq) {
// first = min(first, c.cells[i].pos)
// last = max(last, c.cells[i].pos)
// }
// }
// if last == -1 {
// return false
// }
// posWindowStart := max(0, pos-c.swaWindowSize)
// return posWindowStart >= first && pos <= last+1
// }
// func (c *Causal) shift(seq int, beginIndex, offset int32) error {
// if c.shiftFn == nil {
// return ErrNotSupported
// }
// seqRange := c.cellRanges[seq]
// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
// size := min(seqRange.max-start+1, c.maxBatch)
// offsets := make([]int32, size)
// var batchFirst, batchLast int
// batchFirst = -1
// for i := range offsets {
// cell := c.cells[start+i]
// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
// offsets[i] = offset
// if batchFirst < 0 {
// batchFirst = i
// }
// batchLast = i
// }
// }
// if batchFirst < 0 {
// continue
// }
// offsets = offsets[batchFirst : batchLast+1]
// slog.Info("XXX Causal.shift creating new temporary context")
// ctx := c.backend.NewContext()
// kShift := ctx.Input().FromInts(offsets, len(offsets))
// for i, key := range c.keys {
// if key == nil {
// continue
// }
// kHeadDim := key.Dim(2)
// numKVHeads := key.Dim(1)
// rowSize := key.Stride(0)
// key = key.AsStrided(ctx,
// []int{len(offsets), numKVHeads, kHeadDim},
// []int{key.Stride(0), key.Stride(1)},
// rowSize*(start+batchFirst),
// )
// roped, err := c.shiftFn(ctx, i, key, kShift)
// if err != nil {
// ctx.Close()
// return err
// }
// ctx.Forward(roped.Copy(ctx, key))
// }
// ctx.Compute()
// ctx.Close()
// }
// return nil
// }
// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
// // TODO(jessegross): We should check to see if removing the middle of the sequence will
// // cause the sliding window to encompass tokens that we no longer have. If so, then we
// // should return an error, which will trigger the runner to evaluate the full history and
// // rebuild the window. However, if we have multimodal inputs in our history, this reuse
// // results in use after free, so we don't do it for now.
// var offset int32
// if endIndex != math.MaxInt32 {
// offset = beginIndex - endIndex
// }
// seqRange := newRange()
// for i := range c.cells {
// if slices.Contains(c.cells[i].sequences, seq) {
// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
// } else {
// if c.cells[i].pos >= endIndex {
// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
// return errors.New("shifting cells shared by multiple sequences not supported")
// }
// c.cells[i].pos += offset
// }
// if i < seqRange.min {
// seqRange.min = i
// }
// if i > seqRange.max {
// seqRange.max = i
// }
// }
// }
// }
// if seqRange == newRange() {
// delete(c.cellRanges, seq)
// return nil
// }
// c.cellRanges[seq] = seqRange
// if endIndex != math.MaxInt32 {
// err := c.shift(seq, endIndex+offset, offset)
// if err != nil {
// return err
// }
// }
// return nil
// }
import (
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
type Causal struct {
DType ml.DType
// locations for data storage for this batch
curLocPut ml.Tensor
// locations for data storage for this batch
curLocGet ml.Tensor
// the active layer for Get and Put
curLayer int
capacity int
offset int
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
// TODO is this needed per layer, or will it always be consistent?
kHeadDims, vHeadDims, numKVHeads map[int]int
}
func NewCausalCache() *Causal {
return &Causal{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
kHeadDims: make(map[int]int),
vHeadDims: make(map[int]int),
numKVHeads: make(map[int]int),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.DType = dtype
c.capacity = capacity
c.backend = backend
}
func (c *Causal) SetConfig(config ml.CacheConfig) {}
func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
func (c *Causal) Close() {
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
locsPut := make([]int32, len(batch.Positions))
for i := c.offset; i < len(batch.Positions); i++ {
locsPut[i-c.offset] = int32(i)
}
c.offset += len(batch.Positions)
locsGet := make([]int32, c.offset)
for i := range c.offset {
locsGet[i] = int32(i)
}
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
// slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
return nil
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(3)
vHeadDim := value.Dim(3)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
kCellSize := kHeadDim * numKVHeads
vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
c.kHeadDims[c.curLayer] = kHeadDim
c.vHeadDims[c.curLayer] = vHeadDim
c.numKVHeads[c.curLayer] = numKVHeads
}
key = key.Reshape(ctx, batchSize, 1, kCellSize)
// slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
// slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
// slog.Info("XXX Causal.Put ", "key", key)
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
value = value.Reshape(ctx, batchSize, 1, vCellSize)
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
}
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := c.kHeadDims[c.curLayer]
vHeadDim := c.vHeadDims[c.curLayer]
numKVHeads := c.numKVHeads[c.curLayer]
// rowSize := numKVHeads * c.curBatchSize
// cachedSize := c.curMask.Dim(1)
cachedSize := c.curLocGet.Dim(0)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
return key, value, nil
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("not implemented")
}
func (c *Causal) CanResume(seq int, pos int32) bool {
panic("not implemented")
}
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
panic("not implemented")
}

View File

@@ -1,973 +0,0 @@
package kvcache
// import (
// "fmt"
// "math"
// "slices"
// "testing"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// type testCase struct {
// name string
// in []float32
// inShape []int
// seqs []int
// pos []int32
// expected []float32
// expectedShape []int
// expectedMask []float32
// }
// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
// t.Helper()
// for _, permuted := range []bool{false, true} {
// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
// fn(t, &testBackend{permutedV: permuted})
// })
// }
// }
// func TestStore(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
// inShape: []int{2, 3, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
// expectedShape: []int{2, 3, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
// },
// {
// name: "SecondBatch",
// in: []float32{115, 215, 125, 225, 135, 235},
// inShape: []int{2, 3, 1},
// seqs: []int{0},
// pos: []int32{4},
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
// expectedShape: []int{2, 3, 5},
// expectedMask: []float32{0, 0, 0, 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWA(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWACache(1, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, 0, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{5, 6, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, 0,
// 0, 0, x, x,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWASeparateBatches(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWACache(1, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 2, 16, 2)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "First seq 0",
// in: []float32{1, 2},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{0, 1},
// expected: []float32{1, 2},
// expectedShape: []int{1, 1, 2},
// expectedMask: []float32{
// 0, x,
// 0, 0,
// },
// },
// {
// name: "Second seq 0",
// in: []float32{3, 4},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{2, 3},
// expected: []float32{2, 3, 4},
// expectedShape: []int{1, 1, 3},
// expectedMask: []float32{
// 0, 0, x,
// x, 0, 0,
// },
// },
// {
// name: "First seq 1",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{0, 1},
// expected: []float32{5, 6},
// expectedShape: []int{1, 1, 2},
// expectedMask: []float32{
// 0, x,
// 0, 0,
// },
// },
// {
// name: "Second seq 1",
// in: []float32{7, 8},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{2, 3},
// expected: []float32{6, 3, 4, 7, 8},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, x, x, 0, x,
// x, x, x, 0, 0,
// },
// },
// {
// name: "Third seq 0",
// in: []float32{9, 10},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{9, 10, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, 0,
// 0, 0, x, x,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWAMem(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWAMemCache(1, 3, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, 0, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{5, 2, 3, 4, 6},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, x, x, 0, x,
// 0, x, x, x, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestChunkedAttention(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewChunkedAttentionCache(2, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// testCache(
// t, backend, cache,
// []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, x, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6, 7},
// inShape: []int{1, 1, 3},
// seqs: []int{0, 0, 0},
// pos: []int32{4, 5, 6},
// expected: []float32{1, 2, 3, 4, 5, 6, 7},
// expectedShape: []int{1, 1, 7},
// expectedMask: []float32{
// x, x, x, x, 0, x, x,
// x, x, x, x, 0, 0, x,
// x, x, x, x, x, x, 0,
// },
// },
// {
// name: "ThirdBatch",
// in: []float32{8, 9},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{7, 8},
// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
// expectedShape: []int{1, 1, 9},
// expectedMask: []float32{
// x, x, x, x, x, x, 0, 0, x,
// x, x, x, x, x, x, x, x, 0,
// },
// },
// },
// )
// })
// }
// func TestSequences(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 1, 1},
// pos: []int32{0, 1, 0, 1},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 1},
// pos: []int32{2, 2},
// expected: []float32{1, 2, 3, 4, 5, 6},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestRemove(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// return key.Add(ctx, shift), nil
// })
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 1, 1},
// pos: []int32{0, 1, 0, 1},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, x, 0, x,
// x, x, 0, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// err := cache.Remove(0, 1, math.MaxInt32)
// if err != nil {
// panic(err)
// }
// tests = []testCase{
// {
// name: "RemoveEnd",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 1},
// pos: []int32{1, 2},
// expected: []float32{1, 5, 3, 4, 6},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, 0, x, x, x,
// x, x, 0, 0, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// err = cache.Remove(0, 0, 1)
// if err != nil {
// panic(err)
// }
// tests = []testCase{
// {
// name: "RemoveMiddle",
// in: []float32{7, 8},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{1, 2},
// expected: []float32{7, 4, 3, 4, 6, 8},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{
// 0, 0, x, x, x, x,
// 0, 0, x, x, x, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestCopy(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// cache.CopyPrefix(0, 1, 2)
// tests = []testCase{
// {
// name: "Copy",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{3, 4},
// expected: []float32{1, 2, 3, 4, 5, 6},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
// for _, test := range tests {
// t.Run(test.name, func(t *testing.T) {
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
// if err != nil {
// panic(err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats(test.in, test.inShape...)
// cache.Put(context, tensor, tensor)
// out, _, mask := cache.Get(context)
// context.Forward(out, mask).Compute(out, mask)
// if !slices.Equal(out.Floats(), test.expected) {
// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
// }
// if !slices.Equal(out.Shape(), test.expectedShape) {
// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
// }
// if !slices.Equal(mask.Floats(), test.expectedMask) {
// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
// }
// })
// }
// }
// func TestCanResume(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// windowSize := int32(4)
// cache := NewSWACache(windowSize, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{
// Positions: []int32{0, 1, 2, 3, 4},
// Sequences: []int{0, 0, 0, 0, 0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
// cache.Put(context, tensor, tensor)
// // with window size 4, nothing has slid out of the window yet
// if !cache.CanResume(0, 0) {
// t.Errorf("CanResume(0, 0) = false, want true (within window)")
// }
// if !cache.CanResume(0, 1) {
// t.Errorf("CanResume(0, 1) = false, want true (within window)")
// }
// if !cache.CanResume(0, 2) {
// t.Errorf("CanResume(0, 2) = false, want true (within window)")
// }
// if !cache.CanResume(0, 3) {
// t.Errorf("CanResume(0, 3) = false, want true (latest position)")
// }
// if !cache.CanResume(0, 4) {
// t.Errorf("CanResume(0, 4) = false, want true (latest position)")
// }
// // shift window by adding position 5
// err = cache.StartForward(context, input.Batch{
// Positions: []int32{5},
// Sequences: []int{0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor = context.FromFloats([]float32{6}, 1, 1, 1)
// cache.Put(context, tensor, tensor)
// // only the latest position has overlapping windows
// if cache.CanResume(0, 0) {
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
// }
// if cache.CanResume(0, 1) {
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
// }
// if cache.CanResume(0, 2) {
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
// }
// if cache.CanResume(0, 3) {
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
// }
// if cache.CanResume(0, 4) {
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
// }
// if !cache.CanResume(0, 5) {
// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
// }
// })
// }
// func TestCanResumeSWAMem(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// windowSize := int32(4)
// memSize := int32(5)
// cache := NewSWAMemCache(windowSize, memSize, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{
// Positions: []int32{0, 1, 2, 3, 4, 5, 6},
// Sequences: []int{0, 0, 0, 0, 0, 0, 0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
// cache.Put(context, tensor, tensor)
// // shift window by adding position 7
// err = cache.StartForward(context, input.Batch{
// Positions: []int32{7},
// Sequences: []int{0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor = context.FromFloats([]float32{8}, 1, 1, 1)
// cache.Put(context, tensor, tensor)
// // only the latest position has overlapping windows
// if cache.CanResume(0, 0) {
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
// }
// if cache.CanResume(0, 1) {
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
// }
// if cache.CanResume(0, 2) {
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
// }
// if cache.CanResume(0, 3) {
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
// }
// if cache.CanResume(0, 4) {
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
// }
// if cache.CanResume(0, 5) {
// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
// }
// if !cache.CanResume(0, 6) {
// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
// }
// if !cache.CanResume(0, 7) {
// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
// }
// })
// }
// type testBackend struct {
// ml.Backend
// permutedV bool
// }
// func (b *testBackend) NewContext() ml.Context {
// return &testContext{}
// }
// func (b *testBackend) NewContextSize(int) ml.Context {
// return &testContext{}
// }
// func (b *testBackend) CacheConfig() ml.CacheConfig {
// return ml.CacheConfig{PermutedV: b.permutedV}
// }
// type testContext struct {
// ml.Context
// }
// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
// total := 0
// if len(shape) > 0 {
// total = 1
// for _, s := range shape {
// total *= s
// }
// }
// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
// }
// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
// return c.Empty(dtype, shape...)
// }
// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
// t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
// copy(t.data, s)
// return t
// }
// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
// f := make([]float32, len(s))
// for i := range f {
// f[i] = float32(s[i])
// }
// out := c.FromFloats(f, shape...)
// out.(*testTensor).dtype = ml.DTypeI32
// return out
// }
// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
// s := make([]float32, 0, int((stop-start)/step))
// for i := start; i < stop; i += step {
// s = append(s, i)
// }
// out := c.FromFloats(s, len(s))
// out.(*testTensor).dtype = dtype
// return out
// }
// func (c *testContext) Input() ml.Context { return c }
// func (c *testContext) Layer(int) ml.Context { return c }
// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
// func (c *testContext) Compute(...ml.Tensor) {}
// func (c *testContext) Reserve() {}
// func (c *testContext) MaxGraphNodes() int {
// return 10
// }
// func (c *testContext) Close() {}
// type testTensor struct {
// ml.Tensor
// dtype ml.DType
// elementSize int
// data []float32
// shape []int
// }
// func (t *testTensor) Dim(n int) int {
// return t.shape[n]
// }
// func (t *testTensor) Stride(n int) int {
// stride := t.elementSize
// for i := range n {
// stride *= t.shape[i]
// }
// return stride
// }
// func (t *testTensor) Shape() []int {
// return t.shape
// }
// func (t *testTensor) DType() ml.DType {
// return t.dtype
// }
// func (t *testTensor) Floats() []float32 {
// out := make([]float32, len(t.data))
// copy(out, t.data)
// return out
// }
// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
// for i := range out.data {
// out.data[i] = -t.data[i]
// }
// return out
// }
// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
// for i := range out.data {
// out.data[i] = t.data[i] + t2.(*testTensor).data[i]
// }
// return out
// }
// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
// return &testTensor{
// dtype: t.dtype,
// elementSize: t.elementSize,
// data: t.data,
// shape: shape,
// }
// }
// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
// offset /= t.elementSize
// var s []int
// switch len(shape) {
// case 1:
// s = []int{shape[0]}
// case 3:
// s = []int{shape[0], shape[2]}
// case 5:
// s = []int{shape[0], shape[2], shape[4]}
// default:
// panic("unsupported number of dimensions")
// }
// context := &testContext{}
// view := context.Empty(t.dtype, s...).(*testTensor)
// view.data = t.data[offset : offset+len(view.data)]
// return view
// }
// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
// if len(t.shape) > 4 || len(order) > 4 {
// panic("permute only supports up to 4 dimensions")
// }
// if len(order) != len(t.shape) && len(order) != 4 {
// panic("invalid number of dimensions for permute")
// }
// // ggml_permute expects 4 axes, so fill in any missing dimensions.
// orderFull := append(make([]int, 0, 4), order...)
// for len(orderFull) < 4 {
// orderFull = append(orderFull, len(orderFull))
// }
// seen := [4]bool{}
// shape4 := [4]int{1, 1, 1, 1}
// for i := 0; i < len(t.shape) && i < 4; i++ {
// shape4[i] = t.shape[i]
// }
// newShape4 := [4]int{1, 1, 1, 1}
// for axis := range 4 {
// dst := orderFull[axis]
// if dst < 0 || dst >= 4 {
// panic("invalid axis for permute")
// }
// if seen[dst] {
// panic("duplicate axis for permute")
// }
// seen[dst] = true
// newShape4[dst] = shape4[axis]
// }
// total := len(t.data)
// newData := make([]float32, total)
// if total > 0 {
// oldDims := shape4
// newDims := newShape4
// oldStride := [4]int{1, 1, 1, 1}
// newStride := [4]int{1, 1, 1, 1}
// for i := 1; i < 4; i++ {
// oldStride[i] = oldStride[i-1] * oldDims[i-1]
// newStride[i] = newStride[i-1] * newDims[i-1]
// }
// var coords [4]int
// var newCoords [4]int
// for idx := range total {
// remainder := idx
// for axis := range 4 {
// dim := oldDims[axis]
// if dim == 0 {
// coords[axis] = 0
// continue
// }
// coords[axis] = remainder % dim
// remainder /= dim
// }
// for axis := range 4 {
// newCoords[orderFull[axis]] = coords[axis]
// }
// newIndex := 0
// for axis := range 4 {
// if newDims[axis] == 0 {
// continue
// }
// newIndex += newCoords[axis] * newStride[axis]
// }
// newData[newIndex] = t.data[idx]
// }
// }
// numDims := 4
// for numDims > 1 && newShape4[numDims-1] <= 1 {
// numDims--
// }
// newShape := make([]int, numDims)
// copy(newShape, newShape4[:numDims])
// return &testTensor{
// dtype: t.dtype,
// elementSize: t.elementSize,
// data: newData,
// shape: newShape,
// }
// }
// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
// dst := t
// srcTensor := src.(*testTensor)
// idxTensor := idxs.(*testTensor)
// shapeTo4D := func(shape []int) [4]int {
// out := [4]int{1, 1, 1, 1}
// for i := 0; i < len(shape) && i < 4; i++ {
// out[i] = shape[i]
// }
// return out
// }
// computeStrides := func(shape [4]int) [4]int {
// out := [4]int{1, 1, 1, 1}
// for i := 1; i < 4; i++ {
// out[i] = out[i-1] * shape[i-1]
// }
// return out
// }
// dstShape4D := shapeTo4D(dst.shape)
// srcShape4D := shapeTo4D(srcTensor.shape)
// idxShape4D := shapeTo4D(idxTensor.shape)
// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
// panic("SetRows requires matching tensor shapes")
// }
// if srcShape4D[1] != idxShape4D[0] {
// panic("SetRows rows/index mismatch")
// }
// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
// panic("SetRows cannot broadcast indices")
// }
// if idxShape4D[3] != 1 {
// panic("SetRows expects 1D or 2D index tensors")
// }
// dstStride := computeStrides(dstShape4D)
// srcStride := computeStrides(srcShape4D)
// idxStride := computeStrides(idxShape4D)
// numColumns := srcShape4D[0]
// numRows := srcShape4D[1]
// for dim3Index := range dstShape4D[3] {
// for dim2Index := range dstShape4D[2] {
// idxDim2 := 0
// idxDim3 := 0
// if idxShape4D[1] > 0 {
// idxDim2 = dim2Index % idxShape4D[1]
// }
// if idxShape4D[2] > 0 {
// idxDim3 = dim3Index % idxShape4D[2]
// }
// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
// for row := range numRows {
// idx := int(idxTensor.data[idxBase+row*idxStride[0]])
// if idx < 0 || idx >= dstShape4D[1] {
// panic("SetRows index out of range")
// }
// srcOffset := srcBase + row*srcStride[1]
// dstOffset := dstBase + idx*dstStride[1]
// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
// }
// }
// }
// return dst
// }
// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
// copy(t2.(*testTensor).data, t.data)
// return nil
// }

View File

@@ -1,144 +0,0 @@
//go:build mlx
package kvcache
import (
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
type MLXCausal struct {
DType ml.DType
// locations for data storage for this batch
curLocPut ml.Tensor
// locations for data storage for this batch
curLocGet ml.Tensor
// the active layer for Get and Put
curLayer int
capacity int
offset int
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
// TODO is this needed per layer, or will it always be consistent?
kHeadDims, vHeadDims, numKVHeads map[int]int
}
func NewMLXCausalCache() *MLXCausal {
return &MLXCausal{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
kHeadDims: make(map[int]int),
vHeadDims: make(map[int]int),
numKVHeads: make(map[int]int),
}
}
func (c *MLXCausal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.DType = dtype
c.capacity = capacity
c.backend = backend
}
func (c *MLXCausal) SetConfig(config ml.CacheConfig) {}
func (c *MLXCausal) SetLayer(layer int) {
c.curLayer = layer
}
func (c *MLXCausal) Close() {
// slog.Info("XXX MLXCausal.Close called", "number of contexts", len(c.ctxs))
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *MLXCausal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
locsPut := make([]int32, len(batch.Positions))
for i := c.offset; i < len(batch.Positions); i++ {
locsPut[i-c.offset] = int32(i)
}
c.offset += len(batch.Positions)
locsGet := make([]int32, c.offset)
for i := range c.offset {
locsGet[i] = int32(i)
}
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
// slog.Info("XXX MLXCausal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
return nil
}
func (c *MLXCausal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(3)
vHeadDim := value.Dim(3)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
kCellSize := kHeadDim * numKVHeads
vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX MLXCausal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
c.kHeadDims[c.curLayer] = kHeadDim
c.vHeadDims[c.curLayer] = vHeadDim
c.numKVHeads[c.curLayer] = numKVHeads
}
key = key.Reshape(ctx, batchSize, 1, kCellSize)
// slog.Info("XXX MLXCausal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
// slog.Info("XXX MLXCausal.Put ", "c.curLocPut", c.curLocPut)
// slog.Info("XXX MLXCausal.Put ", "key", key)
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
value = value.Reshape(ctx, batchSize, 1, vCellSize)
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
}
func (c *MLXCausal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := c.kHeadDims[c.curLayer]
vHeadDim := c.vHeadDims[c.curLayer]
numKVHeads := c.numKVHeads[c.curLayer]
// rowSize := numKVHeads * c.curBatchSize
// cachedSize := c.curMask.Dim(1)
cachedSize := c.curLocGet.Dim(0)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX MLXCausal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
return key, value, nil
}
func (c *MLXCausal) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("not implemented")
}
func (c *MLXCausal) CanResume(seq int, pos int32) bool {
panic("not implemented")
}
func (c *MLXCausal) Remove(seq int, beginIndex, endIndex int32) error {
panic("not implemented")
}

134
x/mlxrunner/imagegen.go Normal file
View File

@@ -0,0 +1,134 @@
//go:build mlx
package mlxrunner
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"sync"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
// ImageModel is the interface for image generation models.
type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
}
var imageGenMu sync.Mutex
// loadImageModel loads an image generation model.
func (s *server) loadImageModel() error {
// Check memory requirements before loading
var requiredMemory uint64
if manifest, err := imagegen.LoadManifest(s.modelName); err == nil {
requiredMemory = uint64(manifest.TotalTensorSize())
}
availableMemory := mlx.GetMemoryLimit()
if availableMemory > 0 && requiredMemory > 0 && availableMemory < requiredMemory {
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
}
// Detect model type and load appropriate model
modelType := imagegen.DetectModelType(s.modelName)
slog.Info("detected image model type", "type", modelType)
var model ImageModel
switch modelType {
case "Flux2KleinPipeline":
m := &flux2.Model{}
if err := m.Load(s.modelName); err != nil {
return fmt.Errorf("failed to load flux2 model: %w", err)
}
model = m
default:
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
m := &zimage.Model{}
if err := m.Load(s.modelName); err != nil {
return fmt.Errorf("failed to load zimage model: %w", err)
}
model = m
}
s.imageModel = model
return nil
}
// handleImageCompletion handles image generation requests.
func (s *server) handleImageCompletion(w http.ResponseWriter, r *http.Request, req Request) {
// Serialize generation requests - MLX model may not handle concurrent generation
imageGenMu.Lock()
defer imageGenMu.Unlock()
// Set seed if not provided
if req.Seed <= 0 {
req.Seed = time.Now().UnixNano()
}
// Set up streaming response
w.Header().Set("Content-Type", "application/x-ndjson")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
ctx := r.Context()
enc := json.NewEncoder(w)
// Progress callback streams step updates
progress := func(step, total int) {
resp := Response{Step: step, Total: total}
enc.Encode(resp)
w.Write([]byte("\n"))
flusher.Flush()
}
// Generate image
img, err := s.imageModel.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
if err != nil {
// Don't send error for cancellation
if ctx.Err() != nil {
return
}
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Encode image as base64 PNG
imageData, err := imagegen.EncodeImageBase64(img)
if err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Free the generated image array and clean up MLX state
img.Free()
mlx.ClearCache()
mlx.MetalResetPeakMemory()
// Send final response with image data
resp := Response{
Image: imageData,
Done: true,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
}

420
x/mlxrunner/llm.go Normal file
View File

@@ -0,0 +1,420 @@
//go:build mlx
package mlxrunner
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// TextModel is the interface for LLM text generation models.
type TextModel interface {
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
NewCache(maxSeqLen int32) []cache.Cache
Tokenizer() *tokenizer.Tokenizer
VocabSize() int32
MaxContextLength() int32
NumLayers() int
}
// llmState holds the state for LLM generation
type llmState struct {
model TextModel
}
var llmMu sync.Mutex
// Dedicated stream for generation (like mlx-lm's generation_stream)
var generationStream *mlx.Stream
// withStream runs fn with the generation stream as default
func withStream(fn func()) {
// Lazy initialization of generationStream
if generationStream == nil {
generationStream = mlx.NewStream()
}
orig := mlx.GetDefaultStream()
mlx.SetDefaultStream(generationStream)
fn()
mlx.SetDefaultStream(orig)
}
// Decoder wraps model + cache for autoregressive generation.
// This matches the pattern from cmd/engine/generate.go
type Decoder struct {
model TextModel
caches []cache.Cache
vocabSize int32
temp float32
token *mlx.Array // Current token (kept across iterations)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
}
func NewDecoder(m TextModel, temp float32) *Decoder {
caches := m.NewCache(0)
return &Decoder{
model: m,
caches: caches,
vocabSize: m.VocabSize(),
temp: temp,
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
}
}
func (d *Decoder) prefill(inputIDs []int32) int {
processed := 0
// Track old cache state to free after each chunk
var oldCacheState []*mlx.Array
// Process all-but-1 tokens in chunks, eval cache state for memory management
for len(inputIDs) > 1 {
chunkSize := min(2048, len(inputIDs)-1)
if chunkSize <= 0 {
break
}
chunk := inputIDs[:chunkSize]
// Save old cache state before forward
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
var cacheState []*mlx.Array
withStream(func() {
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
d.model.Forward(x, d.caches)
for _, c := range d.caches {
cacheState = append(cacheState, c.State()...)
}
})
mlx.Eval(cacheState...)
// Free old cache state
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
inputIDs = inputIDs[chunkSize:]
processed += chunkSize
}
// Save old cache state before final step
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
// Final token + sampling
withStream(func() {
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
mlx.Eval(x) // Materialize before any other evals
logits := d.model.Forward(x, d.caches)
d.token = sample(logits, d.temp, d.vocabSize)
})
// Keep cache state (token auto-kept by AsyncEval)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Free old cache state from before final step
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
mlx.ClearCache()
return processed + len(inputIDs)
}
func (d *Decoder) step() int32 {
prevToken := d.token
// Save old cache state (reuse preallocated slice)
d.oldCacheState = d.oldCacheState[:0]
for _, c := range d.caches {
d.oldCacheState = append(d.oldCacheState, c.State()...)
}
withStream(func() {
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
d.token = sample(logits, d.temp, d.vocabSize)
})
// Keep token and new cache state so they survive cleanup
mlx.Keep(d.token)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Sync on previous token (GPU already working on next step)
val := prevToken.ItemInt32()
// Free old token and old cache state
prevToken.Free()
for _, arr := range d.oldCacheState {
arr.Free()
}
return val
}
// sample samples from logits using temperature scaling
func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
// Get last position logits: [1, L, vocab] -> [vocab]
shape := logits.Shape()
seqLen := shape[1]
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocabSize})
lastLogits = mlx.Reshape(lastLogits, vocabSize)
if temp <= 0 || temp < 0.01 {
// Greedy decoding
return mlx.Argmax(lastLogits, -1, false)
}
// Apply temperature scaling
scaled := mlx.DivScalar(lastLogits, temp)
return mlx.RandomCategorical(scaled, -1, 1)
}
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
func (s *server) loadLLMModel() error {
// Load the manifest to get model information
manifest, err := imagegen.LoadManifest(s.modelName)
if err != nil {
return fmt.Errorf("failed to load manifest: %w", err)
}
// Detect model architecture from config.json
configData, err := manifest.ReadConfig("config.json")
if err != nil {
return fmt.Errorf("failed to read config.json: %w", err)
}
var modelConfig struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(configData, &modelConfig); err != nil {
return fmt.Errorf("failed to parse config.json: %w", err)
}
arch := ""
if len(modelConfig.Architectures) > 0 {
arch = modelConfig.Architectures[0]
}
if arch == "" {
arch = modelConfig.ModelType
}
slog.Info("detected LLM architecture", "architecture", arch, "model_type", modelConfig.ModelType)
// Load the appropriate model based on architecture
var model TextModel
archLower := strings.ToLower(arch)
switch {
case strings.Contains(archLower, "glm4moelite"):
m, err := glm4_moe_lite.LoadFromManifest(manifest)
if err != nil {
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
}
model = m
slog.Info("loaded glm4-moe-lite model", "vocab_size", m.VocabSize(), "layers", m.NumLayers())
default:
return fmt.Errorf("LLM architecture %q is not yet supported. "+
"Supported architectures: glm4-moe-lite. "+
"Please convert your model to GGUF format or use a supported architecture", arch)
}
s.llmModel = &llmState{
model: model,
}
return nil
}
// handleLLMCompletion handles LLM text generation requests.
func (s *server) handleLLMCompletion(w http.ResponseWriter, r *http.Request, req Request) {
if s.llmModel == nil {
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
return
}
// Serialize generation requests
llmMu.Lock()
defer llmMu.Unlock()
if err := s.llmGenerate(w, r, req); err != nil {
slog.Error("LLM generation failed", "error", err)
// Don't send error if we've already started streaming
}
}
// llmGenerate runs the generation loop using the Decoder pattern from cmd/engine
func (s *server) llmGenerate(w http.ResponseWriter, r *http.Request, req Request) error {
state := s.llmModel
// Set up streaming response
w.Header().Set("Content-Type", "application/x-ndjson")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
return errors.New("streaming not supported")
}
tok := state.model.Tokenizer()
// The prompt is already formatted by the server using the model's renderer
// (see server/prompt.go renderPrompt), so we don't apply FormatPrompt here.
prompt := req.Prompt
// Tokenize the prompt
inputIDs := tok.Encode(prompt, true)
slog.Debug("tokenized prompt", "num_tokens", len(inputIDs))
// Generation parameters
maxTokens := int(state.model.MaxContextLength())
if maxTokens <= 0 {
maxTokens = 4096
}
if req.Options != nil && req.Options.NumPredict > 0 {
maxTokens = req.Options.NumPredict
}
temperature := float32(0.7)
if req.Options != nil && req.Options.Temperature > 0 {
temperature = float32(req.Options.Temperature)
}
// Enable MLX compilation for better performance
mlx.EnableCompile()
// Create decoder with fresh caches
dec := NewDecoder(state.model, temperature)
prefillStart := time.Now()
prefillTokens := dec.prefill(inputIDs)
// Prefill measurement includes time to first token
firstToken := dec.step()
prefillDuration := time.Since(prefillStart)
promptEvalDuration := prefillDuration
enc := json.NewEncoder(w)
ctx := r.Context()
generated := 0
stopReason := "max_tokens"
// Handle first token
generated++
if tok.IsEOS(firstToken) {
resp := Response{
Done: true,
StopReason: fmt.Sprintf("first_token_eos:%d", firstToken),
PromptEvalCount: prefillTokens,
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
}
enc.Encode(resp)
flusher.Flush()
return nil
}
text := tok.Decode([]int32{firstToken})
resp := Response{Content: text}
enc.Encode(resp)
flusher.Flush()
genStart := time.Now()
// Generation loop
for n := 1; n < maxTokens; n++ {
// Check for cancellation
select {
case <-ctx.Done():
stopReason = fmt.Sprintf("context_cancelled:%d", generated)
break
default:
}
if stopReason != "max_tokens" {
break
}
token := dec.step()
generated++
if tok.IsEOS(token) {
stopReason = fmt.Sprintf("eos_token:%d", token)
break
}
text := tok.Decode([]int32{token})
// Check for stop sequences
if req.Options != nil && len(req.Options.Stop) > 0 {
shouldStop := false
var matchedStop string
for _, stop := range req.Options.Stop {
if strings.Contains(text, stop) {
text = strings.Split(text, stop)[0]
shouldStop = true
matchedStop = stop
break
}
}
if shouldStop {
if text != "" {
resp := Response{Content: text}
enc.Encode(resp)
flusher.Flush()
}
stopReason = fmt.Sprintf("stop_sequence:%s", matchedStop)
break
}
}
resp := Response{Content: text}
enc.Encode(resp)
flusher.Flush()
// Periodically clear MLX cache
if n%256 == 0 {
mlx.ClearCache()
}
}
// Clean up
mlx.ClearCache()
// Send final response with stats
evalDuration := time.Since(genStart)
resp = Response{
Done: true,
StopReason: fmt.Sprintf("%s:generated=%d", stopReason, generated),
PromptEvalCount: prefillTokens,
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
EvalCount: generated,
EvalDuration: int(evalDuration.Nanoseconds()),
}
enc.Encode(resp)
flusher.Flush()
return nil
}

204
x/mlxrunner/runner.go Normal file
View File

@@ -0,0 +1,204 @@
//go:build mlx
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
package mlxrunner
import (
"context"
"encoding/json"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// Execute is the entry point for the unified MLX runner subprocess.
func Execute(args []string) error {
// Set up logging with appropriate level from environment
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: envconfig.LogLevel()})))
fs := flag.NewFlagSet("mlx-runner", flag.ExitOnError)
modelName := fs.String("model", "", "path to model")
port := fs.Int("port", 0, "port to listen on")
if err := fs.Parse(args); err != nil {
return err
}
if *modelName == "" {
return fmt.Errorf("--model is required")
}
if *port == 0 {
return fmt.Errorf("--port is required")
}
// Initialize MLX
if err := mlx.InitMLX(); err != nil {
slog.Error("unable to initialize MLX", "error", err)
return err
}
slog.Info("MLX library initialized")
// Detect model type from capabilities
mode := detectModelMode(*modelName)
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
// Create and start server
server, err := newServer(*modelName, *port, mode)
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
}
// Set up HTTP handlers
mux := http.NewServeMux()
mux.HandleFunc("/health", server.healthHandler)
mux.HandleFunc("/completion", server.completionHandler)
// LLM-specific endpoints
if mode == ModeLLM {
mux.HandleFunc("/tokenize", server.tokenizeHandler)
mux.HandleFunc("/embedding", server.embeddingHandler)
}
httpServer := &http.Server{
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
Handler: mux,
}
// Handle shutdown
done := make(chan struct{})
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
slog.Info("shutting down mlx runner")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpServer.Shutdown(ctx)
close(done)
}()
slog.Info("mlx runner listening", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
return err
}
<-done
return nil
}
// detectModelMode determines whether a model is an LLM or image generation model.
func detectModelMode(modelName string) ModelMode {
// Check for image generation model by looking at model_index.json
modelType := imagegen.DetectModelType(modelName)
if modelType != "" {
// Known image generation model types
switch modelType {
case "ZImagePipeline", "FluxPipeline", "Flux2KleinPipeline":
return ModeImageGen
}
}
// Default to LLM mode for safetensors models without known image gen types
return ModeLLM
}
// server holds the model and handles HTTP requests.
type server struct {
mode ModelMode
modelName string
port int
// Image generation model (when mode == ModeImageGen)
imageModel ImageModel
// LLM model (when mode == ModeLLM)
llmModel *llmState
}
// newServer creates a new server instance and loads the appropriate model.
func newServer(modelName string, port int, mode ModelMode) (*server, error) {
s := &server{
mode: mode,
modelName: modelName,
port: port,
}
switch mode {
case ModeImageGen:
if err := s.loadImageModel(); err != nil {
return nil, fmt.Errorf("failed to load image model: %w", err)
}
case ModeLLM:
if err := s.loadLLMModel(); err != nil {
return nil, fmt.Errorf("failed to load LLM model: %w", err)
}
}
return s, nil
}
func (s *server) healthHandler(w http.ResponseWriter, r *http.Request) {
resp := HealthResponse{Status: "ok"}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch s.mode {
case ModeImageGen:
s.handleImageCompletion(w, r, req)
case ModeLLM:
s.handleLLMCompletion(w, r, req)
}
}
func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) {
if s.llmModel == nil {
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
return
}
var req struct {
Content string `json:"content"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
tok := s.llmModel.model.Tokenizer()
tokens := tok.Encode(req.Content, false)
// Convert int32 to int for JSON response
intTokens := make([]int, len(tokens))
for i, t := range tokens {
intTokens[i] = int(t)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens})
}
func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented)
}

View File

@@ -1,10 +1,10 @@
//go:build !mlx
package runner
package mlxrunner
import "errors"
// Execute returns an error when not built with MLX support.
func Execute(args []string) error {
return errors.New("image generation not available: build with mlx tag")
return errors.New("MLX runner not available: build with mlx tag")
}

View File

@@ -1,4 +1,4 @@
package imagegen
package mlxrunner
import (
"bufio"
@@ -23,19 +23,19 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen"
)
// Server wraps an image generation subprocess to implement llm.LlamaServer.
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
//
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
// like any other model. The plan is to eventually bring this into the llm/ package
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
// separate allows for independent iteration on image generation support.
// like any other model. It supports both LLM (safetensors) and image generation models.
type Server struct {
mu sync.Mutex
cmd *exec.Cmd
port int
modelName string
mode ModelMode
vramSize uint64
done chan error
client *http.Client
@@ -43,10 +43,10 @@ type Server struct {
lastErrLock sync.Mutex
}
// NewServer spawns a new image generation subprocess and waits until it's ready.
func NewServer(modelName string) (*Server, error) {
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
func NewServer(modelName string, mode ModelMode) (*Server, error) {
// Validate platform support before attempting to start
if err := CheckPlatformSupport(); err != nil {
if err := imagegen.CheckPlatformSupport(); err != nil {
return nil, err
}
@@ -71,8 +71,8 @@ func NewServer(modelName string) (*Server, error) {
exe = eval
}
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
// Spawn subprocess: ollama runner --mlx-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
@@ -105,17 +105,21 @@ func NewServer(modelName string) (*Server, error) {
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
// Get total weight size from manifest
var weightSize uint64
if manifest, err := LoadManifest(modelName); err == nil {
weightSize = uint64(manifest.TotalTensorSize())
// Estimate VRAM based on tensor size from manifest
var vramSize uint64
if manifest, err := imagegen.LoadManifest(modelName); err == nil {
vramSize = uint64(manifest.TotalTensorSize())
} else {
// Fallback: default to 8GB if manifest can't be loaded
vramSize = 8 * 1024 * 1024 * 1024
}
s := &Server{
cmd: cmd,
port: port,
modelName: modelName,
vramSize: weightSize,
mode: mode,
vramSize: vramSize,
done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute},
}
@@ -126,23 +130,23 @@ func NewServer(modelName string) (*Server, error) {
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
slog.Info("image-runner", "msg", scanner.Text())
slog.Info("mlx-runner", "msg", scanner.Text())
}
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
slog.Warn("image-runner", "msg", line)
slog.Warn("mlx-runner", "msg", line)
s.lastErrLock.Lock()
s.lastErr = line
s.lastErrLock.Unlock()
}
}()
slog.Info("starting image runner subprocess", "exe", exe, "model", modelName, "port", port)
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port, "mode", mode)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start image runner: %w", err)
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
}
// Reap subprocess when it exits
@@ -165,6 +169,7 @@ func (s *Server) ModelPath() string {
return s.modelName
}
// Load satisfies the LlamaServer interface. MLX models don't need GPU layer assignment.
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
return nil, nil
}
@@ -200,18 +205,18 @@ func (s *Server) waitUntilRunning() error {
// Include recent stderr lines for better error context
errMsg := s.getLastErr()
if errMsg != "" {
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
}
return fmt.Errorf("image runner exited unexpectedly: %w", err)
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
case <-timeout:
errMsg := s.getLastErr()
if errMsg != "" {
return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
}
return errors.New("timeout waiting for image runner to start")
return errors.New("timeout waiting for mlx runner to start")
case <-ticker.C:
if err := s.Ping(ctx); err == nil {
slog.Info("image runner is ready", "port", s.port)
slog.Info("mlx runner is ready", "port", s.port)
return nil
}
}
@@ -225,8 +230,12 @@ func (s *Server) getLastErr() string {
return s.lastErr
}
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
// WaitUntilRunning satisfies the LlamaServer interface.
func (s *Server) WaitUntilRunning(ctx context.Context) error {
return nil
}
// Completion handles both text and image generation requests.
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
seed := req.Seed
if seed == 0 {
@@ -240,22 +249,26 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
}
// Build request for subprocess
creq := struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
Images [][]byte `json:"images,omitempty"`
}{
creq := Request{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Steps: int(req.Steps),
Seed: seed,
Images: images,
}
// Pass LLM options if present
if req.Options != nil {
creq.Options = &RequestOptions{
NumPredict: req.Options.NumPredict,
Temperature: float64(req.Options.Temperature),
TopP: float64(req.Options.TopP),
TopK: req.Options.TopK,
Stop: req.Options.Stop,
}
}
body, err := json.Marshal(creq)
if err != nil {
return err
@@ -282,25 +295,40 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() {
// Parse subprocess response (has singular "image" field)
// Parse subprocess response
var raw struct {
Image string `json:"image,omitempty"`
Content string `json:"content,omitempty"`
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
Image string `json:"image,omitempty"`
Content string `json:"content,omitempty"`
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue
}
// Log stop reason when generation completes
if raw.Done && raw.StopReason != "" {
slog.Info("mlx generation completed", "stop_reason", raw.StopReason)
}
// Convert to llm.CompletionResponse
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
Step: raw.Step,
TotalSteps: raw.Total,
Image: raw.Image,
Content: raw.Content,
Done: raw.Done,
Step: raw.Step,
TotalSteps: raw.Total,
Image: raw.Image,
PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
EvalCount: raw.EvalCount,
EvalDuration: time.Duration(raw.EvalDuration),
}
fn(cresp)
@@ -309,7 +337,20 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
}
}
return scanner.Err()
// Scanner exited without receiving Done - connection was likely closed
scanErr := scanner.Err()
if scanErr != nil {
slog.Error("mlx scanner error", "error", scanErr)
} else {
slog.Warn("mlx scanner EOF without Done response - subprocess may have crashed")
}
// Check if subprocess is still alive
if s.HasExited() {
slog.Error("mlx subprocess has exited unexpectedly")
}
return scanErr
}
// Close terminates the subprocess.
@@ -318,7 +359,7 @@ func (s *Server) Close() error {
defer s.mu.Unlock()
if s.cmd != nil && s.cmd.Process != nil {
slog.Info("stopping image runner subprocess", "pid", s.cmd.Process.Pid)
slog.Info("stopping mlx runner subprocess", "pid", s.cmd.Process.Pid)
s.cmd.Process.Signal(os.Interrupt)
// Wait briefly for graceful shutdown
@@ -347,23 +388,56 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
return s.vramSize
}
// Context length is not applicable for image generation.
// ContextLength returns the context length (not applicable for image generation).
func (s *Server) ContextLength() int {
return 0
}
// Embedding returns embeddings for the input.
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
return nil, 0, errors.New("not supported")
return nil, 0, errors.New("embeddings not supported for MLX models")
}
// Tokenize tokenizes the input content.
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
return nil, errors.New("not supported")
body, err := json.Marshal(map[string]string{"content": content})
if err != nil {
return nil, err
}
url := fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("tokenize failed: %d", resp.StatusCode)
}
var result struct {
Tokens []int `json:"tokens"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return result.Tokens, nil
}
// Detokenize converts tokens back to text.
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
return "", errors.New("not supported")
return "", errors.New("detokenization not supported for MLX models")
}
// Pid returns the process ID of the subprocess.
func (s *Server) Pid() int {
s.mu.Lock()
defer s.mu.Unlock()
@@ -373,9 +447,17 @@ func (s *Server) Pid() int {
return -1
}
func (s *Server) GetPort() int { return s.port }
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
// GetPort returns the port the subprocess is listening on.
func (s *Server) GetPort() int {
return s.port
}
// GetDeviceInfos returns device information.
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil
}
// HasExited returns whether the subprocess has exited.
func (s *Server) HasExited() bool {
select {
case <-s.done:

81
x/mlxrunner/types.go Normal file
View File

@@ -0,0 +1,81 @@
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
//
// This package handles safetensors models created with `ollama create --experimental`,
// supporting both text generation (LLM) and image generation (diffusion) models
// through a single unified interface.
package mlxrunner
// Request is the request format for completion requests.
type Request struct {
Prompt string `json:"prompt"`
// LLM-specific fields
Options *RequestOptions `json:"options,omitempty"`
// Image generation fields
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
}
// RequestOptions contains LLM-specific generation options.
type RequestOptions struct {
NumPredict int `json:"num_predict,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop []string `json:"stop,omitempty"`
}
// Response is streamed back for each progress update.
type Response struct {
// Text generation response
Content string `json:"content,omitempty"`
// Image generation response
Image string `json:"image,omitempty"` // Base64-encoded PNG
// Common fields
Done bool `json:"done"`
DoneReason int `json:"done_reason,omitempty"`
StopReason string `json:"stop_reason,omitempty"` // Debug: why generation stopped
// Progress fields
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
// Statistics
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
}
// HealthResponse is returned by the health endpoint.
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress,omitempty"`
}
// ModelMode represents the type of model being run.
type ModelMode int
const (
// ModeLLM indicates a text generation model.
ModeLLM ModelMode = iota
// ModeImageGen indicates an image generation model.
ModeImageGen
)
func (m ModelMode) String() string {
switch m {
case ModeLLM:
return "llm"
case ModeImageGen:
return "imagegen"
default:
return "unknown"
}
}

View File

@@ -87,7 +87,7 @@ func New(c fs.Config) (model.Model, error) {
// m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
// TODO need to implement sliding window...
m.Cache = kvcache.NewMLXCausalCache()
m.Cache = kvcache.NewCausalCache()
return &m, nil
}

View File

@@ -163,9 +163,18 @@ func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
// getTensorInfoFromManifest extracts tensor info from a manifest.
// This is separated for testability.
// For quantized models, groups weight/scale/qbias into single entries with detected quantization type.
func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
var tensors []api.Tensor
// First pass: collect all tensor info and identify scale tensors
type tensorData struct {
info *safetensorsTensorInfo
digest string
}
tensorMap := make(map[string]*tensorData)
scaleMap := make(map[string]*tensorData) // base name -> scale tensor info
for _, layer := range mf.Layers {
if layer.MediaType != manifest.MediaTypeImageTensor {
continue
@@ -178,28 +187,96 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
}
info, err := readSafetensorsHeader(blobPath)
if err != nil {
// Skip tensors we can't read
continue
}
// Convert shape from int to uint64
shape := make([]uint64, len(info.Shape))
for i, s := range info.Shape {
shape[i] = uint64(s)
td := &tensorData{info: info, digest: layer.Digest}
if strings.HasSuffix(layer.Name, "_scale") {
baseName := strings.TrimSuffix(layer.Name, "_scale")
scaleMap[baseName] = td
} else if strings.HasSuffix(layer.Name, "_qbias") {
// Skip qbias tensors - they're included with the quantized weight
continue
} else {
tensorMap[layer.Name] = td
}
}
// Second pass: build tensor list with quantization info
for _, layer := range mf.Layers {
if layer.MediaType != manifest.MediaTypeImageTensor {
continue
}
tensors = append(tensors, api.Tensor{
Name: layer.Name,
Type: info.Dtype,
Shape: shape,
})
// Skip scale and qbias tensors
if strings.HasSuffix(layer.Name, "_scale") || strings.HasSuffix(layer.Name, "_qbias") {
continue
}
td := tensorMap[layer.Name]
if td == nil {
continue
}
// Check if this tensor has a corresponding scale tensor (quantized)
scaleTd := scaleMap[layer.Name]
if scaleTd != nil && len(td.info.Shape) >= 2 && len(scaleTd.info.Shape) >= 2 {
// Quantized tensor - detect bits from shapes
weightCols := td.info.Shape[len(td.info.Shape)-1]
scaleCols := scaleTd.info.Shape[len(scaleTd.info.Shape)-1]
// Detect quantization: Q4 has pack_factor=8, Q8 has pack_factor=4
// Q4 uses group_size=32: weightCols * 8 / scaleCols = 32
// Q8 uses group_size=64: weightCols * 4 / scaleCols = 64
var bits int
var quantType string
if weightCols*8/scaleCols == 32 {
bits = 4
quantType = "Q4"
} else if weightCols*4/scaleCols == 64 {
bits = 8
quantType = "Q8"
} else {
// Unknown quantization, show raw
quantType = td.info.Dtype
}
// Calculate unpacked shape
shape := make([]uint64, len(td.info.Shape))
for i, s := range td.info.Shape {
shape[i] = uint64(s)
}
if bits > 0 {
packFactor := int64(32 / bits)
shape[len(shape)-1] = uint64(td.info.Shape[len(td.info.Shape)-1] * packFactor)
}
tensors = append(tensors, api.Tensor{
Name: layer.Name,
Type: quantType,
Shape: shape,
})
} else {
// Non-quantized tensor
shape := make([]uint64, len(td.info.Shape))
for i, s := range td.info.Shape {
shape[i] = uint64(s)
}
tensors = append(tensors, api.Tensor{
Name: layer.Name,
Type: td.info.Dtype,
Shape: shape,
})
}
}
return tensors, nil
}
// GetSafetensorsDtype returns the quantization type for a safetensors model.
// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
// Reads from model_index.json first, falls back to detection from tensor names.
// Otherwise returns the torch_dtype from config.json.
func GetSafetensorsDtype(name model.Name) (string, error) {
mf, err := manifest.ParseNamedManifest(name)
@@ -207,16 +284,38 @@ func GetSafetensorsDtype(name model.Name) (string, error) {
return "", fmt.Errorf("failed to load manifest: %w", err)
}
// Check if model is quantized by looking for _scale tensors
// First try to read quantization from model_index.json
var modelIndex struct {
Quantization string `json:"quantization"`
}
if err := mf.ReadConfigJSON("model_index.json", &modelIndex); err == nil && modelIndex.Quantization != "" {
return modelIndex.Quantization, nil
}
// Fallback: detect from tensor names
hasScales := false
hasQBias := false
for _, layer := range mf.Layers {
if layer.MediaType == manifest.MediaTypeImageTensor {
if strings.HasSuffix(layer.Name, "_scale") {
// Model is quantized - return FP8 (affine quantization)
return "FP8", nil
hasScales = true
}
if strings.HasSuffix(layer.Name, "_qbias") {
hasQBias = true
}
}
}
if hasScales {
if hasQBias {
// Affine mode (has scale + qbias) - could be Q4 or Q8
// Default to Q4 as it's more common
return "Q4", nil
}
// No qbias = NVFP4
return "NVFP4", nil
}
// Not quantized - return torch_dtype from config.json
var cfg struct {
TorchDtype string `json:"torch_dtype"`