Compare commits

...

8 Commits

Author SHA1 Message Date
Eva H
3323c1d319 app: add upgrade configuration to settings page (#13512) 2026-02-23 18:08:52 -05:00
Jesse Gross
f20dc6b698 mlx: don't default to affine quantization for unquantized models
Otherwise the BF16 version of models trigger segfaults when they
call into quantized kernels.
2026-02-23 15:03:53 -08:00
Jeffrey Morgan
4b2ac1f369 model: improvements to LFM architectures (#14368) 2026-02-23 14:38:10 -08:00
Jesse Gross
8daf47fb3a mlxrunner: Fix duplicate log prefixes and reduce log noise
Pass subprocess stdout/stderr through to the parent's stderr directly
instead of re-wrapping each line with slog. The subprocess already
writes structured slog output, so the re-wrapping produced nested
timestamps, levels, and message fields that were hard to read.

Also downgrade verbose KV cache debug logs to trace level.
2026-02-23 14:09:20 -08:00
Eva H
6c980579cd ui: use capability-based detection for web search (#14336) 2026-02-23 15:00:09 -05:00
Jesse Gross
5c73c4e2ee mlxrunner: Simplify KV cache to single-entry prefix matching
The KV cache previously used a tree structure which could
store multiple divergent sequences, which is good for cache
reuse. However, this is typically used in conjunction with
paged attention so each node in the tree can store just a
chunk of the KV cache and they can be stitched together later.
We don't currently do this, so the cache was storing copies of
the full cache for each past sequence.

This redundancy plus the lack of resource limits, caused significant
memory use as a conversation grew. Instead, this changes to store
a single entry for the cache, which can be prefix matched. Although
it is less ideal for multiple users, it largely matches Ollama's
current behavior. It can be improved as additional pieces are fleshed
out.
2026-02-23 09:50:07 -08:00
Jesse Gross
5daf59cc66 mlxrunner: Fix memory leaks with pin/sweep lifecycle management
The previous approach tracked array lifecycles through reference
counting, where each array recorded its inputs and a reference count
that was decremented as dependents were freed. This is not really
necessary as MLX tracks references internally. It is also error
prone as it is easy to create new arrays and forget to free them
when the Go variable goes out of scope.

Instead, we can pin just the arrays we want (typically outputs and
specific intermediates, like the cache). All other arrays are freed
by default when we run sweep. This avoids most causes of memory leaks
while still giving the freedom to save what we want.
2026-02-23 09:50:07 -08:00
Jeffrey Morgan
0ade9205cc models: add nemotronh architecture support (#14356) 2026-02-22 15:09:14 -08:00
74 changed files with 7540 additions and 1727 deletions

View File

@@ -253,6 +253,8 @@ func main() {
done <- osrv.Run(octx)
}()
upd := &updater.Updater{Store: st}
uiServer := ui.Server{
Token: token,
Restart: func() {
@@ -267,6 +269,10 @@ func main() {
ToolRegistry: toolRegistry,
Dev: devMode,
Logger: slog.Default(),
Updater: upd,
UpdateAvailableFunc: func() {
UpdateAvailable("")
},
}
srv := &http.Server{
@@ -284,8 +290,13 @@ func main() {
slog.Debug("background desktop server done")
}()
updater := &updater.Updater{Store: st}
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
// Check for pending updates on startup (show tray notification if update is ready)
if updater.IsUpdatePending() {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
if err != nil {
@@ -348,6 +359,18 @@ func startHiddenTasks() {
// CLI triggered app startup use-case
slog.Info("deferring pending update for fast startup")
} else {
// Check if auto-update is enabled before automatically upgrading
st := &store.Store{}
settings, err := st.Settings()
if err != nil {
slog.Warn("failed to load settings for upgrade check", "error", err)
} else if !settings.AutoUpdateEnabled {
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
// Still show tray notification so user knows update is ready
UpdateAvailable("")
return
}
if err := updater.DoUpgradeAtStartup(); err != nil {
slog.Info("unable to perform upgrade at startup", "error", err)
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization

View File

@@ -9,12 +9,12 @@ import (
"strings"
"time"
sqlite3 "github.com/mattn/go-sqlite3"
_ "github.com/mattn/go-sqlite3"
)
// currentSchemaVersion defines the current database schema version.
// Increment this when making schema changes that require migrations.
const currentSchemaVersion = 14
const currentSchemaVersion = 15
// database wraps the SQLite connection.
// SQLite handles its own locking for concurrent access:
@@ -86,6 +86,7 @@ func (db *database) init() error {
think_level TEXT NOT NULL DEFAULT '',
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
remote TEXT NOT NULL DEFAULT '', -- deprecated
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
schema_version INTEGER NOT NULL DEFAULT %d
);
@@ -257,6 +258,12 @@ func (db *database) migrate() error {
return fmt.Errorf("migrate v13 to v14: %w", err)
}
version = 14
case 14:
// add auto_update_enabled column to settings table
if err := db.migrateV14ToV15(); err != nil {
return fmt.Errorf("migrate v14 to v15: %w", err)
}
version = 15
default:
// If we have a version we don't recognize, just set it to current
// This might happen during development
@@ -496,6 +503,21 @@ func (db *database) migrateV13ToV14() error {
return nil
}
// migrateV14ToV15 adds the auto_update_enabled column to the settings table
func (db *database) migrateV14ToV15() error {
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
if err != nil && !duplicateColumnError(err) {
return fmt.Errorf("add auto_update_enabled column: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 15`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
func (db *database) cleanupOrphanedData() error {
_, err := db.conn.Exec(`
@@ -526,19 +548,11 @@ func (db *database) cleanupOrphanedData() error {
}
func duplicateColumnError(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok {
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "duplicate column name")
}
return false
return err != nil && strings.Contains(err.Error(), "duplicate column name")
}
func columnNotExists(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok {
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "no such column")
}
return false
return err != nil && strings.Contains(err.Error(), "no such column")
}
func (db *database) getAllChats() ([]Chat, error) {
@@ -1152,9 +1166,9 @@ func (db *database) getSettings() (Settings, error) {
var s Settings
err := db.conn.QueryRow(`
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
FROM settings
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
if err != nil {
return Settings{}, fmt.Errorf("get settings: %w", err)
}
@@ -1164,9 +1178,9 @@ func (db *database) getSettings() (Settings, error) {
func (db *database) setSettings(s Settings) error {
_, err := db.conn.Exec(`
UPDATE settings
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
UPDATE settings
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
if err != nil {
return fmt.Errorf("set settings: %w", err)
}

View File

@@ -166,6 +166,9 @@ type Settings struct {
// SidebarOpen indicates if the chat sidebar is open
SidebarOpen bool
// AutoUpdateEnabled indicates if automatic updates should be downloaded
AutoUpdateEnabled bool
}
type Store struct {

View File

@@ -414,6 +414,7 @@ export class Settings {
ThinkLevel: string;
SelectedModel: string;
SidebarOpen: boolean;
AutoUpdateEnabled: boolean;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
@@ -431,6 +432,7 @@ export class Settings {
this.ThinkLevel = source["ThinkLevel"];
this.SelectedModel = source["SelectedModel"];
this.SidebarOpen = source["SidebarOpen"];
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
}
}
export class SettingsResponse {

View File

@@ -17,7 +17,10 @@ import {
} from "@/hooks/useChats";
import { useNavigate } from "@tanstack/react-router";
import { useSelectedModel } from "@/hooks/useSelectedModel";
import { useHasVisionCapability } from "@/hooks/useModelCapabilities";
import {
useHasVisionCapability,
useHasToolsCapability,
} from "@/hooks/useModelCapabilities";
import { useUser } from "@/hooks/useUser";
import { DisplayLogin } from "@/components/DisplayLogin";
import { ErrorEvent, Message } from "@/gotypes";
@@ -149,12 +152,7 @@ function ChatForm({
} = useSettings();
const { cloudDisabled } = useCloudStatus();
// current supported models for web search
const modelLower = selectedModel?.model.toLowerCase() || "";
const supportsWebSearch =
modelLower.startsWith("gpt-oss") ||
modelLower.startsWith("qwen3") ||
modelLower.startsWith("deepseek-v3");
const supportsWebSearch = useHasToolsCapability(selectedModel?.model);
// Use per-chat thinking level instead of global
const thinkLevel: ThinkingLevel =
settingsThinkLevel === "none" || !settingsThinkLevel

View File

@@ -15,6 +15,7 @@ import {
XMarkIcon,
CogIcon,
ArrowLeftIcon,
ArrowDownTrayIcon,
} from "@heroicons/react/20/solid";
import { Settings as SettingsType } from "@/gotypes";
import { useNavigate } from "@tanstack/react-router";
@@ -440,6 +441,29 @@ export default function Settings() {
</div>
</Field>
{/* Auto Update */}
<Field>
<div className="flex items-start justify-between gap-4">
<div className="flex items-start space-x-3 flex-1">
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
<div>
<Label>Auto-download updates</Label>
<Description>
{settings.AutoUpdateEnabled
? "Automatically download updates when available."
: "Updates will not be downloaded automatically."}
</Description>
</div>
</div>
<div className="flex-shrink-0">
<Switch
checked={settings.AutoUpdateEnabled}
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
/>
</div>
</div>
</Field>
{/* Expose Ollama */}
<Field>
<div className="flex items-start justify-between gap-4">

View File

@@ -20,3 +20,8 @@ export function useHasVisionCapability(modelName: string | undefined) {
const { data: capabilitiesResponse } = useModelCapabilities(modelName);
return capabilitiesResponse?.capabilities?.includes("vision") ?? false;
}
export function useHasToolsCapability(modelName: string | undefined) {
const { data: capabilitiesResponse } = useModelCapabilities(modelName);
return capabilitiesResponse?.capabilities?.includes("tools") ?? false;
}

View File

@@ -28,6 +28,7 @@ import (
"github.com/ollama/ollama/app/tools"
"github.com/ollama/ollama/app/types/not"
"github.com/ollama/ollama/app/ui/responses"
"github.com/ollama/ollama/app/updater"
"github.com/ollama/ollama/app/version"
ollamaAuth "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
@@ -106,6 +107,10 @@ type Server struct {
// Dev is true if the server is running in development mode
Dev bool
// Updater for checking and downloading updates
Updater *updater.Updater
UpdateAvailableFunc func()
}
func (s *Server) log() *slog.Logger {
@@ -829,8 +834,9 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
if !hasAttachments {
WebSearchEnabled := req.WebSearch != nil && *req.WebSearch
hasToolsCapability := slices.Contains(details.Capabilities, model.CapabilityTools)
if WebSearchEnabled {
if WebSearchEnabled && hasToolsCapability {
if supportsBrowserTools(req.Model) {
browserState, ok := s.browserState(chat)
if !ok {
@@ -840,7 +846,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
registry.Register(tools.NewBrowserSearch(browser))
registry.Register(tools.NewBrowserOpen(browser))
registry.Register(tools.NewBrowserFind(browser))
} else if supportsWebSearchTools(req.Model) {
} else {
registry.Register(&tools.WebSearch{})
registry.Register(&tools.WebFetch{})
}
@@ -1446,6 +1452,24 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
return fmt.Errorf("failed to save settings: %w", err)
}
// Handle auto-update toggle changes
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
if !settings.AutoUpdateEnabled {
// Auto-update disabled: cancel any ongoing download
if s.Updater != nil {
s.Updater.CancelOngoingDownload()
}
} else {
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
s.UpdateAvailableFunc()
} else if s.Updater != nil {
// Trigger the background checker to run immediately
s.Updater.TriggerImmediateCheck()
}
}
}
if old.ContextLength != settings.ContextLength ||
old.Models != settings.Models ||
old.Expose != settings.Expose {
@@ -1648,17 +1672,6 @@ func supportsBrowserTools(model string) bool {
return strings.HasPrefix(strings.ToLower(model), "gpt-oss")
}
// Web search tools are simpler, providing only basic web search and fetch capabilities (e.g., "web_search", "web_fetch") without simulating a browser. Currently only qwen3 and deepseek-v3 support web search tools.
func supportsWebSearchTools(model string) bool {
model = strings.ToLower(model)
prefixes := []string{"qwen3", "deepseek-v3"}
for _, p := range prefixes {
if strings.HasPrefix(model, p) {
return true
}
}
return false
}
// buildChatRequest converts store.Chat to api.ChatRequest
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {

View File

@@ -4,6 +4,7 @@ package ui
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -11,9 +12,11 @@ import (
"path/filepath"
"runtime"
"strings"
"sync/atomic"
"testing"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/updater"
)
func TestHandlePostApiSettings(t *testing.T) {
@@ -522,3 +525,290 @@ func TestUserAgentTransport(t *testing.T) {
t.Logf("User-Agent transport successfully set: %s", receivedUA)
}
func TestSupportsBrowserTools(t *testing.T) {
tests := []struct {
model string
want bool
}{
{"gpt-oss", true},
{"gpt-oss-latest", true},
{"GPT-OSS", true},
{"Gpt-Oss-v2", true},
{"qwen3", false},
{"deepseek-v3", false},
{"llama3.3", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.model, func(t *testing.T) {
if got := supportsBrowserTools(tt.model); got != tt.want {
t.Errorf("supportsBrowserTools(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}
}
func TestWebSearchToolRegistration(t *testing.T) {
// Validates that the capability-gating logic in chat() correctly
// decides which tools to register based on model capabilities and
// the web search flag.
tests := []struct {
name string
webSearchEnabled bool
hasToolsCap bool
model string
wantBrowser bool // expects browser tools (gpt-oss)
wantWebSearch bool // expects basic web search/fetch tools
wantNone bool // expects no tools registered
}{
{
name: "web search enabled with tools capability - browser model",
webSearchEnabled: true,
hasToolsCap: true,
model: "gpt-oss-latest",
wantBrowser: true,
},
{
name: "web search enabled with tools capability - non-browser model",
webSearchEnabled: true,
hasToolsCap: true,
model: "qwen3",
wantWebSearch: true,
},
{
name: "web search enabled without tools capability",
webSearchEnabled: true,
hasToolsCap: false,
model: "llama3.3",
wantNone: true,
},
{
name: "web search disabled with tools capability",
webSearchEnabled: false,
hasToolsCap: true,
model: "qwen3",
wantNone: true,
},
{
name: "web search disabled without tools capability",
webSearchEnabled: false,
hasToolsCap: false,
model: "llama3.3",
wantNone: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Replicate the decision logic from chat() handler
gotBrowser := false
gotWebSearch := false
if tt.webSearchEnabled && tt.hasToolsCap {
if supportsBrowserTools(tt.model) {
gotBrowser = true
} else {
gotWebSearch = true
}
}
if tt.wantBrowser && !gotBrowser {
t.Error("expected browser tools to be registered")
}
if tt.wantWebSearch && !gotWebSearch {
t.Error("expected web search tools to be registered")
}
if tt.wantNone && (gotBrowser || gotWebSearch) {
t.Error("expected no tools to be registered")
}
if !tt.wantBrowser && gotBrowser {
t.Error("unexpected browser tools registered")
}
if !tt.wantWebSearch && gotWebSearch {
t.Error("unexpected web search tools registered")
}
})
}
}
func TestSettingsToggleAutoUpdateOff_CancelsDownload(t *testing.T) {
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close()
// Start with auto-update enabled
settings, err := testStore.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = true
if err := testStore.SetSettings(settings); err != nil {
t.Fatal(err)
}
upd := &updater.Updater{Store: &store.Store{
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
}}
defer upd.Store.Close()
// We can't easily mock CancelOngoingDownload, but we can verify
// the full settings handler flow works without error
server := &Server{
Store: testStore,
Restart: func() {},
Updater: upd,
}
// Disable auto-update via settings API
settings.AutoUpdateEnabled = false
body, err := json.Marshal(settings)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
if err := server.settings(rr, req); err != nil {
t.Fatalf("settings() error = %v", err)
}
if rr.Code != http.StatusOK {
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
}
// Verify settings were saved with auto-update disabled
saved, err := testStore.Settings()
if err != nil {
t.Fatal(err)
}
if saved.AutoUpdateEnabled {
t.Fatal("expected AutoUpdateEnabled to be false after toggle off")
}
}
func TestSettingsToggleAutoUpdateOn_WithPendingUpdate_ShowsNotification(t *testing.T) {
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close()
// Start with auto-update disabled
settings, err := testStore.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = false
if err := testStore.SetSettings(settings); err != nil {
t.Fatal(err)
}
// Simulate that an update was previously downloaded
oldVal := updater.UpdateDownloaded
updater.UpdateDownloaded = true
defer func() { updater.UpdateDownloaded = oldVal }()
var notificationCalled atomic.Bool
server := &Server{
Store: testStore,
Restart: func() {},
UpdateAvailableFunc: func() {
notificationCalled.Store(true)
},
}
// Re-enable auto-update via settings API
settings.AutoUpdateEnabled = true
body, err := json.Marshal(settings)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
if err := server.settings(rr, req); err != nil {
t.Fatalf("settings() error = %v", err)
}
if rr.Code != http.StatusOK {
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
}
if !notificationCalled.Load() {
t.Fatal("expected UpdateAvailableFunc to be called when re-enabling with a downloaded update")
}
}
func TestSettingsToggleAutoUpdateOn_NoPendingUpdate_TriggersCheck(t *testing.T) {
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close()
// Start with auto-update disabled
settings, err := testStore.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = false
if err := testStore.SetSettings(settings); err != nil {
t.Fatal(err)
}
// Ensure no pending update - clear both the downloaded flag and the stage dir
oldVal := updater.UpdateDownloaded
updater.UpdateDownloaded = false
defer func() { updater.UpdateDownloaded = oldVal }()
oldStageDir := updater.UpdateStageDir
updater.UpdateStageDir = t.TempDir() // empty dir means IsUpdatePending() returns false
defer func() { updater.UpdateStageDir = oldStageDir }()
upd := &updater.Updater{Store: &store.Store{
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
}}
defer upd.Store.Close()
// Initialize the checkNow channel by starting (and immediately stopping) the checker
// so TriggerImmediateCheck doesn't panic on nil channel
ctx, cancel := context.WithCancel(t.Context())
upd.StartBackgroundUpdaterChecker(ctx, func(string) error { return nil })
defer cancel()
var notificationCalled atomic.Bool
server := &Server{
Store: testStore,
Restart: func() {},
Updater: upd,
UpdateAvailableFunc: func() {
notificationCalled.Store(true)
},
}
// Re-enable auto-update via settings API
settings.AutoUpdateEnabled = true
body, err := json.Marshal(settings)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
if err := server.settings(rr, req); err != nil {
t.Fatalf("settings() error = %v", err)
}
if rr.Code != http.StatusOK {
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
}
// UpdateAvailableFunc should NOT be called since there's no pending update
if notificationCalled.Load() {
t.Fatal("UpdateAvailableFunc should not be called when there is no pending update")
}
}

View File

@@ -19,6 +19,7 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/app/store"
@@ -58,7 +59,8 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
query := requestURL.Query()
query.Add("os", runtime.GOOS)
query.Add("arch", runtime.GOARCH)
query.Add("version", version.Version)
currentVersion := version.Version
query.Add("version", currentVersion)
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
// The original macOS app used to use the device ID
@@ -131,15 +133,27 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
}
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
// Create a cancellable context for this download
downloadCtx, cancel := context.WithCancel(ctx)
u.cancelDownloadLock.Lock()
u.cancelDownload = cancel
u.cancelDownloadLock.Unlock()
defer func() {
u.cancelDownloadLock.Lock()
u.cancelDownload = nil
u.cancelDownloadLock.Unlock()
cancel()
}()
// Do a head first to check etag info
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil)
if err != nil {
return err
}
// In case of slow downloads, continue the update check in the background
bgctx, cancel := context.WithCancel(ctx)
defer cancel()
bgctx, bgcancel := context.WithCancel(downloadCtx)
defer bgcancel()
go func() {
for {
select {
@@ -176,6 +190,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
_, err = os.Stat(stageFilename)
if err == nil {
slog.Info("update already downloaded", "bundle", stageFilename)
UpdateDownloaded = true
return nil
}
@@ -244,33 +259,84 @@ func cleanupOldDownloads(stageDir string) {
}
type Updater struct {
Store *store.Store
Store *store.Store
cancelDownload context.CancelFunc
cancelDownloadLock sync.Mutex
checkNow chan struct{}
}
// CancelOngoingDownload cancels any currently running download
func (u *Updater) CancelOngoingDownload() {
u.cancelDownloadLock.Lock()
defer u.cancelDownloadLock.Unlock()
if u.cancelDownload != nil {
slog.Info("cancelling ongoing update download")
u.cancelDownload()
u.cancelDownload = nil
}
}
// TriggerImmediateCheck signals the background checker to check for updates immediately
func (u *Updater) TriggerImmediateCheck() {
if u.checkNow != nil {
select {
case u.checkNow <- struct{}{}:
default:
// Check already pending, no need to queue another
}
}
}
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
u.checkNow = make(chan struct{}, 1)
go func() {
// Don't blast an update message immediately after startup
time.Sleep(UpdateCheckInitialDelay)
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
ticker := time.NewTicker(UpdateCheckInterval)
defer ticker.Stop()
for {
available, resp := u.checkForUpdate(ctx)
if available {
err := u.DownloadNewRelease(ctx, resp)
if err != nil {
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
} else {
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
}
}
}
select {
case <-ctx.Done():
slog.Debug("stopping background update checker")
return
default:
time.Sleep(UpdateCheckInterval)
case <-u.checkNow:
// Immediate check triggered
case <-ticker.C:
// Regular interval check
}
// Always check for updates
available, resp := u.checkForUpdate(ctx)
if !available {
continue
}
// Update is available - check if auto-update is enabled for downloading
settings, err := u.Store.Settings()
if err != nil {
slog.Error("failed to load settings", "error", err)
continue
}
if !settings.AutoUpdateEnabled {
// Auto-update disabled - don't download, just log
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
continue
}
// Auto-update is enabled - download
err = u.DownloadNewRelease(ctx, resp)
if err != nil {
slog.Error("failed to download new release", "error", err)
continue
}
// Download successful - show tray notification (regardless of toggle state)
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn("failed to register update available with tray", "error", err)
}
}
}()

View File

@@ -11,6 +11,8 @@ import (
"log/slog"
"net/http"
"net/http/httptest"
"path/filepath"
"sync/atomic"
"testing"
"time"
@@ -33,7 +35,7 @@ func TestIsNewReleaseAvailable(t *testing.T) {
defer server.Close()
slog.Debug("server", "url", server.URL)
updater := &Updater{Store: &store.Store{}}
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
defer updater.Store.Close() // Ensure database is closed
UpdateCheckURLBase = server.URL + "/update.json"
updatePresent, resp := updater.checkForUpdate(t.Context())
@@ -84,8 +86,18 @@ func TestBackgoundChecker(t *testing.T) {
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close() // Ensure database is closed
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
defer updater.Store.Close()
settings, err := updater.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = true
if err := updater.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
select {
case <-stallTimer.C:
@@ -99,3 +111,264 @@ func TestBackgoundChecker(t *testing.T) {
}
}
}
func TestAutoUpdateDisabledSkipsDownload(t *testing.T) {
UpdateStageDir = t.TempDir()
var downloadAttempted atomic.Bool
done := make(chan struct{})
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
UpdateCheckInitialDelay = 5 * time.Millisecond
UpdateCheckInterval = 5 * time.Millisecond
VerifyDownload = func() error {
return nil
}
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
} else if r.URL.Path == "/9.9.9/"+Installer {
downloadAttempted.Store(true)
buf := &bytes.Buffer{}
zw := zip.NewWriter(buf)
zw.Close()
io.Copy(w, buf)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
defer updater.Store.Close()
// Ensure auto-update is disabled
settings, err := updater.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = false
if err := updater.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
cb := func(ver string) error {
t.Fatal("callback should not be called when auto-update is disabled")
return nil
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
// Wait enough time for multiple check cycles
time.Sleep(50 * time.Millisecond)
close(done)
if downloadAttempted.Load() {
t.Fatal("download should not be attempted when auto-update is disabled")
}
}
func TestAutoUpdateReenabledDownloadsUpdate(t *testing.T) {
UpdateStageDir = t.TempDir()
var downloadAttempted atomic.Bool
callbackCalled := make(chan struct{}, 1)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
UpdateCheckInitialDelay = 5 * time.Millisecond
UpdateCheckInterval = 5 * time.Millisecond
VerifyDownload = func() error {
return nil
}
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
} else if r.URL.Path == "/9.9.9/"+Installer {
downloadAttempted.Store(true)
buf := &bytes.Buffer{}
zw := zip.NewWriter(buf)
zw.Close()
io.Copy(w, buf)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
upd := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
defer upd.Store.Close()
// Start with auto-update disabled
settings, err := upd.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = false
if err := upd.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
cb := func(ver string) error {
select {
case callbackCalled <- struct{}{}:
default:
}
return nil
}
upd.StartBackgroundUpdaterChecker(ctx, cb)
// Wait for a few cycles with auto-update disabled - no download should happen
time.Sleep(50 * time.Millisecond)
if downloadAttempted.Load() {
t.Fatal("download should not happen while auto-update is disabled")
}
// Re-enable auto-update
settings.AutoUpdateEnabled = true
if err := upd.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
// Wait for the checker to pick it up and download
select {
case <-callbackCalled:
// Success: download happened and callback was called after re-enabling
if !downloadAttempted.Load() {
t.Fatal("expected download to be attempted after re-enabling")
}
case <-time.After(5 * time.Second):
t.Fatal("expected download and callback after re-enabling auto-update")
}
}
func TestCancelOngoingDownload(t *testing.T) {
UpdateStageDir = t.TempDir()
downloadStarted := make(chan struct{})
downloadCancelled := make(chan struct{})
ctx := t.Context()
VerifyDownload = func() error {
return nil
}
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
} else if r.URL.Path == "/9.9.9/"+Installer {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", "1000000")
w.WriteHeader(http.StatusOK)
return
}
// Signal that download has started
close(downloadStarted)
// Wait for cancellation or timeout
select {
case <-r.Context().Done():
close(downloadCancelled)
return
case <-time.After(5 * time.Second):
t.Error("download was not cancelled in time")
}
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
defer updater.Store.Close()
_, resp := updater.checkForUpdate(ctx)
// Start download in goroutine
go func() {
_ = updater.DownloadNewRelease(ctx, resp)
}()
// Wait for download to start
select {
case <-downloadStarted:
case <-time.After(2 * time.Second):
t.Fatal("download did not start in time")
}
// Cancel the download
updater.CancelOngoingDownload()
// Verify cancellation was received
select {
case <-downloadCancelled:
// Success
case <-time.After(2 * time.Second):
t.Fatal("download cancellation was not received by server")
}
}
func TestTriggerImmediateCheck(t *testing.T) {
UpdateStageDir = t.TempDir()
checkCount := atomic.Int32{}
checkDone := make(chan struct{}, 10)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
// Set a very long interval so only TriggerImmediateCheck causes checks
UpdateCheckInitialDelay = 1 * time.Millisecond
UpdateCheckInterval = 1 * time.Hour
VerifyDownload = func() error {
return nil
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
checkCount.Add(1)
select {
case checkDone <- struct{}{}:
default:
}
// Return no update available
w.WriteHeader(http.StatusNoContent)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
defer updater.Store.Close()
cb := func(ver string) error {
return nil
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
// Wait for goroutine to start and pass initial delay
time.Sleep(10 * time.Millisecond)
// With 1 hour interval, no check should have happened yet
initialCount := checkCount.Load()
// Trigger immediate check
updater.TriggerImmediateCheck()
// Wait for the triggered check
select {
case <-checkDone:
case <-time.After(2 * time.Second):
t.Fatal("triggered check did not happen")
}
finalCount := checkCount.Load()
if finalCount <= initialCount {
t.Fatalf("TriggerImmediateCheck did not cause additional check: initial=%d, final=%d", initialCount, finalCount)
}
}

View File

@@ -369,25 +369,6 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
return nil
}
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
// const ERROR_SUCCESS syscall.Errno = 0
// t.muMenus.RLock()
// menu := uintptr(t.menus[parentId])
// t.muMenus.RUnlock()
// res, _, err := pRemoveMenu.Call(
// menu,
// uintptr(menuItemId),
// MF_BYCOMMAND,
// )
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
// return err
// }
// t.delFromVisibleItems(parentId, menuItemId)
// return nil
// }
func (t *winTray) showMenu() error {
p := point{}
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))

View File

@@ -51,7 +51,6 @@ const (
IMAGE_ICON = 1 // Loads an icon
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
MF_BYCOMMAND = 0x00000000
MFS_DISABLED = 0x00000003
MFT_SEPARATOR = 0x00000800
MFT_STRING = 0x00000000

View File

@@ -257,10 +257,11 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
if err != nil {
return nil, nil, err
}
bts = sanitizeNonFiniteJSON(bts)
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("parse config.json: %w", err)
}
if len(p.Architectures) < 1 {
@@ -315,16 +316,20 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &glm4MoeLiteModel{}
case "GlmOcrForConditionalGeneration":
conv = &glmOcrModel{}
case "Lfm2ForCausalLM":
case "Lfm2ForCausalLM", "Lfm2MoeForCausalLM":
conv = &lfm2Model{}
case "Lfm2VlForConditionalGeneration":
conv = &lfm2VLTextModel{}
case "Qwen3NextForCausalLM":
conv = &qwen3NextModel{}
case "NemotronHForCausalLM":
conv = &nemotronHModel{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("parse config.json for %q: %w", p.Architectures[0], err)
}
if t, ok := conv.(moreParser); ok {

View File

@@ -1,6 +1,8 @@
package convert
import (
"cmp"
"fmt"
"slices"
"strings"
@@ -13,42 +15,149 @@ type lfm2Model struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
IntermediateSize uint32 `json:"intermediate_size"`
BlockFFDim uint32 `json:"block_ff_dim"`
BlockMultipleOf uint32 `json:"block_multiple_of"`
BlockAutoAdjustFFDim bool `json:"block_auto_adjust_ff_dim"`
BlockFFNDimMultiplier float32 `json:"block_ffn_dim_multiplier"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
NormEps float32 `json:"norm_eps"`
ConvLCache uint32 `json:"conv_L_cache"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
NumExperts uint32 `json:"num_experts"`
NumLocalExperts uint32 `json:"num_local_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NumDenseLayers uint32 `json:"num_dense_layers"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
LayerTypes []string `json:"layer_types"`
TieEmbedding bool `json:"tie_embedding"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
} `json:"rope_parameters"`
}
var _ ModelConverter = (*lfm2Model)(nil)
const (
defaultMaxPositionEmbeddings = uint32(128_000)
fallbackContextLength = uint32(32_768)
)
func (p *lfm2Model) isMoE() bool {
return p.ModelType == "lfm2_moe" || p.expertCount() > 0
}
func (p *lfm2Model) ropeFreqBase() float32 {
if p.RopeTheta != 0 {
return p.RopeTheta
}
return p.RopeParameters.RopeTheta
}
func (p *lfm2Model) expertCount() uint32 {
if p.NumLocalExperts > 0 {
return p.NumLocalExperts
}
return p.NumExperts
}
func (p *lfm2Model) feedForwardLength() uint32 {
ff := p.IntermediateSize
if p.BlockFFDim != 0 {
ff = p.BlockFFDim
}
if !p.BlockAutoAdjustFFDim || p.BlockMultipleOf == 0 {
return ff
}
ff = (2 * ff) / 3
// Keep default multiplier behavior consistent with llama.cpp conversion.
if p.BlockFFNDimMultiplier != 0 {
ff = uint32(float32(ff) * p.BlockFFNDimMultiplier)
}
m := p.BlockMultipleOf
return m * ((ff + m - 1) / m)
}
func (p *lfm2Model) hasKnownContextLengthFallbackSignature() bool {
return p.isMoE() &&
p.VocabSize == 65536 &&
p.HiddenSize == 2048 &&
p.NumHiddenLayers == 40 &&
p.IntermediateSize == 11776 &&
p.NumAttentionHeads == 32 &&
p.NumKeyValueHeads == 8 &&
p.NumDenseLayers == 2 &&
p.expertCount() == 64 &&
p.NumExpertsPerToken == 4 &&
p.MoEIntermediateSize == 1536
}
func (p *lfm2Model) contextLength() uint32 {
if p.MaxPositionEmbeddings == defaultMaxPositionEmbeddings && p.hasKnownContextLengthFallbackSignature() {
return fallbackContextLength
}
return p.MaxPositionEmbeddings
}
func (p *lfm2Model) KV(t *Tokenizer) KV {
architecture := "lfm2"
if p.isMoE() {
architecture = "lfm2moe"
}
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "lfm2"
kv["lfm2.vocab_size"] = p.VocabSize
kv["lfm2.block_count"] = p.NumHiddenLayers
kv["lfm2.embedding_length"] = p.HiddenSize
kv["lfm2.feed_forward_length"] = p.IntermediateSize
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
kv["general.architecture"] = architecture
kv["tokenizer.ggml.pre"] = "lfm2"
kv["vocab_size"] = p.VocabSize
kv["block_count"] = p.NumHiddenLayers
kv["embedding_length"] = p.HiddenSize
kv["feed_forward_length"] = p.feedForwardLength()
kv["context_length"] = p.contextLength()
// Build per-layer KV head count array based on layer_types
// (0 = shortconv layer, non-zero = attention layer with that many KV heads)
// (0 = shortconv layer, non-zero = attention layer with that many KV heads).
//
// Dense LFM2 in HF defaults to all attention layers when layer_types is absent.
// Preserve that behavior to avoid accidentally emitting all-conv metadata.
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
for i := range p.NumHiddenLayers {
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
if len(p.LayerTypes) == 0 {
for i := range p.NumHiddenLayers {
kvHeadCounts[i] = p.NumKeyValueHeads
}
} else {
for i := range p.NumHiddenLayers {
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
kvHeadCounts[i] = p.NumKeyValueHeads
}
}
}
kv["lfm2.attention.head_count"] = p.NumAttentionHeads
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
kv["lfm2.rope.freq_base"] = p.RopeTheta
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
kv["attention.head_count"] = p.NumAttentionHeads
kv["attention.head_count_kv"] = kvHeadCounts
kv["attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
kv["attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
kv["attention.layer_norm_rms_epsilon"] = p.NormEps
kv["shortconv.l_cache"] = p.ConvLCache
if ropeFreqBase := p.ropeFreqBase(); ropeFreqBase != 0 {
kv["rope.freq_base"] = ropeFreqBase
}
if p.isMoE() {
kv["expert_count"] = p.expertCount()
kv["expert_used_count"] = p.NumExpertsPerToken
kv["expert_feed_forward_length"] = p.MoEIntermediateSize
kv["leading_dense_block_count"] = p.NumDenseLayers
kv["expert_gating_func"] = uint32(2) // sigmoid
kv["expert_weights_scale"] = cmp.Or(p.RoutedScalingFactor, float32(1.0))
}
return kv
}
@@ -56,6 +165,30 @@ func (p *lfm2Model) KV(t *Tokenizer) KV {
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
if p.isMoE() {
merges := make([]merge, 0, p.NumHiddenLayers*3)
for i := range p.NumHiddenLayers {
if i < p.NumDenseLayers {
continue
}
merges = append(merges, merge{
fmt.Sprintf("blk.%d.feed_forward.experts.*.w1.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.feed_forward.experts.*.w2.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.feed_forward.experts.*.w3.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
})
}
merged, remaining := mergeTensors(ts, merges...)
out = append(out, merged...)
ts = remaining
}
for _, t := range ts {
shape := t.Shape()
@@ -80,7 +213,7 @@ func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
func (p *lfm2Model) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.embedding_norm", "output_norm",
"model.embedding_norm", "token_embd_norm",
"model.layers", "blk",
"operator_norm", "attn_norm",
"self_attn.q_proj", "attn_q",
@@ -92,6 +225,8 @@ func (p *lfm2Model) Replacements() []string {
"conv.conv", "shortconv.conv",
"conv.in_proj", "shortconv.in_proj",
"conv.out_proj", "shortconv.out_proj",
"feed_forward.gate", "ffn_gate_inp",
"feed_forward.expert_bias", "exp_probs_b.bias",
"feed_forward.w1", "ffn_gate",
"feed_forward.w2", "ffn_down",
"feed_forward.w3", "ffn_up",

View File

@@ -0,0 +1,271 @@
package convert
import (
"io"
"slices"
"strings"
"testing"
)
type lfm2StubTensor struct {
tensorBase
}
func newLFM2StubTensor(name string, shape []uint64) *lfm2StubTensor {
return &lfm2StubTensor{
tensorBase: tensorBase{
name: name,
shape: shape,
},
}
}
func (t *lfm2StubTensor) WriteTo(io.Writer) (int64, error) {
return 0, nil
}
func (t *lfm2StubTensor) Clone() Tensor {
return &lfm2StubTensor{
tensorBase: tensorBase{
name: t.name,
shape: slices.Clone(t.shape),
},
}
}
func TestLFM2MoEKV(t *testing.T) {
var p lfm2Model
p.ModelParameters.ModelType = "lfm2_moe"
p.VocabSize = 65536
p.HiddenSize = 2048
p.NumHiddenLayers = 4
p.MaxPositionEmbeddings = 128000
p.IntermediateSize = 11776
p.NumAttentionHeads = 32
p.NumKeyValueHeads = 8
p.LayerTypes = []string{"conv", "full_attention", "conv", "full_attention"}
p.NormEps = 1e-5
p.ConvLCache = 3
p.MoEIntermediateSize = 1536
p.NumExperts = 64
p.NumExpertsPerToken = 4
p.NumDenseLayers = 2
p.RopeParameters.RopeTheta = 1_000_000
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["general.architecture"], "lfm2moe"; got != want {
t.Fatalf("general.architecture = %v, want %v", got, want)
}
if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want {
t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want)
}
if got, want := kv["expert_count"], uint32(64); got != want {
t.Fatalf("expert_count = %v, want %v", got, want)
}
if got, want := kv["expert_used_count"], uint32(4); got != want {
t.Fatalf("expert_used_count = %v, want %v", got, want)
}
if got, want := kv["expert_feed_forward_length"], uint32(1536); got != want {
t.Fatalf("expert_feed_forward_length = %v, want %v", got, want)
}
if got, want := kv["leading_dense_block_count"], uint32(2); got != want {
t.Fatalf("leading_dense_block_count = %v, want %v", got, want)
}
if got, want := kv["expert_gating_func"], uint32(2); got != want {
t.Fatalf("expert_gating_func = %v, want %v", got, want)
}
gotHeadCounts, ok := kv["attention.head_count_kv"].([]uint32)
if !ok {
t.Fatalf("attention.head_count_kv has unexpected type %T", kv["attention.head_count_kv"])
}
wantHeadCounts := []uint32{0, 8, 0, 8}
if !slices.Equal(gotHeadCounts, wantHeadCounts) {
t.Fatalf("attention.head_count_kv = %v, want %v", gotHeadCounts, wantHeadCounts)
}
if got, want := kv["rope.freq_base"], float32(1_000_000); got != want {
t.Fatalf("rope.freq_base = %v, want %v", got, want)
}
}
func TestLFM2DenseKV(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 32000},
HiddenSize: 1024,
NumHiddenLayers: 2,
MaxPositionEmbeddings: 32768,
IntermediateSize: 4096,
NumAttentionHeads: 16,
NumKeyValueHeads: 4,
LayerTypes: []string{"conv", "full_attention"},
NormEps: 1e-5,
ConvLCache: 3,
RopeTheta: 10000,
}
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["general.architecture"], "lfm2"; got != want {
t.Fatalf("general.architecture = %v, want %v", got, want)
}
if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want {
t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want)
}
if _, ok := kv["expert_count"]; ok {
t.Fatalf("expert_count should not be set for dense lfm2")
}
}
func TestLFM2MoETensors(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2_moe"},
NumHiddenLayers: 4,
NumDenseLayers: 2,
}
in := []Tensor{
newLFM2StubTensor("blk.2.feed_forward.experts.0.w1.weight", []uint64{1536, 2048}),
newLFM2StubTensor("blk.2.feed_forward.experts.1.w1.weight", []uint64{1536, 2048}),
newLFM2StubTensor("blk.2.feed_forward.experts.0.w2.weight", []uint64{2048, 1536}),
newLFM2StubTensor("blk.2.feed_forward.experts.1.w2.weight", []uint64{2048, 1536}),
newLFM2StubTensor("blk.2.feed_forward.experts.0.w3.weight", []uint64{1536, 2048}),
newLFM2StubTensor("blk.2.feed_forward.experts.1.w3.weight", []uint64{1536, 2048}),
newLFM2StubTensor("blk.0.shortconv.conv.weight", []uint64{2048, 1, 3}),
}
out := p.Tensors(in)
byName := make(map[string][]uint64, len(out))
for _, tns := range out {
byName[tns.Name] = tns.Shape
}
if got, ok := byName["blk.2.ffn_gate_exps.weight"]; !ok {
t.Fatalf("missing merged tensor blk.2.ffn_gate_exps.weight")
} else if !slices.Equal(got, []uint64{2, 1536, 2048}) {
t.Fatalf("blk.2.ffn_gate_exps.weight shape = %v, want [2 1536 2048]", got)
}
if got, ok := byName["blk.2.ffn_down_exps.weight"]; !ok {
t.Fatalf("missing merged tensor blk.2.ffn_down_exps.weight")
} else if !slices.Equal(got, []uint64{2, 2048, 1536}) {
t.Fatalf("blk.2.ffn_down_exps.weight shape = %v, want [2 2048 1536]", got)
}
if got, ok := byName["blk.2.ffn_up_exps.weight"]; !ok {
t.Fatalf("missing merged tensor blk.2.ffn_up_exps.weight")
} else if !slices.Equal(got, []uint64{2, 1536, 2048}) {
t.Fatalf("blk.2.ffn_up_exps.weight shape = %v, want [2 1536 2048]", got)
}
if got, ok := byName["blk.0.shortconv.conv.weight"]; !ok {
t.Fatalf("missing shortconv tensor")
} else if !slices.Equal(got, []uint64{2048, 3}) {
t.Fatalf("blk.0.shortconv.conv.weight shape = %v, want [2048 3]", got)
}
if _, ok := byName["blk.2.feed_forward.experts.0.w1.weight"]; ok {
t.Fatalf("unmerged expert tensor should not be present")
}
}
func TestLFM2MoEReplacements(t *testing.T) {
p := lfm2Model{}
replacer := strings.NewReplacer(p.Replacements()...)
if got, want := replacer.Replace("model.layers.2.feed_forward.expert_bias"), "blk.2.exp_probs_b.bias"; got != want {
t.Fatalf("expert bias replacement = %q, want %q", got, want)
}
if got, want := replacer.Replace("model.layers.2.feed_forward.gate.weight"), "blk.2.ffn_gate_inp.weight"; got != want {
t.Fatalf("gate replacement = %q, want %q", got, want)
}
}
func TestLFM2KVContextLengthEdgeCaseFallbackOverride(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536},
HiddenSize: 2048,
NumHiddenLayers: 40,
MaxPositionEmbeddings: 128000,
IntermediateSize: 11776,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
LayerTypes: make([]string, 40),
NormEps: 1e-5,
ConvLCache: 3,
MoEIntermediateSize: 1536,
NumExperts: 64,
NumExpertsPerToken: 4,
NumDenseLayers: 2,
}
for i := 0; i < len(p.LayerTypes); i++ {
p.LayerTypes[i] = "conv"
}
p.LayerTypes[2] = "full_attention"
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["context_length"], uint32(32768); got != want {
t.Fatalf("context_length = %v, want %v", got, want)
}
}
func TestLFM2KVContextLengthNoOverride(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536},
HiddenSize: 2048,
NumHiddenLayers: 39, // mismatch: should not trigger edge case
MaxPositionEmbeddings: 128000,
IntermediateSize: 11776,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
LayerTypes: []string{"conv", "full_attention"},
NormEps: 1e-5,
ConvLCache: 3,
MoEIntermediateSize: 1536,
NumExperts: 64,
NumExpertsPerToken: 4,
NumDenseLayers: 2,
}
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["context_length"], uint32(128000); got != want {
t.Fatalf("context_length = %v, want %v", got, want)
}
}
func TestLFM2KVFeedForwardLengthAutoAdjust(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536},
HiddenSize: 2048,
NumHiddenLayers: 16,
MaxPositionEmbeddings: 128000,
IntermediateSize: 12288, // should be ignored when block_ff_dim is set
BlockFFDim: 12288,
BlockAutoAdjustFFDim: true,
BlockMultipleOf: 256,
BlockFFNDimMultiplier: 1.0,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
LayerTypes: []string{"conv", "full_attention"},
NormEps: 1e-5,
ConvLCache: 3,
}
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["feed_forward_length"], uint32(8192); got != want {
t.Fatalf("feed_forward_length = %v, want %v", got, want)
}
}

417
convert/convert_lfm2_vl.go Normal file
View File

@@ -0,0 +1,417 @@
package convert
import (
"cmp"
"encoding/json"
"errors"
"fmt"
"io/fs"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
// lfm2VLTextModel converts the language model component of LFM2 VL checkpoints.
type lfm2VLTextModel struct {
TextConfig lfm2Model `json:"text_config"`
DoImageSplitting *bool `json:"do_image_splitting"`
DownsampleFactor uint32 `json:"downsample_factor"`
EncoderPatchSize uint32 `json:"encoder_patch_size"`
ImageTokenID uint32 `json:"image_token_id"`
MaxImageTokens uint32 `json:"max_image_tokens"`
MinImageTokens uint32 `json:"min_image_tokens"`
MaxTiles uint32 `json:"max_tiles"`
MinTiles uint32 `json:"min_tiles"`
TileSize uint32 `json:"tile_size"`
MaxPixelsTolerance float32 `json:"max_pixels_tolerance"`
ProjectorUseLayernorm bool `json:"projector_use_layernorm"`
ProjectorHiddenSize uint32 `json:"projector_hidden_size"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
UseImageSpecialTokens *bool `json:"use_image_special_tokens"`
UseThumbnail *bool `json:"use_thumbnail"`
VisionConfig struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
LayerNormEpsilon float32 `json:"layer_norm_eps"`
} `json:"vision_config"`
Processor struct {
ImageProcessor struct {
DoImageSplitting *bool `json:"do_image_splitting"`
DownsampleFactor uint32 `json:"downsample_factor"`
MaxImageTokens uint32 `json:"max_image_tokens"`
MinImageTokens uint32 `json:"min_image_tokens"`
MaxTiles uint32 `json:"max_tiles"`
MinTiles uint32 `json:"min_tiles"`
MaxPixelsTol float32 `json:"max_pixels_tolerance"`
TileSize uint32 `json:"tile_size"`
UseThumbnail *bool `json:"use_thumbnail"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
Size struct {
Height uint32 `json:"height"`
Width uint32 `json:"width"`
} `json:"size"`
} `json:"image_processor"`
}
}
func (p *lfm2VLTextModel) textModel() *lfm2Model {
return &p.TextConfig
}
func (p *lfm2VLTextModel) specialTokenTypes() []string {
return p.textModel().specialTokenTypes()
}
func (p *lfm2VLTextModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "processor_config.json")
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return err
}
return json.Unmarshal(bts, &p.Processor)
}
func (p *lfm2VLTextModel) visionImageSize() uint32 {
// LFM2-VL image processor operates on 512 tiles and downsamples by factor 2
// before projection. Keep a fixed square image size compatible with position
// embeddings and the simplified runtime image pipeline.
tile := cmp.Or(
p.Processor.ImageProcessor.TileSize,
p.Processor.ImageProcessor.Size.Height,
p.Processor.ImageProcessor.Size.Width,
uint32(512),
)
downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
if downsample == 0 {
return tile
}
return max(uint32(1), tile/downsample)
}
func (p *lfm2VLTextModel) KV(t *Tokenizer) KV {
kv := p.textModel().KV(t)
boolOr := func(defaultValue bool, values ...*bool) bool {
for _, v := range values {
if v != nil {
return *v
}
}
return defaultValue
}
kv["vision.block_count"] = cmp.Or(p.VisionConfig.NumHiddenLayers, uint32(27))
kv["vision.embedding_length"] = cmp.Or(p.VisionConfig.HiddenSize, uint32(1152))
kv["vision.feed_forward_length"] = cmp.Or(p.VisionConfig.IntermediateSize, uint32(4304))
kv["vision.attention.head_count"] = cmp.Or(p.VisionConfig.NumAttentionHeads, uint32(16))
kv["vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionConfig.LayerNormEpsilon, float32(1e-6))
kv["vision.patch_size"] = cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16))
kv["vision.num_channels"] = cmp.Or(p.VisionConfig.NumChannels, uint32(3))
kv["vision.image_size"] = p.visionImageSize()
kv["vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
kv["vision.projector.use_layernorm"] = p.ProjectorUseLayernorm
kv["vision.do_image_splitting"] = boolOr(true, p.DoImageSplitting, p.Processor.ImageProcessor.DoImageSplitting)
kv["vision.min_tiles"] = cmp.Or(p.MinTiles, p.Processor.ImageProcessor.MinTiles, uint32(2))
kv["vision.max_tiles"] = cmp.Or(p.MaxTiles, p.Processor.ImageProcessor.MaxTiles, uint32(10))
kv["vision.tile_size"] = cmp.Or(p.TileSize, p.Processor.ImageProcessor.TileSize, uint32(512))
kv["vision.min_image_tokens"] = cmp.Or(p.MinImageTokens, p.Processor.ImageProcessor.MinImageTokens, uint32(64))
kv["vision.max_image_tokens"] = cmp.Or(p.MaxImageTokens, p.Processor.ImageProcessor.MaxImageTokens, uint32(256))
kv["vision.max_pixels_tolerance"] = cmp.Or(p.MaxPixelsTolerance, p.Processor.ImageProcessor.MaxPixelsTol, float32(2.0))
kv["vision.use_thumbnail"] = boolOr(true, p.UseThumbnail, p.Processor.ImageProcessor.UseThumbnail)
kv["vision.use_image_special_tokens"] = boolOr(true, p.UseImageSpecialTokens)
kv["vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5}))
kv["vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5}))
kv["vision.image_token_id"] = cmp.Or(p.ImageTokenID, uint32(396))
setVisionTokenID := func(k, token string) {
if t == nil || t.Vocabulary == nil {
return
}
for i, v := range t.Vocabulary.Tokens {
if v == token {
kv[k] = uint32(i)
return
}
}
}
setVisionTokenID("vision.image_start_token_id", "<|image_start|>")
setVisionTokenID("vision.image_end_token_id", "<|image_end|>")
setVisionTokenID("vision.image_thumbnail_token_id", "<|img_thumbnail|>")
return kv
}
func (p *lfm2VLTextModel) Tensors(ts []Tensor) []*ggml.Tensor {
patchSize := int(cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16)))
numChannels := int(cmp.Or(p.VisionConfig.NumChannels, uint32(3)))
for _, t := range ts {
if t.Name() == "v.patch_embd.weight" {
shape := t.Shape()
if len(shape) == 2 {
inputDim := uint64(numChannels * patchSize * patchSize)
if shape[1] == inputDim {
channels := numChannels
patch := patchSize
t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) {
return repackPatchEmbeddingWeight(data, srcShape, channels, patch)
})
}
}
}
}
out := p.textModel().Tensors(ts)
for _, t := range out {
if t.Name == "v.patch_embd.weight" && len(t.Shape) == 2 {
t.Shape = []uint64{t.Shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)}
}
}
return out
}
func (p *lfm2VLTextModel) Replacements() []string {
out := make([]string, 0, 96)
addText := func(from, to string) {
out = append(out, from, to)
if strings.HasPrefix(from, "model.") {
suffix := strings.TrimPrefix(from, "model.")
out = append(out,
"model.language_model."+suffix, to,
"model.language_model.model."+suffix, to,
)
}
}
base := p.textModel().Replacements()
for i := 0; i+1 < len(base); i += 2 {
addText(base[i], base[i+1])
}
// Vision tower + multimodal projector tensors (single-file conversion).
out = append(out,
"model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd",
"model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd",
"model.vision_tower.vision_model.encoder.layers", "v.blk",
"model.vision_tower.vision_model.post_layernorm", "v.post_ln",
"model.multi_modal_projector.layer_norm", "mm.layer_norm",
"model.multi_modal_projector.linear_1", "mm.1",
"model.multi_modal_projector.linear_2", "mm.2",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.out_proj", "attn_out",
"layer_norm1", "ln1",
"layer_norm2", "ln2",
"mlp.fc1", "ffn_up",
"mlp.fc2", "ffn_down",
)
return out
}
// lfm2VLProjectorModel converts the vision encoder + projector component of LFM2 VL checkpoints.
type lfm2VLProjectorModel struct {
ModelParameters
DownsampleFactor uint32 `json:"downsample_factor"`
ProjectorHiddenDim uint32 `json:"projector_hidden_size"`
VisionModel struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
LayerNormEpsilon float32 `json:"layer_norm_eps"`
ImageSize uint32 `json:"image_size"`
} `json:"vision_config"`
Processor struct {
ImageProcessor struct {
DownsampleFactor uint32 `json:"downsample_factor"`
TileSize uint32 `json:"tile_size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
Size struct {
Height uint32 `json:"height"`
Width uint32 `json:"width"`
} `json:"size"`
} `json:"image_processor"`
}
}
var (
_ ModelConverter = (*lfm2VLTextModel)(nil)
_ ModelConverter = (*lfm2VLProjectorModel)(nil)
_ moreParser = (*lfm2VLTextModel)(nil)
_ moreParser = (*lfm2VLProjectorModel)(nil)
)
func (p *lfm2VLProjectorModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "processor_config.json")
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return err
}
return json.Unmarshal(bts, &p.Processor)
}
func (p *lfm2VLProjectorModel) imageSize() uint32 {
if p.VisionModel.ImageSize > 0 {
return p.VisionModel.ImageSize
}
downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
baseSize := cmp.Or(
p.Processor.ImageProcessor.TileSize,
p.Processor.ImageProcessor.Size.Height,
p.Processor.ImageProcessor.Size.Width,
uint32(256),
)
if downsample == 0 {
return baseSize
}
return max(uint32(1), baseSize/downsample)
}
func (p *lfm2VLProjectorModel) KV(_ *Tokenizer) KV {
kv := KV{
"general.architecture": "clip",
"general.type": "mmproj",
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"clip.has_vision_encoder": true,
"clip.projector_type": "lfm2",
"clip.use_gelu": true,
}
kv["clip.vision.block_count"] = cmp.Or(p.VisionModel.NumHiddenLayers, uint32(27))
kv["clip.vision.embedding_length"] = cmp.Or(p.VisionModel.HiddenSize, uint32(1152))
kv["clip.vision.feed_forward_length"] = cmp.Or(p.VisionModel.IntermediateSize, uint32(4304))
kv["clip.vision.attention.head_count"] = cmp.Or(p.VisionModel.NumAttentionHeads, uint32(16))
kv["clip.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, float32(1e-6))
kv["clip.vision.patch_size"] = cmp.Or(p.VisionModel.PatchSize, uint32(16))
kv["clip.vision.image_size"] = p.imageSize()
kv["clip.vision.projection_dim"] = cmp.Or(p.ProjectorHiddenDim, uint32(2048))
kv["clip.vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
kv["clip.vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5}))
kv["clip.vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5}))
return kv
}
func defaultFloat32Slice(v, fallback []float32) []float32 {
if len(v) > 0 {
return v
}
return fallback
}
func (p *lfm2VLProjectorModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
numChannels := cmp.Or(p.VisionModel.NumChannels, uint32(3))
patchSize := cmp.Or(p.VisionModel.PatchSize, uint32(16))
for _, t := range ts {
name := t.Name()
if !(strings.HasPrefix(name, "v.") || strings.HasPrefix(name, "mm.")) {
continue
}
shape := t.Shape()
if name == "v.patch_embd.weight" && len(shape) == 2 {
inputDim := uint64(numChannels * patchSize * patchSize)
if shape[1] == inputDim {
shape = []uint64{shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)}
channels := int(numChannels)
patch := int(patchSize)
t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) {
return repackPatchEmbeddingWeight(data, srcShape, channels, patch)
})
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
}
return out
}
func (p *lfm2VLProjectorModel) Replacements() []string {
return []string{
"model.multi_modal_projector.linear_1", "mm.1",
"model.multi_modal_projector.linear_2", "mm.2",
"model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd",
"model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd",
"model.vision_tower.vision_model.encoder.layers", "v.blk",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.out_proj", "attn_out",
"layer_norm1", "ln1",
"layer_norm2", "ln2",
"mlp.fc1", "ffn_up",
"mlp.fc2", "ffn_down",
"model.vision_tower.vision_model.post_layernorm", "v.post_ln",
}
}
func repackPatchEmbeddingWeight(data []float32, srcShape []uint64, channels, patch int) ([]float32, error) {
if len(srcShape) != 2 {
return nil, fmt.Errorf("invalid patch embedding shape rank: %d", len(srcShape))
}
outDim := int(srcShape[0])
flatInputDim := int(srcShape[1])
expectedInputDim := channels * patch * patch
if flatInputDim != expectedInputDim {
return nil, fmt.Errorf("invalid patch embedding input dim: got %d, want %d", flatInputDim, expectedInputDim)
}
expectedSize := outDim * flatInputDim
if len(data) != expectedSize {
return nil, fmt.Errorf("invalid patch embedding data size: got %d, want %d", len(data), expectedSize)
}
repacked := make([]float32, len(data))
perChannel := patch * patch
for o := range outDim {
inBase := o * flatInputDim
outBase := o * flatInputDim
for y := range patch {
for x := range patch {
inPixelBase := inBase + (y*patch+x)*channels
for c := range channels {
src := inPixelBase + c
dst := outBase + c*perChannel + y*patch + x
repacked[dst] = data[src]
}
}
}
}
return repacked, nil
}

View File

@@ -0,0 +1,249 @@
package convert
import (
"slices"
"strings"
"testing"
)
func TestLFM2VLTextModelKVUsesTextConfig(t *testing.T) {
p := lfm2VLTextModel{
TextConfig: lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536},
HiddenSize: 2048,
NumHiddenLayers: 16,
MaxPositionEmbeddings: 128000,
IntermediateSize: 12288,
BlockFFDim: 12288,
BlockAutoAdjustFFDim: true,
BlockMultipleOf: 256,
BlockFFNDimMultiplier: 1.0,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
LayerTypes: []string{"conv", "full_attention"},
NormEps: 1e-5,
ConvLCache: 3,
},
DownsampleFactor: 2,
VisionConfig: struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
LayerNormEpsilon float32 `json:"layer_norm_eps"`
}{
HiddenSize: 1152,
IntermediateSize: 4304,
NumAttentionHeads: 16,
NumHiddenLayers: 27,
NumChannels: 3,
PatchSize: 16,
LayerNormEpsilon: 1e-6,
},
}
p.Processor.ImageProcessor.TileSize = 512
p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5}
p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5}
kv := p.KV(&Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
Tokens: []string{"<|pad|>", "<image>", "<|image_start|>", "<|image_end|>", "<|img_thumbnail|>"},
},
})
if got, want := kv["general.architecture"], "lfm2"; got != want {
t.Fatalf("general.architecture = %v, want %v", got, want)
}
if got, want := kv["feed_forward_length"], uint32(8192); got != want {
t.Fatalf("feed_forward_length = %v, want %v", got, want)
}
if got, want := kv["vision.block_count"], uint32(27); got != want {
t.Fatalf("vision.block_count = %v, want %v", got, want)
}
if got, want := kv["vision.image_size"], uint32(256); got != want {
t.Fatalf("vision.image_size = %v, want %v", got, want)
}
if got, want := kv["vision.image_token_id"], uint32(396); got != want {
t.Fatalf("vision.image_token_id = %v, want %v", got, want)
}
if got, want := kv["vision.image_start_token_id"], uint32(2); got != want {
t.Fatalf("vision.image_start_token_id = %v, want %v", got, want)
}
if got, want := kv["vision.do_image_splitting"], true; got != want {
t.Fatalf("vision.do_image_splitting = %v, want %v", got, want)
}
if got, want := kv["vision.min_tiles"], uint32(2); got != want {
t.Fatalf("vision.min_tiles = %v, want %v", got, want)
}
if got, want := kv["vision.max_tiles"], uint32(10); got != want {
t.Fatalf("vision.max_tiles = %v, want %v", got, want)
}
if got, want := kv["vision.tile_size"], uint32(512); got != want {
t.Fatalf("vision.tile_size = %v, want %v", got, want)
}
if got, want := kv["vision.use_thumbnail"], true; got != want {
t.Fatalf("vision.use_thumbnail = %v, want %v", got, want)
}
if got, want := kv["vision.use_image_special_tokens"], true; got != want {
t.Fatalf("vision.use_image_special_tokens = %v, want %v", got, want)
}
}
func TestLFM2VLTextModelTensorsIncludeVision(t *testing.T) {
p := lfm2VLTextModel{}
p.VisionConfig.PatchSize = 16
p.VisionConfig.NumChannels = 3
input := []Tensor{
newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}),
newLFM2StubTensor("model.layers.0.ffn_norm.weight", []uint64{2048}),
newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}),
newLFM2StubTensor("v.blk.0.attn_q.weight", []uint64{1152, 1152}),
newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}),
}
out := p.Tensors(input)
if len(out) == 0 {
t.Fatal("expected non-empty tensor list")
}
foundPatch := false
foundVision := false
for _, tns := range out {
if tns.Name == "v.patch_embd.weight" {
foundPatch = true
if !slices.Equal(tns.Shape, []uint64{1152, 3, 16, 16}) {
t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", tns.Shape)
}
}
if strings.HasPrefix(tns.Name, "v.") || strings.HasPrefix(tns.Name, "mm.") {
foundVision = true
}
}
if !foundPatch {
t.Fatal("expected v.patch_embd.weight in output tensors")
}
if !foundVision {
t.Fatal("expected at least one vision/projector tensor in output")
}
}
func TestLFM2VLTextModelReplacements(t *testing.T) {
p := lfm2VLTextModel{}
r := strings.NewReplacer(p.Replacements()...)
tests := []struct {
name string
in string
want string
}{
{
name: "language_model_embed_tokens",
in: "model.language_model.embed_tokens.weight",
want: "token_embd.weight",
},
{
name: "language_model_layers",
in: "model.language_model.layers.2.self_attn.q_proj.weight",
want: "blk.2.attn_q.weight",
},
{
name: "nested_language_model_prefix",
in: "model.language_model.model.embedding_norm.weight",
want: "token_embd_norm.weight",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := r.Replace(tt.in); got != tt.want {
t.Fatalf("replacement(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestLFM2VLProjectorKV(t *testing.T) {
p := lfm2VLProjectorModel{
DownsampleFactor: 2,
ProjectorHiddenDim: 2048,
}
p.VisionModel.NumHiddenLayers = 27
p.VisionModel.HiddenSize = 1152
p.VisionModel.IntermediateSize = 4304
p.VisionModel.NumAttentionHeads = 16
p.VisionModel.PatchSize = 16
p.VisionModel.LayerNormEpsilon = 1e-6
p.Processor.ImageProcessor.TileSize = 512
p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5}
p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5}
kv := p.KV(nil)
if got, want := kv["general.architecture"], "clip"; got != want {
t.Fatalf("general.architecture = %v, want %v", got, want)
}
if got, want := kv["clip.projector_type"], "lfm2"; got != want {
t.Fatalf("clip.projector_type = %v, want %v", got, want)
}
if got, want := kv["clip.vision.image_size"], uint32(256); got != want {
t.Fatalf("clip.vision.image_size = %v, want %v", got, want)
}
}
func TestLFM2VLProjectorTensorsPatchReshape(t *testing.T) {
p := lfm2VLProjectorModel{}
p.VisionModel.NumChannels = 3
p.VisionModel.PatchSize = 16
input := []Tensor{
newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}),
newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}),
newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}),
}
out := p.Tensors(input)
if len(out) != 2 {
t.Fatalf("expected 2 tensors, got %d", len(out))
}
var patchShape []uint64
for _, tns := range out {
if tns.Name == "v.patch_embd.weight" {
patchShape = tns.Shape
break
}
}
if !slices.Equal(patchShape, []uint64{1152, 3, 16, 16}) {
t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", patchShape)
}
}
func TestRepackPatchEmbeddingWeight(t *testing.T) {
data := []float32{
0, 1, // y=0,x=0
2, 3, // y=0,x=1
4, 5, // y=1,x=0
6, 7, // y=1,x=1
}
got, err := repackPatchEmbeddingWeight(data, []uint64{1, 8}, 2, 2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := []float32{0, 2, 4, 6, 1, 3, 5, 7}
if !slices.Equal(got, want) {
t.Fatalf("repacked data = %v, want %v", got, want)
}
}

View File

@@ -0,0 +1,385 @@
package convert
import (
"cmp"
"encoding/json"
"fmt"
"io/fs"
"math"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type hybridPattern string
func (p *hybridPattern) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
*p = ""
return nil
}
var single string
if err := json.Unmarshal(data, &single); err == nil {
*p = hybridPattern(strings.TrimSpace(single))
return nil
}
var parts []string
if err := json.Unmarshal(data, &parts); err == nil {
*p = hybridPattern(strings.Join(parts, ""))
return nil
}
return fmt.Errorf("hybrid_override_pattern must be a string or string array")
}
type nemotronHModel struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_eps"`
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
ConvKernel uint32 `json:"conv_kernel"`
SSMStateSize uint32 `json:"ssm_state_size"`
MambaNumHeads uint32 `json:"mamba_num_heads"`
MambaHeadDim uint32 `json:"mamba_head_dim"`
NGroups uint32 `json:"n_groups"`
IntermediateSize uint32 `json:"intermediate_size"`
HybridOverridePattern hybridPattern `json:"hybrid_override_pattern"`
// MoE
NumExperts uint32 `json:"num_experts"`
NumSharedExperts uint32 `json:"num_shared_experts"`
NRoutedExperts uint32 `json:"n_routed_experts"`
NSharedExperts uint32 `json:"n_shared_experts"`
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
MoESharedExpertIntermediate uint32 `json:"moe_shared_expert_intermediate_size"`
NormTopKProb bool `json:"norm_topk_prob"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
ExpertGroupCount uint32 `json:"n_group"`
ExpertGroupUsedCount uint32 `json:"topk_group"`
}
var _ ModelConverter = (*nemotronHModel)(nil)
func (n *nemotronHModel) parseMore(_ fs.FS) error {
if n.NumHiddenLayers == 0 {
return fmt.Errorf("nemotron_h: num_hidden_layers must be set")
}
if n.HiddenSize == 0 {
return fmt.Errorf("nemotron_h: hidden_size must be set")
}
if n.NumAttentionHeads == 0 {
return fmt.Errorf("nemotron_h: num_attention_heads must be set")
}
if n.HeadDim == 0 {
if n.HiddenSize%n.NumAttentionHeads != 0 {
return fmt.Errorf("nemotron_h: hidden_size (%d) must be divisible by num_attention_heads (%d)", n.HiddenSize, n.NumAttentionHeads)
}
n.HeadDim = n.HiddenSize / n.NumAttentionHeads
}
if n.NumKeyValueHeads == 0 {
n.NumKeyValueHeads = n.NumAttentionHeads
}
if n.ConvKernel == 0 {
return fmt.Errorf("nemotron_h: conv_kernel must be set")
}
if n.SSMStateSize == 0 {
return fmt.Errorf("nemotron_h: ssm_state_size must be set")
}
if n.ssmHeadCount() == 0 {
return fmt.Errorf("nemotron_h: mamba_num_heads must be set")
}
if n.MambaHeadDim == 0 {
return fmt.Errorf("nemotron_h: mamba_head_dim must be set")
}
if n.NGroups == 0 {
n.NGroups = 1
}
if _, _, err := n.layerArrays(); err != nil {
return err
}
if n.isMoE() {
if n.routedExpertCount() == 0 {
return fmt.Errorf("nemotron_h: routed expert count must be set for MoE models")
}
if n.NumExpertsPerTok == 0 {
return fmt.Errorf("nemotron_h: num_experts_per_tok must be set for MoE models")
}
if n.NumExpertsPerTok > n.routedExpertCount() {
return fmt.Errorf("nemotron_h: num_experts_per_tok (%d) cannot exceed expert_count (%d)", n.NumExpertsPerTok, n.routedExpertCount())
}
if n.moeIntermediateSize() == 0 {
return fmt.Errorf("nemotron_h: moe_intermediate_size must be set for MoE models")
}
}
return nil
}
func (n *nemotronHModel) isMoE() bool {
return cmp.Or(n.routedExpertCount(), n.NumExpertsPerTok, n.MoEIntermediateSize) > 0
}
func (n *nemotronHModel) routedExpertCount() uint32 {
return cmp.Or(n.NRoutedExperts, n.NumExperts)
}
func (n *nemotronHModel) sharedExpertCount() uint32 {
return cmp.Or(n.NSharedExperts, n.NumSharedExperts)
}
func (n *nemotronHModel) ssmHeadCount() uint32 {
return n.MambaNumHeads
}
func (n *nemotronHModel) ssmInnerSize() uint32 {
return n.MambaHeadDim * n.ssmHeadCount()
}
func (n *nemotronHModel) epsilon() float32 {
return cmp.Or(n.NormEpsilon, n.LayerNormEpsilon, float32(1e-5))
}
func (n *nemotronHModel) moeIntermediateSize() uint32 {
return cmp.Or(n.MoEIntermediateSize, n.IntermediateSize)
}
func (n *nemotronHModel) denseIntermediateSize() uint32 {
return cmp.Or(n.IntermediateSize, n.MoEIntermediateSize)
}
func (n *nemotronHModel) layerArrays() (headCountKV []uint32, ffnLengths []uint32, err error) {
pattern := strings.TrimSpace(string(n.HybridOverridePattern))
if pattern == "" {
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern must be set")
}
runes := []rune(pattern)
if len(runes) != int(n.NumHiddenLayers) {
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern length (%d) must match num_hidden_layers (%d)", len(runes), n.NumHiddenLayers)
}
headCountKV = make([]uint32, n.NumHiddenLayers)
ffnLengths = make([]uint32, n.NumHiddenLayers)
attnKVHeads := cmp.Or(n.NumKeyValueHeads, n.NumAttentionHeads)
moeFFN := n.moeIntermediateSize()
denseFFN := n.denseIntermediateSize()
for i, layerType := range runes {
switch layerType {
case 'M':
// Recurrent layer: no KV heads and no FFN.
case '*', 'A':
// Attention-only layer.
headCountKV[i] = attnKVHeads
case 'E':
// MoE layer.
if moeFFN == 0 {
return nil, nil, fmt.Errorf("nemotron_h: moe layer at index %d but moe_intermediate_size is zero", i)
}
ffnLengths[i] = moeFFN
case '-':
// Dense FFN layer.
if denseFFN == 0 {
return nil, nil, fmt.Errorf("nemotron_h: dense FFN layer at index %d but intermediate_size is zero", i)
}
ffnLengths[i] = denseFFN
default:
return nil, nil, fmt.Errorf("nemotron_h: unsupported layer type %q in hybrid_override_pattern at index %d", layerType, i)
}
}
return headCountKV, ffnLengths, nil
}
func (n *nemotronHModel) KV(t *Tokenizer) KV {
kv := n.ModelParameters.KV(t)
arch := "nemotron_h"
if n.isMoE() {
arch = "nemotron_h_moe"
}
kv["general.architecture"] = arch
kv["block_count"] = n.NumHiddenLayers
kv["context_length"] = n.MaxPositionEmbeddings
kv["embedding_length"] = n.HiddenSize
kv["attention.head_count"] = n.NumAttentionHeads
kv["attention.key_length"] = n.HeadDim
kv["attention.value_length"] = n.HeadDim
kv["attention.layer_norm_epsilon"] = n.epsilon()
kv["attention.layer_norm_rms_epsilon"] = n.epsilon()
kv["rope.freq_base"] = cmp.Or(n.RopeTheta, float32(10000))
if n.PartialRotaryFactor > 0 && n.PartialRotaryFactor <= 1 {
kv["rope.dimension_count"] = uint32(float32(n.HeadDim) * n.PartialRotaryFactor)
}
if headCountKV, ffnLengths, err := n.layerArrays(); err == nil {
kv["attention.head_count_kv"] = headCountKV
kv["feed_forward_length"] = ffnLengths
}
kv["ssm.conv_kernel"] = n.ConvKernel
kv["ssm.inner_size"] = n.ssmInnerSize()
kv["ssm.state_size"] = n.SSMStateSize
kv["ssm.group_count"] = n.NGroups
kv["ssm.time_step_rank"] = n.ssmHeadCount()
if n.isMoE() {
kv["expert_count"] = n.routedExpertCount()
kv["expert_used_count"] = n.NumExpertsPerTok
kv["expert_feed_forward_length"] = n.moeIntermediateSize()
if n.sharedExpertCount() > 0 {
kv["expert_shared_count"] = n.sharedExpertCount()
}
if n.MoESharedExpertIntermediate > 0 {
kv["expert_shared_feed_forward_length"] = n.MoESharedExpertIntermediate
}
kv["expert_weights_norm"] = n.NormTopKProb
kv["expert_weights_scale"] = n.RoutedScalingFactor
if n.ExpertGroupCount > 0 {
kv["expert_group_count"] = n.ExpertGroupCount
}
if n.ExpertGroupUsedCount > 0 {
kv["expert_group_used_count"] = n.ExpertGroupUsedCount
}
}
return kv
}
func normalizeVectorShapeToColumn(shape []uint64) []uint64 {
switch len(shape) {
case 1:
return []uint64{shape[0], 1}
case 2:
if shape[0] == 1 && shape[1] > 1 {
return []uint64{shape[1], 1}
}
if shape[1] == 1 && shape[0] > 1 {
return []uint64{shape[0], 1}
}
}
return slices.Clone(shape)
}
func (n *nemotronHModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
remaining := ts
if n.isMoE() {
merges := make([]merge, 0, n.NumHiddenLayers*2)
for i := range n.NumHiddenLayers {
merges = append(merges, merge{
fmt.Sprintf("blk.%d.mixer.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.mixer.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
})
}
merged, rest := mergeTensors(ts, merges...)
out = append(out, merged...)
remaining = rest
}
nGroups := uint64(cmp.Or(n.NGroups, uint32(1)))
for _, t := range remaining {
name := t.Name()
shape := slices.Clone(t.Shape())
switch {
case strings.HasSuffix(name, ".ssm_a"):
shape = normalizeVectorShapeToColumn(shape)
t.SetRepacker(func(_ string, data []float32, _ []uint64) ([]float32, error) {
out := make([]float32, len(data))
for i, v := range data {
out[i] = -float32(math.Exp(float64(v)))
}
return out, nil
})
case strings.HasSuffix(name, ".ssm_d"):
shape = normalizeVectorShapeToColumn(shape)
case strings.HasSuffix(name, ".ssm_norm.weight"):
switch len(shape) {
case 1:
if nGroups > 0 && shape[0]%nGroups == 0 {
shape = []uint64{nGroups, shape[0] / nGroups}
}
case 2:
if shape[0] == 1 && nGroups > 0 && shape[1]%nGroups == 0 {
shape = []uint64{nGroups, shape[1] / nGroups}
}
}
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
if len(shape) == 3 {
if shape[0] == 1 {
shape = []uint64{shape[1], shape[2]}
} else if shape[1] == 1 {
shape = []uint64{shape[0], shape[2]}
}
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
})
}
return out
}
func (n *nemotronHModel) Replacements() []string {
return []string{
// Embedding and output
"lm_head", "output",
"backbone.embeddings", "token_embd",
"backbone.norm_f", "output_norm",
"backbone.layers", "blk",
// Recurrent (Mamba2) tensors
"mixer.in_proj", "ssm_in",
"mixer.out_proj", "ssm_out",
"mixer.dt_bias", "ssm_dt.bias",
"mixer.A_log", "ssm_a",
"mixer.D", "ssm_d",
"mixer.conv1d", "ssm_conv1d",
"mixer.norm.weight", "ssm_norm.weight",
// Attention tensors
"mixer.q_proj", "attn_q",
"mixer.k_proj", "attn_k",
"mixer.v_proj", "attn_v",
"mixer.o_proj", "attn_output",
// FFN / MoE tensors
"mixer.gate.e_score_correction_bias", "exp_probs_b.bias",
"mixer.gate", "ffn_gate_inp",
"mixer.fc1_latent_proj", "ffn_latent_in",
"mixer.fc2_latent_proj", "ffn_latent_out",
"mixer.shared_experts.up_proj", "ffn_up_shexp",
"mixer.shared_experts.down_proj", "ffn_down_shexp",
"mixer.up_proj", "ffn_up",
"mixer.down_proj", "ffn_down",
// Per-layer pre-norm
".norm.weight", ".attn_norm.weight",
}
}

View File

@@ -0,0 +1,230 @@
package convert
import (
"bytes"
"encoding/binary"
"encoding/json"
"io"
"os"
"path/filepath"
"slices"
"strings"
"testing"
)
func TestHybridPatternUnmarshal(t *testing.T) {
t.Run("string", func(t *testing.T) {
var p hybridPattern
if err := json.Unmarshal([]byte(`"MEM*"`), &p); err != nil {
t.Fatal(err)
}
if got, want := string(p), "MEM*"; got != want {
t.Fatalf("unexpected pattern: got %q want %q", got, want)
}
})
t.Run("array", func(t *testing.T) {
var p hybridPattern
if err := json.Unmarshal([]byte(`["M","E","M","*"]`), &p); err != nil {
t.Fatal(err)
}
if got, want := string(p), "MEM*"; got != want {
t.Fatalf("unexpected pattern: got %q want %q", got, want)
}
})
}
func TestNemotronHLayerArrays(t *testing.T) {
m := &nemotronHModel{
NumHiddenLayers: 5,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
HybridOverridePattern: "MEM*E",
NRoutedExperts: 128,
NumExpertsPerTok: 6,
MoEIntermediateSize: 1856,
}
headsKV, ffn, err := m.layerArrays()
if err != nil {
t.Fatal(err)
}
if got, want := headsKV, []uint32{0, 0, 0, 8, 0}; !slices.Equal(got, want) {
t.Fatalf("unexpected head_count_kv: got %v want %v", got, want)
}
if got, want := ffn, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
}
}
func TestNemotronHKV(t *testing.T) {
m := &nemotronHModel{
MaxPositionEmbeddings: 1048576,
HiddenSize: 2688,
NumHiddenLayers: 5,
NumAttentionHeads: 32,
NumKeyValueHeads: 2,
HeadDim: 128,
LayerNormEpsilon: 1e-5,
RopeTheta: 10000,
PartialRotaryFactor: 0.5,
ConvKernel: 4,
SSMStateSize: 128,
MambaNumHeads: 64,
MambaHeadDim: 64,
NGroups: 8,
HybridOverridePattern: "MEM*E",
NRoutedExperts: 128,
NSharedExperts: 1,
NumExpertsPerTok: 6,
MoEIntermediateSize: 1856,
MoESharedExpertIntermediate: 3712,
NormTopKProb: true,
RoutedScalingFactor: 2.5,
}
if err := m.parseMore(nil); err != nil {
t.Fatal(err)
}
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
if got, want := kv["general.architecture"], "nemotron_h_moe"; got != want {
t.Fatalf("unexpected architecture: got %v want %v", got, want)
}
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
if !ok {
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
}
if got, want := headCountKV, []uint32{0, 0, 0, 2, 0}; !slices.Equal(got, want) {
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
}
ffnLength, ok := kv["feed_forward_length"].([]uint32)
if !ok {
t.Fatalf("feed_forward_length has unexpected type: %T", kv["feed_forward_length"])
}
if got, want := ffnLength, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
}
}
func TestNemotronHTensorsTransforms(t *testing.T) {
m := &nemotronHModel{NGroups: 8}
in := []Tensor{
&fakeTensor{
name: "blk.0.ssm_a",
shape: []uint64{4},
data: []float32{0, 1, 2, 3},
},
&fakeTensor{
name: "blk.0.ssm_d",
shape: []uint64{4},
data: []float32{0, 1, 2, 3},
},
&fakeTensor{
name: "blk.0.ssm_norm.weight",
shape: []uint64{16},
data: make([]float32, 16),
},
&fakeTensor{
name: "blk.0.ssm_conv1d.weight",
shape: []uint64{10, 1, 4},
data: make([]float32, 40),
},
}
out := m.Tensors(in)
if len(out) != len(in) {
t.Fatalf("unexpected output tensor count: got %d want %d", len(out), len(in))
}
got := map[string]struct {
shape []uint64
writer io.WriterTo
}{}
for _, t := range out {
got[t.Name] = struct {
shape []uint64
writer io.WriterTo
}{shape: t.Shape, writer: t.WriterTo}
}
if shape := got["blk.0.ssm_a"].shape; !slices.Equal(shape, []uint64{4, 1}) {
t.Fatalf("unexpected ssm_a shape: %v", shape)
}
if shape := got["blk.0.ssm_d"].shape; !slices.Equal(shape, []uint64{4, 1}) {
t.Fatalf("unexpected ssm_d shape: %v", shape)
}
if shape := got["blk.0.ssm_norm.weight"].shape; !slices.Equal(shape, []uint64{8, 2}) {
t.Fatalf("unexpected ssm_norm shape: %v", shape)
}
if shape := got["blk.0.ssm_conv1d.weight"].shape; !slices.Equal(shape, []uint64{10, 4}) {
t.Fatalf("unexpected ssm_conv1d shape: %v", shape)
}
var b bytes.Buffer
if _, err := got["blk.0.ssm_a"].writer.WriteTo(&b); err != nil {
t.Fatal(err)
}
values := make([]float32, 4)
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
t.Fatal(err)
}
// 0 -> -exp(0) == -1
if values[0] != -1 {
t.Fatalf("unexpected transformed ssm_a[0]: got %v want -1", values[0])
}
}
func TestNemotronHLoadModelMetadata(t *testing.T) {
tempDir := t.TempDir()
config := `{
"architectures": ["NemotronHForCausalLM"],
"model_type": "nemotron_h",
"num_hidden_layers": 4,
"hidden_size": 512,
"max_position_embeddings": 32768,
"num_attention_heads": 8,
"num_key_value_heads": 2,
"head_dim": 64,
"layer_norm_epsilon": 1e-5,
"conv_kernel": 4,
"ssm_state_size": 128,
"mamba_num_heads": 16,
"mamba_head_dim": 32,
"n_groups": 8,
"hybrid_override_pattern": "ME*M",
"n_routed_experts": 16,
"num_experts_per_tok": 4,
"moe_intermediate_size": 256
}`
if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil {
t.Fatal(err)
}
kv, _, err := LoadModelMetadata(os.DirFS(tempDir))
if err != nil {
t.Fatal(err)
}
if _, ok := kv.(*nemotronHModel); !ok {
t.Fatalf("unexpected converter type: %T", kv)
}
}
func TestNemotronHReplacementsLatentProjections(t *testing.T) {
m := &nemotronHModel{}
r := strings.NewReplacer(m.Replacements()...)
if got, want := r.Replace("backbone.layers.1.mixer.fc1_latent_proj.weight"), "blk.1.ffn_latent_in.weight"; got != want {
t.Fatalf("unexpected fc1 replacement: got %q want %q", got, want)
}
if got, want := r.Replace("backbone.layers.1.mixer.fc2_latent_proj.weight"), "blk.1.ffn_latent_out.weight"; got != want {
t.Fatalf("unexpected fc2 replacement: got %q want %q", got, want)
}
}

97
convert/json_compat.go Normal file
View File

@@ -0,0 +1,97 @@
package convert
// sanitizeNonFiniteJSON rewrites non-standard JSON numeric tokens that some
// HF configs emit (Infinity, -Infinity, NaN) into standard JSON numbers.
//
// This is intentionally conservative:
// - only runs outside quoted strings
// - only rewrites full tokens
//
// We map these values to 0 because encoding/json rejects non-finite values,
// and these fields are typically model-side metadata not consumed by the
// converter.
func sanitizeNonFiniteJSON(in []byte) []byte {
if len(in) == 0 {
return in
}
out := make([]byte, 0, len(in))
inString := false
escape := false
for i := 0; i < len(in); {
c := in[i]
if inString {
out = append(out, c)
if escape {
escape = false
} else if c == '\\' {
escape = true
} else if c == '"' {
inString = false
}
i++
continue
}
if c == '"' {
inString = true
out = append(out, c)
i++
continue
}
if hasToken(in, i, "-Infinity") {
out = append(out, '0')
i += len("-Infinity")
continue
}
if hasToken(in, i, "Infinity") {
out = append(out, '0')
i += len("Infinity")
continue
}
if hasToken(in, i, "NaN") {
out = append(out, '0')
i += len("NaN")
continue
}
out = append(out, c)
i++
}
return out
}
func hasToken(in []byte, at int, tok string) bool {
end := at + len(tok)
if at < 0 || end > len(in) {
return false
}
if string(in[at:end]) != tok {
return false
}
if at > 0 && !isJSONValuePrefixBoundary(in[at-1]) {
return false
}
if end < len(in) && !isJSONValueSuffixBoundary(in[end]) {
return false
}
return true
}
func isJSONWhitespace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
func isJSONValuePrefixBoundary(b byte) bool {
return isJSONWhitespace(b) || b == ':' || b == ',' || b == '['
}
func isJSONValueSuffixBoundary(b byte) bool {
return isJSONWhitespace(b) || b == ',' || b == ']' || b == '}'
}

View File

@@ -0,0 +1,46 @@
package convert
import "testing"
func TestSanitizeNonFiniteJSON(t *testing.T) {
tests := []struct {
name string
in string
want string
}{
{
name: "infinity token",
in: `{"a":[0,Infinity,1]}`,
want: `{"a":[0,0,1]}`,
},
{
name: "negative infinity token",
in: `{"a":-Infinity}`,
want: `{"a":0}`,
},
{
name: "nan token",
in: `{"a":NaN}`,
want: `{"a":0}`,
},
{
name: "tokens inside strings untouched",
in: `{"a":"Infinity -Infinity NaN","b":Infinity}`,
want: `{"a":"Infinity -Infinity NaN","b":0}`,
},
{
name: "identifier-like token untouched",
in: `{"a":InfinityValue}`,
want: `{"a":InfinityValue}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := string(sanitizeNonFiniteJSON([]byte(tt.in)))
if got != tt.want {
t.Fatalf("sanitizeNonFiniteJSON() = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -212,8 +212,13 @@ type tokenizer struct {
PreTokenizer struct {
PreTokenizers []struct {
Type string `json:"type"`
Pattern struct {
Type string `json:"type"`
Behavior string `json:"behavior"`
Invert bool `json:"invert"`
AddPrefixSpace bool `json:"add_prefix_space"`
TrimOffsets bool `json:"trim_offsets"`
UseRegex bool `json:"use_regex"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`

View File

@@ -191,6 +191,84 @@ func TestParseTokenizer(t *testing.T) {
Pre: "default",
},
},
{
name: "llama-bpe pretokenizer and control tokens",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"added_tokens": [
{"id": 1, "content": "<|startoftext|>", "special": true},
{"id": 6, "content": "<|im_start|>", "special": true},
{"id": 7, "content": "<|im_end|>", "special": true},
{"id": 8, "content": "<|tool_list_start|>", "special": true},
{"id": 9, "content": "<|tool_list_end|>", "special": true},
{"id": 10, "content": "<|tool_call_start|>", "special": true},
{"id": 11, "content": "<|tool_call_end|>", "special": true},
{"id": 12, "content": "<|tool_response_start|>", "special": true},
{"id": 13, "content": "<|tool_response_end|>", "special": true},
{"id": 396, "content": "<image>", "special": true},
{"id": 64400, "content": "<think>", "special": true},
{"id": 64401, "content": "</think>", "special": true}
],
"model": {
"vocab": {
"<|startoftext|>": 1,
"<|im_start|>": 6,
"<|im_end|>": 7,
"<|tool_list_start|>": 8,
"<|tool_list_end|>": 9,
"<|tool_call_start|>": 10,
"<|tool_call_end|>": 11,
"<|tool_response_start|>": 12,
"<|tool_response_end|>": 13,
"<image>": 396,
"<think>": 64400,
"</think>": 64401
}
},
"pre_tokenizer": {
"type": "Sequence",
"pretokenizers": [
{
"type": "Split",
"pattern": {
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
},
"behavior": "Isolated",
"invert": false
},
{
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": true,
"use_regex": false
}
]
}
}`),
}),
want: &Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
Tokens: []string{
"<|startoftext|>",
"<|im_start|>",
"<|im_end|>",
"<|tool_list_start|>",
"<|tool_list_end|>",
"<|tool_call_start|>",
"<|tool_call_end|>",
"<|tool_response_start|>",
"<|tool_response_end|>",
"<image>",
"<think>",
"</think>",
},
Scores: []float32{1, 6, 7, 8, 9, 10, 11, 12, 13, 396, 64400, 64401},
Types: []int32{3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3},
},
Pre: "llama-bpe",
},
},
{
name: "list string merges",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{

View File

@@ -160,6 +160,27 @@ func (kv KV) SSMGroupCount() uint64 {
return uint64(kv.Uint("ssm.group_count"))
}
func (kv KV) FFNLength() []uint64 {
ffnLengthDefault := uint32(0)
ffnLength := kv.UintOrArrayValueAsArray("feed_forward_length", ffnLengthDefault)
if len(ffnLength) == 1 {
ffnLengthDefault = ffnLength[0]
}
nLayers := int(kv.BlockCount())
if len(ffnLength) > nLayers {
slog.Warn("got more elements of feed_forward_length than layers", "len(ffnLength)", len(ffnLength), "layers", nLayers)
}
out := make([]uint64, nLayers)
for i := range nLayers {
if i >= len(ffnLength) {
out[i] = uint64(ffnLengthDefault)
} else {
out[i] = uint64(ffnLength[i])
}
}
return out
}
// general types
func (kv KV) String(key string, defaultValue ...string) string {
@@ -264,6 +285,7 @@ func (kv KV) OllamaEngineRequired() bool {
"llama4",
"mistral3",
"mllama",
"nemotron_h", "nemotron_h_moe",
"nomic-bert",
"olmo3",
"qwen25vl",
@@ -273,6 +295,7 @@ func (kv KV) OllamaEngineRequired() bool {
"glm4moelite",
"glmocr",
"lfm2",
"lfm2moe",
}, kv.Architecture())
}
@@ -864,7 +887,9 @@ func (f GGML) FlashAttention() bool {
"glmocr",
"gptoss", "gpt-oss",
"lfm2",
"lfm2moe",
"mistral3",
"nemotron_h", "nemotron_h_moe",
"olmo3",
"qwen3", "qwen3moe",
"qwen3next",

752
kvcache/recurrent.go Normal file
View File

@@ -0,0 +1,752 @@
package kvcache
import (
"errors"
"fmt"
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
const (
DefaultCheckpointCount = 32
DefaultCheckpointMinPos = int32(16)
DefaultCheckpointInterval = int32(1280)
)
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
// Config configures a shared hybrid recurrent cache.
type RecurrentConfig struct {
Shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
ConvDim int
ConvChannels int
RecurrentStateSize int
CheckpointLogPrefix string
}
var (
_ Cache = (*Recurrent)(nil)
_ CheckpointCache = (*Recurrent)(nil)
)
// Cache stores:
// - a standard causal KV cache
// - per-sequence conv state for recurrent operators
// - per-sequence recurrent state for recurrent operators
//
// Conv state shape (per layer, per sequence): [convDim, convChannels]
// Recurrent state shape (per layer, per sequence): [recurrentStateSize]
type Recurrent struct {
kv *Causal
backend ml.Backend
dtype ml.DType
maxSequences int
// Conv state dimensions
convDim int
convChannels int
// Recurrent state dimensions
recurrentStateSize int
logPrefix string
// slot mapping for recurrent state (copy-on-write)
slotForSeq map[int]int
refCount []int
freeSlots []int
seqCounts map[int]int
slotScratch [1]int32
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
// per-layer recurrent state buffers (allocated lazily)
recurrentCtxs map[int]ml.Context
recurrentStates map[int]ml.Tensor // [recurrentStateSize, maxSlots]
// recurrent checkpoints (per slot)
checkpointCount int
checkpointMinPos int32
checkpointInterval int32
checkpointCtxSize int
checkpoints map[int]*slotCheckpointStore
pendingRestore map[int]checkpointRestore
curCheckpointPos []int32
curCheckpointSlots map[int]int
reserveCheckpoints bool
checkpointConvCtxs map[int]ml.Context
checkpointRecurCtxs map[int]ml.Context
checkpointReserved map[int]struct{}
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
writableError error
}
func NewRecurrentCache(config RecurrentConfig) *Recurrent {
return &Recurrent{
kv: NewCausalCache(config.Shift),
convDim: config.ConvDim,
convChannels: config.ConvChannels,
recurrentStateSize: config.RecurrentStateSize,
logPrefix: config.CheckpointLogPrefix,
slotForSeq: make(map[int]int),
seqCounts: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
recurrentCtxs: make(map[int]ml.Context),
recurrentStates: make(map[int]ml.Tensor),
checkpointCount: DefaultCheckpointCount,
checkpointMinPos: DefaultCheckpointMinPos,
checkpointInterval: DefaultCheckpointInterval,
checkpoints: make(map[int]*slotCheckpointStore),
pendingRestore: make(map[int]checkpointRestore),
curCheckpointSlots: make(map[int]int),
checkpointConvCtxs: make(map[int]ml.Context),
checkpointRecurCtxs: make(map[int]ml.Context),
checkpointReserved: make(map[int]struct{}),
}
}
func (c *Recurrent) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
c.checkpoints = make(map[int]*slotCheckpointStore)
c.pendingRestore = make(map[int]checkpointRestore)
c.curCheckpointPos = c.curCheckpointPos[:0]
c.curCheckpointSlots = make(map[int]int)
c.checkpointReserved = make(map[int]struct{})
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
if c.checkpointCtxSize < 8 {
c.checkpointCtxSize = 8
}
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *Recurrent) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
for _, ctx := range c.recurrentCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointConvCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointRecurCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *Recurrent) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *Recurrent) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *Recurrent) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *Recurrent) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *Recurrent) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
nTokens := len(batch.Sequences)
if nTokens == 0 {
c.curSeqs = c.curSeqs[:0]
c.curSlots = c.curSlots[:0]
c.curSlotsInput = nil
c.curSeqTokens = 0
c.reserveCheckpoints = false
c.writableEnsured = false
c.writableError = nil
return nil
}
// Fast path for single-sequence batches (common during decode and prefill).
firstSeq := batch.Sequences[0]
singleSeq := true
for _, s := range batch.Sequences[1:] {
if s != firstSeq {
singleSeq = false
break
}
}
if singleSeq {
return c.startForwardSingleSeq(ctx, firstSeq, nTokens, batch, reserve)
}
// Derive equal-length sequence layout for recurrent layers.
seqCounts := c.seqCounts
for s := range seqCounts {
delete(seqCounts, s)
}
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if seqCounts[s] == 0 {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return ErrNotSupported
}
}
c.curSeqTokens = want
if reserve {
c.curSlots = c.curSlots[:0]
for i := range nSeqs {
c.curSlots = append(c.curSlots, i)
}
c.finalizeStartForward(ctx, batch, true)
return nil
}
// Ensure slots exist for sequences in this batch.
c.curSlots = c.curSlots[:0]
var newSlots []int
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
c.curSlots = append(c.curSlots, slot)
}
if len(newSlots) > 0 {
c.zeroSlots(ctx, newSlots)
}
c.finalizeStartForward(ctx, batch, false)
return nil
}
func (c *Recurrent) startForwardSingleSeq(ctx ml.Context, seq, seqTokens int, batch input.Batch, reserve bool) error {
c.curSeqs = append(c.curSeqs[:0], seq)
c.curSeqTokens = seqTokens
if reserve {
c.curSlots = append(c.curSlots[:0], 0)
c.finalizeStartForward(ctx, batch, true)
return nil
}
slot, ok := c.slotForSeq[seq]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[seq] = slot
c.refCount[slot] = 1
slotList := [1]int{slot}
c.zeroSlots(ctx, slotList[:])
}
c.curSlots = append(c.curSlots[:0], slot)
c.finalizeStartForward(ctx, batch, false)
return nil
}
func (c *Recurrent) finalizeStartForward(ctx ml.Context, batch input.Batch, reserve bool) {
c.setCurSlotsInput(ctx)
c.writableEnsured = false
c.writableError = nil
c.reserveCheckpoints = reserve
c.planCheckpoints(batch)
}
func (c *Recurrent) setCurSlotsInput(ctx ml.Context) {
c.curSlotsInput = c.slotsInput(ctx, c.curSlots)
}
func (c *Recurrent) slotsInput(ctx ml.Context, slots []int) ml.Tensor {
switch len(slots) {
case 0:
return nil
case 1:
c.slotScratch[0] = int32(slots[0])
return ctx.Input().FromInts(c.slotScratch[:], 1)
default:
slotIndices := make([]int32, len(slots))
for i, v := range slots {
slotIndices[i] = int32(v)
}
return ctx.Input().FromInts(slotIndices, len(slotIndices))
}
}
func (c *Recurrent) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *Recurrent) freeSlot(slot int) {
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// zeroSlots zeros recurrent state for the given slots across all cached layers.
func (c *Recurrent) zeroSlots(ctx ml.Context, slots []int) {
if len(slots) == 0 {
return
}
inputCtx := ctx.Input()
slotsTensor := c.slotsInput(ctx, slots)
if len(c.convStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
for _, buf := range c.convStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
if len(c.recurrentStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.recurrentStateSize, len(slots))
for _, buf := range c.recurrentStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
}
// EnsureWritable ensures sequences have private slots (copy-on-write).
func (c *Recurrent) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
}
c.setCurSlotsInput(ctx)
return nil
}
func (c *Recurrent) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
for _, buf := range c.convStates {
rows := buf.Rows(ctx, src)
if rows.DType() != ml.DTypeF32 {
rows = rows.Cast(ctx, ml.DTypeF32)
}
ctx.Forward(buf.SetRows(ctx, rows, dst))
}
for _, buf := range c.recurrentStates {
rows := buf.Rows(ctx, src)
if rows.DType() != ml.DTypeF32 {
rows = rows.Cast(ctx, ml.DTypeF32)
}
ctx.Forward(buf.SetRows(ctx, rows, dst))
}
}
func (c *Recurrent) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
if c.validSlot(dstSlot) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
return
}
if c.validSlot(srcSlot) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *Recurrent) CanResume(seq int, pos int32) bool {
if !c.kv.CanResume(seq, pos) {
return false
}
if pos == 0 {
return true
}
return c.hasCheckpoint(seq, pos)
}
func (c *Recurrent) Remove(seq int, beginIndex, endIndex int32) error {
if beginIndex > 0 && endIndex != math.MaxInt32 {
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
delete(c.pendingRestore, seq)
slot, ok := c.slotForSeq[seq]
if !ok || !c.validSlot(slot) {
return nil
}
// Detach shared recurrent state/checkpoints before mutating checkpoint positions.
if c.refCount[slot] > 1 {
newSlot, err := c.allocSlot()
if err != nil {
return err
}
ctx := c.backend.NewContext()
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
ctx.Compute()
}
ctx.Close()
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
slot = newSlot
}
c.shiftCheckpoints(slot, beginIndex, endIndex)
return nil
}
if beginIndex > 0 {
restore, ok := c.pendingRestore[seq]
if !ok || restore.pos+1 != beginIndex {
return ErrNotSupported
}
if !c.restoreComplete(restore) {
return ErrNotSupported
}
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
newSlot, err := c.allocSlot()
if err != nil {
return err
}
ctx := c.backend.NewContext()
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
ctx.Compute()
}
ctx.Close()
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
restore.slot = newSlot
c.pendingRestore[seq] = restore
}
}
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
if beginIndex > 0 {
restore := c.pendingRestore[seq]
delete(c.pendingRestore, seq)
return c.applyCheckpointRestore(restore)
}
slot, ok := c.slotForSeq[seq]
delete(c.pendingRestore, seq)
if !ok {
return nil
}
if !c.validSlot(slot) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.clearCheckpoints(slot)
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
}
func (c *Recurrent) validSlot(slot int) bool {
return slot >= 0 && slot < len(c.refCount)
}
func (c *Recurrent) SlotsTensor() ml.Tensor {
return c.curSlotsInput
}
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
func (c *Recurrent) contiguousSlots() (int, bool) {
if len(c.curSlots) == 0 {
return 0, false
}
start := c.curSlots[0]
for i, s := range c.curSlots {
if s != start+i {
return 0, false
}
}
return start, true
}
func (c *Recurrent) SeqTokens() int {
return c.curSeqTokens
}
func (c *Recurrent) NumSeqs() int {
return len(c.curSeqs)
}
func (c *Recurrent) convBuffer(layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
c.convStates[layer] = buf
return buf
}
func (c *Recurrent) recurrentBuffer(layer int) ml.Tensor {
if buf, ok := c.recurrentStates[layer]; ok {
return buf
}
if _, ok := c.recurrentCtxs[layer]; !ok {
c.recurrentCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
buf := c.recurrentCtxs[layer].Zeros(ml.DTypeF32, c.recurrentStateSize, c.maxSequences)
c.recurrentStates[layer] = buf
return buf
}
func (c *Recurrent) ensureWritable(ctx ml.Context) error {
c.ensureWritableOnce(ctx)
return c.writableError
}
func (c *Recurrent) currentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int) ml.Tensor {
if start, ok := c.contiguousSlots(); ok {
offset := start * buf.Stride(1)
return buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
}
return buf.Rows(ctx, c.SlotsTensor())
}
func (c *Recurrent) writeCurrentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int, src ml.Tensor) {
if start, ok := c.contiguousSlots(); ok {
offset := start * buf.Stride(1)
view := buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
ctx.Forward(src.Copy(ctx, view))
return
}
ctx.Forward(buf.SetRows(ctx, src, c.SlotsTensor()))
}
func (c *Recurrent) ensureWritableOnce(ctx ml.Context) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
}
// ConvState returns conv state for current batch sequences as [convDim, convChannels, nSeqs].
func (c *Recurrent) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
if err := c.ensureWritable(ctx); err != nil {
return nil, err
}
buf := c.convBuffer(layer)
cur := c.currentSlotRows(ctx, buf, c.convDim*c.convChannels)
return cur.Reshape(ctx, c.convDim, c.convChannels, c.NumSeqs()), nil
}
// UpdateConvState writes new conv state for current batch sequences.
func (c *Recurrent) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(layer)
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.NumSeqs())
srcF32 := src
if src.DType() != ml.DTypeF32 {
srcF32 = src.Cast(ctx, ml.DTypeF32)
}
c.writeCurrentSlotRows(ctx, buf, c.convDim*c.convChannels, srcF32)
c.captureConvCheckpoint(ctx, layer, srcF32)
}
// RecurrentState returns recurrent state for current batch sequences with shape [dims..., nSeqs].
func (c *Recurrent) RecurrentState(ctx ml.Context, layer int, dims ...int) (ml.Tensor, error) {
if err := c.ensureWritable(ctx); err != nil {
return nil, err
}
if len(dims) == 0 {
return nil, ErrInvalidRecurrentShape
}
size := 1
for _, d := range dims {
if d <= 0 {
return nil, ErrInvalidRecurrentShape
}
size *= d
}
if size != c.recurrentStateSize {
return nil, fmt.Errorf("%w: got %v (size %d), want size %d", ErrInvalidRecurrentShape, dims, size, c.recurrentStateSize)
}
buf := c.recurrentBuffer(layer)
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
shape := make([]int, 0, len(dims)+1)
shape = append(shape, dims...)
shape = append(shape, c.NumSeqs())
return cur.Reshape(ctx, shape...), nil
}
// RecurrentState4D returns recurrent state as [dim0, dim1, dim2, nSeqs].
func (c *Recurrent) RecurrentState4D(ctx ml.Context, layer int, dim0, dim1, dim2 int) (ml.Tensor, error) {
if err := c.ensureWritable(ctx); err != nil {
return nil, err
}
if dim0 <= 0 || dim1 <= 0 || dim2 <= 0 {
return nil, ErrInvalidRecurrentShape
}
size := dim0 * dim1 * dim2
if size != c.recurrentStateSize {
return nil, fmt.Errorf("%w: got [%d %d %d] (size %d), want size %d", ErrInvalidRecurrentShape, dim0, dim1, dim2, size, c.recurrentStateSize)
}
buf := c.recurrentBuffer(layer)
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
return cur.Reshape(ctx, dim0, dim1, dim2, c.NumSeqs()), nil
}
// UpdateRecurrentState writes new recurrent state for current batch sequences.
func (c *Recurrent) UpdateRecurrentState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.recurrentBuffer(layer)
src := newState.Reshape(ctx, c.recurrentStateSize, c.NumSeqs())
srcF32 := src
if src.DType() != ml.DTypeF32 {
srcF32 = src.Cast(ctx, ml.DTypeF32)
}
c.writeCurrentSlotRows(ctx, buf, c.recurrentStateSize, srcF32)
c.captureRecurrentCheckpoint(ctx, layer, srcF32)
}
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
func (c *Recurrent) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *Recurrent) Seqs() []int {
return slices.Clone(c.curSeqs)
}

View File

@@ -0,0 +1,561 @@
package kvcache
import (
"log/slog"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
// memory usage while preserving prefix reuse for recurrent state.
type checkpointEntry struct {
pos int32
conv map[int]ml.Tensor
recurrent map[int]ml.Tensor
}
type slotCheckpointStore struct {
entries []checkpointEntry
size int
next int
lastPos int32
}
type checkpointRestore struct {
slot int
idx int
pos int32
}
func newSlotCheckpointStore(n int) *slotCheckpointStore {
entries := make([]checkpointEntry, n)
for i := range entries {
entries[i].pos = -1
}
return &slotCheckpointStore{
entries: entries,
lastPos: -1,
}
}
func (s *slotCheckpointStore) reset() {
s.size = 0
s.next = 0
s.lastPos = -1
for i := range s.entries {
s.entries[i].pos = -1
}
}
func (s *slotCheckpointStore) record(pos int32) int {
if len(s.entries) == 0 {
return -1
}
idx := s.next
s.next = (s.next + 1) % len(s.entries)
if s.size < len(s.entries) {
s.size++
}
s.entries[idx].pos = pos
s.lastPos = pos
return idx
}
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
bestIdx := -1
bestPos := int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 || pos >= targetPos {
continue
}
if pos > bestPos {
bestPos = pos
bestIdx = i
}
}
if bestIdx < 0 {
return -1, -1, false
}
return bestIdx, bestPos, true
}
func (s *slotCheckpointStore) pruneAfter(pos int32) {
if len(s.entries) == 0 {
s.size = 0
s.next = 0
s.lastPos = -1
return
}
size := 0
next := -1
minPos := int32(math.MaxInt32)
minIdx := 0
for i := range s.entries {
if s.entries[i].pos > pos {
s.entries[i].pos = -1
}
if s.entries[i].pos >= 0 {
size++
if s.entries[i].pos < minPos {
minPos = s.entries[i].pos
minIdx = i
}
} else if next == -1 {
next = i
}
}
s.size = size
if size == 0 {
s.next = 0
s.lastPos = -1
return
}
if next != -1 {
s.next = next
} else {
// Full ring: overwrite the oldest checkpoint next.
s.next = minIdx
}
s.lastPos = pos
}
func (s *slotCheckpointStore) shiftRange(beginIndex, endIndex int32) {
if len(s.entries) == 0 {
s.size = 0
s.next = 0
s.lastPos = -1
return
}
offset := beginIndex - endIndex
size := 0
next := -1
minPos := int32(math.MaxInt32)
maxPos := int32(-1)
minIdx := 0
for i := range s.entries {
pos := s.entries[i].pos
if pos >= 0 {
if pos >= beginIndex && pos < endIndex {
s.entries[i].pos = -1
} else if pos >= endIndex {
s.entries[i].pos = pos + offset
}
}
pos = s.entries[i].pos
if pos >= 0 {
size++
if pos < minPos {
minPos = pos
minIdx = i
}
if pos > maxPos {
maxPos = pos
}
} else if next == -1 {
next = i
}
}
s.size = size
if size == 0 {
s.next = 0
s.lastPos = -1
return
}
if next != -1 {
s.next = next
} else {
// Full ring: overwrite the oldest checkpoint next.
s.next = minIdx
}
s.lastPos = maxPos
}
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
minPos = int32(math.MaxInt32)
maxPos = int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 {
continue
}
size++
if pos < minPos {
minPos = pos
}
if pos > maxPos {
maxPos = pos
}
}
if size == 0 {
minPos = -1
maxPos = -1
}
return size, minPos, maxPos, s.lastPos
}
func (c *Recurrent) checkpointTag() string {
if c.logPrefix == "" {
return "kvcache.recurrent"
}
return c.logPrefix
}
func (c *Recurrent) planCheckpoints(batch input.Batch) {
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
c.curCheckpointPos = c.curCheckpointPos[:0]
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
return
}
if cap(c.curCheckpointPos) < len(c.curSeqs) {
c.curCheckpointPos = make([]int32, len(c.curSeqs))
} else {
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
}
for i := range c.curCheckpointPos {
c.curCheckpointPos[i] = -1
}
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
posMax := make(map[int]int32, len(c.curSeqs))
for i, seq := range batch.Sequences {
pos := batch.Positions[i]
if cur, ok := posMax[seq]; !ok || pos > cur {
posMax[seq] = pos
}
}
for i, seq := range c.curSeqs {
pos, ok := posMax[seq]
if !ok {
continue
}
if pos < c.checkpointMinPos {
continue
}
slot := c.curSlots[i]
store := c.checkpointStore(slot)
lastPos := store.lastPos
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
c.curCheckpointPos[i] = pos
}
}
}
func (c *Recurrent) checkpointStore(slot int) *slotCheckpointStore {
store, ok := c.checkpoints[slot]
if ok {
return store
}
store = newSlotCheckpointStore(c.checkpointCount)
c.checkpoints[slot] = store
return store
}
func (c *Recurrent) checkpointIndexForSlot(slot int, pos int32) int {
if c.checkpointCount == 0 {
return -1
}
if idx, ok := c.curCheckpointSlots[slot]; ok {
return idx
}
store := c.checkpointStore(slot)
idx := store.record(pos)
if idx >= 0 {
c.curCheckpointSlots[slot] = idx
}
return idx
}
func (c *Recurrent) hasCheckpoint(seq int, pos int32) bool {
if pos <= 0 {
return false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return false
}
store, ok := c.checkpoints[slot]
if !ok {
return false
}
_, _, ok = store.bestIndex(pos)
return ok
}
func (c *Recurrent) PrepareRestore(seq int, targetPos int32) (int32, bool) {
if targetPos <= 0 {
return 0, false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return 0, false
}
store, ok := c.checkpoints[slot]
if !ok {
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
return 0, false
}
idx, pos, ok := store.bestIndex(targetPos)
if !ok {
size, minPos, maxPos, lastPos := store.window()
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
"min", minPos, "max", maxPos, "last", lastPos)
return 0, false
}
c.pendingRestore[seq] = checkpointRestore{
slot: slot,
idx: idx,
pos: pos,
}
return pos + 1, true
}
func (c *Recurrent) applyCheckpointRestore(restore checkpointRestore) error {
entry, ok := c.restoreEntry(restore)
if !ok {
return ErrNotSupported
}
ctx := c.backend.NewContext()
defer ctx.Close()
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
for layer, src := range entry.conv {
buf := c.convBuffer(layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
for layer, src := range entry.recurrent {
buf := c.recurrentBuffer(layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
if len(entry.conv) > 0 || len(entry.recurrent) > 0 {
ctx.Compute()
}
store := c.checkpoints[restore.slot]
store.pruneAfter(restore.pos)
return nil
}
func (c *Recurrent) restoreComplete(restore checkpointRestore) bool {
_, ok := c.restoreEntry(restore)
return ok
}
func (c *Recurrent) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
store, ok := c.checkpoints[restore.slot]
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
return nil, false
}
entry := &store.entries[restore.idx]
if entry.pos < 0 {
return nil, false
}
if !c.entryComplete(entry) {
return nil, false
}
return entry, true
}
func (c *Recurrent) entryComplete(entry *checkpointEntry) bool {
for layer := range c.convStates {
if entry.conv == nil || entry.conv[layer] == nil {
return false
}
}
for layer := range c.recurrentStates {
if entry.recurrent == nil || entry.recurrent[layer] == nil {
return false
}
}
return true
}
func (c *Recurrent) clearCheckpoints(slot int) {
if store, ok := c.checkpoints[slot]; ok {
store.reset()
}
}
func (c *Recurrent) shiftCheckpoints(slot int, beginIndex, endIndex int32) {
if store, ok := c.checkpoints[slot]; ok {
store.shiftRange(beginIndex, endIndex)
}
}
func (c *Recurrent) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
if c.checkpointCount == 0 {
return
}
srcStore, ok := c.checkpoints[srcSlot]
if !ok || srcStore.size == 0 {
return
}
dstStore := c.checkpointStore(dstSlot)
dstStore.size = srcStore.size
dstStore.next = srcStore.next
dstStore.lastPos = srcStore.lastPos
for i := range srcStore.entries {
srcEntry := &srcStore.entries[i]
dstEntry := &dstStore.entries[i]
dstEntry.pos = srcEntry.pos
if srcEntry.conv != nil {
if dstEntry.conv == nil {
dstEntry.conv = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.conv {
dst := c.ensureCheckpointConv(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
if srcEntry.recurrent != nil {
if dstEntry.recurrent == nil {
dstEntry.recurrent = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.recurrent {
dst := c.ensureCheckpointRecurrent(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
}
}
func (c *Recurrent) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointConv(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointConv(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *Recurrent) captureRecurrentCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointRecurrent(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointRecurrent(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *Recurrent) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
if entry.conv == nil {
entry.conv = make(map[int]ml.Tensor)
}
if t, ok := entry.conv[layer]; ok {
return t
}
ctx, ok := c.checkpointConvCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointConvCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
entry.conv[layer] = t
return t
}
func (c *Recurrent) ensureCheckpointRecurrent(layer int, entry *checkpointEntry) ml.Tensor {
if entry.recurrent == nil {
entry.recurrent = make(map[int]ml.Tensor)
}
if t, ok := entry.recurrent[layer]; ok {
return t
}
ctx, ok := c.checkpointRecurCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointRecurCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.recurrentStateSize, 1)
entry.recurrent[layer] = t
return t
}
func (c *Recurrent) reserveCheckpointConv(layer int) {
key := checkpointReserveKey(layer, 0)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointConv(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func (c *Recurrent) reserveCheckpointRecurrent(layer int) {
key := checkpointReserveKey(layer, 1)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointRecurrent(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func checkpointReserveKey(layer int, kind int) int {
return layer*2 + kind
}

View File

@@ -0,0 +1,288 @@
package kvcache
import (
"errors"
"math"
"slices"
"testing"
"github.com/ollama/ollama/ml"
)
func newTestCache() *Recurrent {
return NewRecurrentCache(RecurrentConfig{ConvDim: 1, ConvChannels: 2, RecurrentStateSize: 2})
}
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
store := newSlotCheckpointStore(2)
store.record(10)
store.record(20)
_, pos, ok := store.bestIndex(15)
if !ok || pos != 10 {
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
}
store.record(30) // overwrite oldest (10)
if _, _, ok := store.bestIndex(15); ok {
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
}
_, pos, ok = store.bestIndex(40)
if !ok || pos != 30 {
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
}
}
func TestCachePrepareRestore(t *testing.T) {
cache := newTestCache()
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
store := cache.checkpointStore(0)
store.record(5)
store.record(9)
store.record(15)
restorePos, ok := cache.PrepareRestore(1, 12)
if !ok {
t.Fatalf("expected restore ok")
}
if restorePos != 10 {
t.Fatalf("expected restorePos 10, got %d", restorePos)
}
rest, ok := cache.pendingRestore[1]
if !ok {
t.Fatalf("expected pending restore entry")
}
if rest.pos != 9 {
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
}
}
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.pruneAfter(20)
if store.lastPos != 20 {
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
}
_, pos, ok := store.bestIndex(25)
if !ok || pos != 20 {
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
}
_, pos, ok = store.bestIndex(35)
if !ok || pos != 20 {
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
}
}
func TestCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
cache := newTestCache()
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
// Simulate layer 0 requires both conv and recurrent checkpoints.
cache.convStates[0] = nil
cache.recurrentStates[0] = nil
store := cache.checkpointStore(0)
idx := store.record(9)
entry := &store.entries[idx]
entry.conv = map[int]ml.Tensor{0: nil}
// entry.recurrent intentionally missing
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
err := cache.Remove(1, 10, math.MaxInt32)
if !errors.Is(err, ErrNotSupported) {
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
}
}
func TestCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
cache := newTestCache()
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
store := cache.checkpointStore(0)
idx := store.record(9)
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
restore := cache.pendingRestore[1]
if !cache.restoreComplete(restore) {
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
}
}
func TestCacheRecurrentStateShapeValidation(t *testing.T) {
cache := newTestCache()
_, err := cache.RecurrentState(nil, 0, 3)
if !errors.Is(err, ErrInvalidRecurrentShape) {
t.Fatalf("expected ErrInvalidRecurrentShape, got %v", err)
}
}
func TestSlotCheckpointStoreShiftRange(t *testing.T) {
store := newSlotCheckpointStore(5)
store.record(1)
store.record(4)
store.record(7)
store.record(10)
store.shiftRange(2, 6)
var positions []int32
for i := range store.entries {
if store.entries[i].pos >= 0 {
positions = append(positions, store.entries[i].pos)
}
}
slices.Sort(positions)
want := []int32{1, 3, 6}
if !slices.Equal(positions, want) {
t.Fatalf("unexpected shifted positions: got=%v want=%v", positions, want)
}
if store.lastPos != 6 {
t.Fatalf("expected lastPos 6, got %d", store.lastPos)
}
}
func TestCacheRemoveMiddleShiftsCheckpoints(t *testing.T) {
cache := newTestCache()
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: 0, pos: 1}
store := cache.checkpointStore(0)
store.record(1)
store.record(4)
store.record(7)
store.record(10)
if err := cache.Remove(1, 2, 6); err != nil {
t.Fatalf("expected middle remove to succeed, got %v", err)
}
if _, ok := cache.pendingRestore[1]; ok {
t.Fatalf("expected pending restore to be cleared after middle remove")
}
var positions []int32
for i := range store.entries {
if store.entries[i].pos >= 0 {
positions = append(positions, store.entries[i].pos)
}
}
slices.Sort(positions)
want := []int32{1, 3, 6}
if !slices.Equal(positions, want) {
t.Fatalf("unexpected checkpoint positions after remove: got=%v want=%v", positions, want)
}
}
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.entries[0].conv = make(map[int]ml.Tensor)
store.entries[0].conv[0] = nil
store.entries[0].recurrent = make(map[int]ml.Tensor)
store.entries[0].recurrent[0] = nil
store.record(40)
if store.entries[0].conv == nil {
t.Fatalf("expected conv map to be preserved on reuse")
}
if store.entries[0].recurrent == nil {
t.Fatalf("expected recurrent map to be preserved on reuse")
}
if store.entries[0].pos != 40 {
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
}
}
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
store := newSlotCheckpointStore(2)
idx1 := store.record(10)
idx2 := store.record(20)
if idx1 != 0 || idx2 != 1 {
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
}
if store.size != 2 {
t.Fatalf("expected size 2, got %d", store.size)
}
_, pos1, ok1 := store.bestIndex(15)
_, pos2, ok2 := store.bestIndex(25)
if !ok1 || pos1 != 10 {
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
}
if !ok2 || pos2 != 20 {
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
}
}
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
store := newSlotCheckpointStore(0)
idx := store.record(10)
if idx != -1 {
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
}
_, _, ok := store.bestIndex(15)
if ok {
t.Fatalf("expected no checkpoint for empty buffer")
}
}
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.pruneAfter(5)
if store.size != 0 {
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
}
if store.lastPos != -1 {
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
}
_, _, ok := store.bestIndex(100)
if ok {
t.Fatalf("expected no checkpoint after pruning all")
}
}

View File

@@ -0,0 +1,37 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
Date: Sun, 22 Feb 2026 14:12:30 -0800
Subject: [PATCH] ggml-metal: guard mul_mat_id map0 and add ne20=22
specialization
---
ggml/src/ggml-metal/ggml-metal-ops.cpp | 3 ++-
ggml/src/ggml-metal/ggml-metal.metal | 1 +
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index 4ac135603..ac5ad53db 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
// ne21 = n_rows (batch size)
const int ne21_mm_id_min = 32;
- if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
+ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
+ (ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
//switch (op->src[0]->type) {
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index c37447a10..4f338aa13 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
+template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(

View File

@@ -163,6 +163,7 @@ type Tensor interface {
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
SSMConv(ctx Context, kernel Tensor) Tensor
SSMScan(ctx Context, x, dt, A, B, C, ids Tensor) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor

View File

@@ -1662,6 +1662,13 @@ func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) SSMScan(ctx ml.Context, x, dt, A, B, C, ids ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_ssm_scan(ctx.(*Context).ctx, t.t, x.(*Tensor).t, dt.(*Tensor).t, A.(*Tensor).t, B.(*Tensor).t, C.(*Tensor).t, ids.(*Tensor).t),
}
}
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{
b: t.b,

View File

@@ -12249,6 +12249,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(

View File

@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
// ne21 = n_rows (batch size)
const int ne21_mm_id_min = 32;
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
(ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
//switch (op->src[0]->type) {

View File

@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(

View File

@@ -67,6 +67,7 @@ func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
func (f *fakeTensor) SSMScan(ctx ml.Context, _, _, _, _, _, _ ml.Tensor) ml.Tensor { return f }
func (m *fakeBackend) Get(name string) ml.Tensor {
if slices.Contains(m.names, name) {

View File

@@ -1,410 +1,44 @@
package lfm2
import (
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
var _ kvcache.Cache = (*HybridCache)(nil)
var (
_ kvcache.Cache = (*HybridCache)(nil)
_ kvcache.CheckpointCache = (*HybridCache)(nil)
)
// HybridCache stores:
// - a standard causal KV cache for attention layers
// - a per-sequence recurrent conv state for shortconv layers
// HybridCache adapts the shared recurrent cache for LFM2:
// - KV attention cache is handled by the embedded causal cache
// - shortconv recurrent state uses conv slots [dConv, hiddenSize]
//
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
// This reuses shared checkpoint/restore logic for prefix mismatch recovery.
type HybridCache struct {
kv *kvcache.Causal
backend ml.Backend
dtype ml.DType
maxSequences int
hiddenSize int
dConv int
// slot mapping for recurrent state
slotForSeq map[int]int
refCount []int
freeSlots []int
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
// track any error from EnsureWritable to propagate later
writableError error
*kvcache.Recurrent
}
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
return &HybridCache{
kv: kvcache.NewCausalCache(shift),
hiddenSize: hiddenSize,
dConv: dConv,
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
}
}
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
Shift: shift,
ConvDim: dConv,
ConvChannels: hiddenSize,
RecurrentStateSize: 1, // LFM2 uses only conv state; keep a minimal recurrent buffer size.
CheckpointLogPrefix: "lfm2",
})
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *HybridCache) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *HybridCache) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
// Derive equal-length sequence layout for shortconv.
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
seqCounts := make(map[int]int)
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if _, ok := seqCounts[s]; !ok {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
if len(c.curSeqs) == 0 {
return nil
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
}
}
c.curSeqTokens = want
// When reserving memory for estimation, use fake slot assignments
// without modifying permanent state (slotForSeq, refCount)
if reserve {
c.curSlots = c.curSlots[:0]
slots := make([]int32, nSeqs)
for i := range nSeqs {
c.curSlots = append(c.curSlots, i)
slots[i] = int32(i)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
// Ensure slots exist for sequences in this batch
c.curSlots = c.curSlots[:0]
var newSlots []int // track newly allocated slots that need zeroing
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
c.curSlots = append(c.curSlots, slot)
}
// Zero conv state for newly allocated slots to clear stale data from previous sequences
if len(newSlots) > 0 {
c.zeroConvSlots(ctx, newSlots)
}
// Create a tensor for the current slots
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
// Reset writable state for new forward pass
c.writableEnsured = false
c.writableError = nil
return nil
}
func (c *HybridCache) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, kvcache.ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *HybridCache) freeSlot(slot int) {
// Bounds check before freeing
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// zeroConvSlots zeros the conv state for the given slots across all layers.
// This must be called when recycling slots to prevent stale state from affecting new sequences.
func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) {
if len(slots) == 0 || len(c.convStates) == 0 {
return
}
// Use input context for creating tensors
inputCtx := ctx.Input()
// Create slot indices tensor
slotIndices := make([]int32, len(slots))
for i, s := range slots {
slotIndices[i] = int32(s)
}
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
// Create zero tensor for the slots (SetRows requires F32 source)
zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots))
// Zero each layer's conv state for these slots
for _, buf := range c.convStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
// Returns an error if slot allocation fails.
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
// Copy existing conv state for all initialized layers
for _, buf := range c.convStates {
// buf: [dConv*hiddenSize, maxSlots]
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
// SetRows requires F32 source
srcF32 := src.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
}
}
// Rebuild current slots tensor
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
// On the first write to dst, EnsureWritable will create a private slot.
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
// Bounds check before decrementing
if dstSlot >= 0 && dstSlot < len(c.refCount) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
// src may not have a slot yet; dst will allocate on demand
return
}
// Bounds check before incrementing
if srcSlot >= 0 && srcSlot < len(c.refCount) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *HybridCache) CanResume(seq int, pos int32) bool {
return c.kv.CanResume(seq, pos)
}
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
// For recurrent state, any removal invalidates the state because
// the state at position N depends on all previous positions.
// Drop the slot mapping so it resets on next use.
slot, ok := c.slotForSeq[seq]
if !ok {
return nil
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
return &HybridCache{Recurrent: base}
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.curSlotsInput
return c.SlotsTensor()
}
func (c *HybridCache) seqTokens() int {
return c.curSeqTokens
return c.SeqTokens()
}
func (c *HybridCache) numSeqs() int {
return len(c.curSeqs)
}
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
c.convStates[layer] = buf
return buf
}
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
// Returns an error if copy-on-write allocation fails.
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
}
// UpdateConvState writes a new conv state for current batch sequences.
// newState must have shape [dConv, hiddenSize, nSeqs].
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
// SetRows requires F32 source
srcF32 := src.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
}
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *HybridCache) Seqs() []int {
return slices.Clone(c.curSeqs)
return c.NumSeqs()
}

View File

@@ -4,441 +4,39 @@ import (
"testing"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
// TestHybridCache tests verify the slot management logic of HybridCache.
// These tests focus on the recurrent state slot allocation, reference counting,
// and copy-on-write semantics without requiring a full ML backend.
func TestHybridCache_New(t *testing.T) {
cache := NewHybridCache(nil, 512, 2)
if cache == nil {
t.Fatal("expected cache to be created")
}
// createSlotOnlyCache creates a HybridCache with only the slot management
// fields initialized. Used to test slot logic in isolation.
func createSlotOnlyCache(maxSequences int) *HybridCache {
return &HybridCache{
hiddenSize: 256,
dConv: 3,
maxSequences: maxSequences,
refCount: make([]int, maxSequences),
freeSlots: initFreeSlots(maxSequences),
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
if cache.Recurrent == nil {
t.Fatal("expected embedded recurrent cache to be created")
}
}
func initFreeSlots(n int) []int {
slots := make([]int, 0, n)
for i := n - 1; i >= 0; i-- {
slots = append(slots, i)
}
return slots
}
func TestHybridCache_ImplementsCheckpointCache(t *testing.T) {
cache := NewHybridCache(nil, 512, 2)
func TestHybridCache_SlotAllocation(t *testing.T) {
cache := createSlotOnlyCache(4)
// Verify initial state
if len(cache.freeSlots) != 4 {
t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots))
}
// Allocate all slots
for range 4 {
slot, err := cache.allocSlot()
if err != nil {
t.Fatalf("allocSlot failed: %v", err)
}
cache.refCount[slot] = 1
}
// Should be full now
if len(cache.freeSlots) != 0 {
t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots))
}
// Trying to allocate another should fail
_, err := cache.allocSlot()
if err != kvcache.ErrKvCacheFull {
t.Errorf("expected ErrKvCacheFull, got %v", err)
if _, ok := any(cache).(kvcache.CheckpointCache); !ok {
t.Fatal("expected HybridCache to implement CheckpointCache")
}
}
func TestHybridCache_SlotReuse(t *testing.T) {
cache := createSlotOnlyCache(4)
func TestHybridCache_DefaultBatchState(t *testing.T) {
cache := NewHybridCache(nil, 512, 2)
// Allocate a slot
slot1, _ := cache.allocSlot()
cache.refCount[slot1] = 1
// Free it
cache.refCount[slot1] = 0
cache.freeSlot(slot1)
// Allocate again - should get the same slot back (LIFO)
slot2, _ := cache.allocSlot()
if slot2 != slot1 {
t.Errorf("expected slot %d to be reused, got %d", slot1, slot2)
}
}
func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Simulate sharing slot with seq 2 (copy-on-write style)
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Should share the same slot
if cache.slotForSeq[2] != slot1 {
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
if got := cache.numSeqs(); got != 0 {
t.Fatalf("expected 0 sequences before StartForward, got %d", got)
}
// Ref count should be 2
if cache.refCount[slot1] != 2 {
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Share with seq 2
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Unshare seq 2
cache.refCount[slot1]--
delete(cache.slotForSeq, 2)
// Ref count should be back to 1
if cache.refCount[slot1] != 1 {
t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1])
if got := cache.seqTokens(); got != 0 {
t.Fatalf("expected 0 sequence tokens before StartForward, got %d", got)
}
// Seq 2 should no longer have a slot
if _, ok := cache.slotForSeq[2]; ok {
t.Error("seq 2 should not have a slot after unshare")
}
}
func TestHybridCache_SlotFreeWhenUnused(t *testing.T) {
cache := createSlotOnlyCache(4)
initialFreeSlots := len(cache.freeSlots)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Free the slot when refCount drops to 0
cache.refCount[slot1]--
if cache.refCount[slot1] <= 0 {
cache.refCount[slot1] = 0
cache.freeSlot(slot1)
}
delete(cache.slotForSeq, 1)
// Slot should be freed
if len(cache.freeSlots) != initialFreeSlots {
t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots))
}
// Ref count should be 0
if cache.refCount[slot1] != 0 {
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_SlotOverwrite(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slots for seq 1 and seq 2
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
slot2, _ := cache.allocSlot()
cache.slotForSeq[2] = slot2
cache.refCount[slot2] = 1
initialFreeSlots := len(cache.freeSlots)
// Simulate overwriting seq 2's slot with slot1 (sharing)
// First free the old slot
cache.refCount[slot2]--
if cache.refCount[slot2] <= 0 {
cache.refCount[slot2] = 0
cache.freeSlot(slot2)
}
// Then share slot1
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Seq 2 should now share slot1
if cache.slotForSeq[2] != slot1 {
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
}
// Old slot2 should be freed
if len(cache.freeSlots) != initialFreeSlots+1 {
t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots))
}
}
func TestHybridCache_BoundsChecking(t *testing.T) {
cache := createSlotOnlyCache(4)
// Test freeing invalid slot (should not panic)
cache.freeSlot(-1)
cache.freeSlot(100) // out of bounds
// freeSlot does bounds checking, so invalid slots should be ignored
if len(cache.freeSlots) != 4 {
t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots))
}
}
func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) {
cache := createSlotOnlyCache(8)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Fork to seq 2, 3, 4 (all share slot1)
for _, seq := range []int{2, 3, 4} {
cache.slotForSeq[seq] = slot1
cache.refCount[slot1]++
}
// Ref count should be 4
if cache.refCount[slot1] != 4 {
t.Errorf("expected refCount 4, got %d", cache.refCount[slot1])
}
// Remove seq 2, 3
for _, seq := range []int{2, 3} {
delete(cache.slotForSeq, seq)
cache.refCount[slot1]--
}
if cache.refCount[slot1] != 2 {
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
}
// Slot should still be allocated (not in free list)
found := false
for _, s := range cache.freeSlots {
if s == slot1 {
found = true
break
}
}
if found {
t.Error("slot1 should not be in free list yet")
}
// Remove remaining sequences
for _, seq := range []int{1, 4} {
delete(cache.slotForSeq, seq)
cache.refCount[slot1]--
}
if cache.refCount[slot1] != 0 {
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_ChainedSharing(t *testing.T) {
cache := createSlotOnlyCache(8)
// Create seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Share 1 -> 2
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Share 2 -> 3 (should still share slot1)
cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1
cache.refCount[slot1]++
// All should share slot1
if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 {
t.Error("all sequences should share slot1")
}
if cache.refCount[slot1] != 3 {
t.Errorf("expected refCount 3, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_CacheParameters(t *testing.T) {
cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5
if cache.hiddenSize != 512 {
t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize)
}
if cache.dConv != 5 {
t.Errorf("expected dConv 5, got %d", cache.dConv)
}
}
func TestHybridCache_NumSeqs(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially no sequences
if cache.numSeqs() != 0 {
t.Errorf("expected 0 seqs, got %d", cache.numSeqs())
}
// Manually set up current batch state
cache.curSeqs = []int{1, 2, 3}
if cache.numSeqs() != 3 {
t.Errorf("expected 3 seqs, got %d", cache.numSeqs())
}
}
func TestHybridCache_SeqTokens(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially 0
if cache.seqTokens() != 0 {
t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens())
}
// Manually set up current batch state
cache.curSeqTokens = 16
if cache.seqTokens() != 16 {
t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens())
}
}
// Test that Seqs returns a clone of curSeqs
func TestHybridCache_Seqs_ReturnsClone(t *testing.T) {
cache := createSlotOnlyCache(4)
cache.curSeqs = []int{1, 2, 3}
seqs := cache.Seqs()
// Modify returned slice
seqs[0] = 999
// Original should be unchanged
if cache.curSeqs[0] != 1 {
t.Error("Seqs should return a clone, not the original slice")
}
}
func TestHybridCache_IsSupportedForBatch(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially not supported (no batch set up)
if cache.IsSupportedForBatch() {
t.Error("expected IsSupportedForBatch to be false initially")
}
// Set up a valid batch
cache.curSeqTokens = 1
cache.curSeqs = []int{1}
if !cache.IsSupportedForBatch() {
t.Error("expected IsSupportedForBatch to be true with valid batch")
}
}
func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) {
cache := createSlotOnlyCache(4)
// zeroConvSlots should handle empty slots without panicking
cache.zeroConvSlots(nil, nil)
cache.zeroConvSlots(nil, []int{})
// zeroConvSlots should handle empty convStates without panicking
cache.zeroConvSlots(nil, []int{0, 1, 2})
}
func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Free the slot (simulating sequence removal)
cache.refCount[slot1]--
cache.freeSlot(slot1)
delete(cache.slotForSeq, 1)
// Verify slot is in free list
if len(cache.freeSlots) != 4 {
t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots))
}
// Allocate for new seq 2 - should get recycled slot
slot2, _ := cache.allocSlot()
if slot2 != slot1 {
t.Errorf("expected recycled slot %d, got %d", slot1, slot2)
}
// This recycled slot would need zeroing in the real implementation
// The actual zeroing is tested via integration tests since it requires ML context
}
func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) {
cache := createSlotOnlyCache(4)
// Simulate the slot allocation flow from StartForward
// When a sequence doesn't have a slot, it gets allocated and tracked as "new"
newSlots := []int{}
// Seq 1 doesn't have a slot - allocate and track
seq := 1
if _, ok := cache.slotForSeq[seq]; !ok {
slot, err := cache.allocSlot()
if err != nil {
t.Fatalf("allocSlot failed: %v", err)
}
cache.slotForSeq[seq] = slot
cache.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
// Verify newSlots contains the allocated slot
if len(newSlots) != 1 {
t.Errorf("expected 1 new slot, got %d", len(newSlots))
}
// Seq 1 already has a slot - should NOT be tracked as new
newSlots2 := []int{}
if _, ok := cache.slotForSeq[seq]; !ok {
slot, _ := cache.allocSlot()
cache.slotForSeq[seq] = slot
cache.refCount[slot] = 1
newSlots2 = append(newSlots2, slot)
}
// Verify no new slots for existing sequence
if len(newSlots2) != 0 {
t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2))
t.Fatal("expected unsupported batch layout before StartForward")
}
}

View File

@@ -1,7 +1,11 @@
package lfm2
import (
"bytes"
"cmp"
"errors"
"fmt"
"image"
"math"
"github.com/ollama/ollama/fs"
@@ -25,8 +29,20 @@ type Options struct {
// per-layer head counts (LFM2 alternates attention and recurrent layers)
numHeadsByLayer []int
numKVHeadsByLayer []int
// MoE config
numExperts int
numExpertsUsed int
normTopKProb bool
expertWeightsScale float32
expertGatingFunc uint32
}
const (
expertGatingFuncSoftmax = uint32(0)
expertGatingFuncSigmoid = uint32(2)
)
func (o Options) headDimValue() int {
// Head dim is shared across layers; fall back to first attention layer head count.
for _, h := range o.numHeadsByLayer {
@@ -67,18 +83,138 @@ type Model struct {
OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
VisionModel *VisionModel `gguf:"v"`
VisionProjector *VisionProjector `gguf:"mm"`
ImageProcessor ImageProcessor
imageTokenID int32
imageStartToken int32
imageEndToken int32
imageThumbnailID int32
imageRowColIDs map[imageGridPos]int32
useSpecialTokens bool
projectorOptions VisionProjectorOptions
Options
}
func New(c fs.Config) (model.Model, error) {
if c.Uint("expert_count") > 0 {
return nil, model.ErrUnsupportedModel
var _ model.MultimodalProcessor = (*Model)(nil)
type imageGridPos struct {
row int
col int
}
type visionEmbeddingLayout struct {
rows int
cols int
hasThumbnail bool
}
type visionChunkData struct {
tokens int
row int
col int
thumbnail bool
layout *visionEmbeddingLayout
}
func (m *Model) Validate() error {
if m.TokenEmbedding == nil {
return errors.New("lfm2: missing token_embd tensor")
}
if m.OutputNorm == nil {
return errors.New("lfm2: missing output_norm tensor")
}
if m.Output == nil {
return errors.New("lfm2: missing output tensor")
}
for i, layer := range m.Layers {
if layer.AttentionNorm == nil {
return fmt.Errorf("lfm2: missing blk.%d.attn_norm tensor", i)
}
if layer.MLPNorm == nil {
return fmt.Errorf("lfm2: missing blk.%d.ffn_norm tensor", i)
}
switch ff := layer.MLP.(type) {
case nil:
return fmt.Errorf("lfm2: missing blk.%d feed-forward tensors", i)
case *denseMLP:
if ff.Up == nil || ff.Down == nil || ff.Gate == nil {
return fmt.Errorf("lfm2: missing blk.%d dense feed-forward tensors", i)
}
case *sparseMLP:
if ff.Router == nil || ff.Gate == nil || ff.Up == nil || ff.Down == nil {
return fmt.Errorf("lfm2: missing blk.%d sparse feed-forward tensors", i)
}
default:
return fmt.Errorf("lfm2: unsupported feed-forward type at blk.%d", i)
}
switch op := layer.Operator.(type) {
case *Attention:
if op == nil || op.Query == nil || op.Key == nil || op.Value == nil || op.Output == nil || op.QueryNorm == nil || op.KeyNorm == nil {
return fmt.Errorf("lfm2: missing blk.%d attention tensors", i)
}
case *ShortConv:
if op == nil || op.Conv == nil || op.Conv.Weight == nil || op.InProj == nil || op.OutProj == nil {
return fmt.Errorf("lfm2: missing blk.%d shortconv tensors", i)
}
default:
return fmt.Errorf("lfm2: unsupported operator at blk.%d", i)
}
}
if m.VisionModel != nil {
if m.VisionModel.PatchEmbedding == nil {
return errors.New("lfm2: missing vision patch embedding tensors")
}
if m.VisionModel.PositionEmbedding == nil {
return errors.New("lfm2: missing vision position embedding tensors")
}
if m.VisionModel.PostLayerNorm == nil {
return errors.New("lfm2: missing vision post layer norm tensors")
}
if len(m.VisionModel.Layers) == 0 {
return errors.New("lfm2: missing vision encoder layers")
}
for i, layer := range m.VisionModel.Layers {
if layer.LayerNorm1 == nil || layer.LayerNorm2 == nil || layer.SelfAttention == nil || layer.MLP == nil {
return fmt.Errorf("lfm2: missing vision layer tensors at v.blk.%d", i)
}
if layer.SelfAttention.Query == nil || layer.SelfAttention.Key == nil || layer.SelfAttention.Value == nil || layer.SelfAttention.Output == nil {
return fmt.Errorf("lfm2: missing vision attention tensors at v.blk.%d", i)
}
if layer.MLP.Up == nil || layer.MLP.Down == nil {
return fmt.Errorf("lfm2: missing vision feed-forward tensors at v.blk.%d", i)
}
}
if m.VisionProjector == nil || m.VisionProjector.Linear1 == nil || m.VisionProjector.Linear2 == nil {
return errors.New("lfm2: missing multimodal projector tensors")
}
}
return nil
}
func New(c fs.Config) (model.Model, error) {
if c.String("tokenizer.ggml.model") != "gpt2" {
return nil, model.ErrUnsupportedTokenizer
}
numExperts := int(c.Uint("expert_count"))
isMoE := numExperts > 0
numExpertsUsed := int(c.Uint("expert_used_count"))
if isMoE {
if numExperts <= 0 {
return nil, fmt.Errorf("lfm2: invalid expert_count=%d", numExperts)
}
if numExpertsUsed <= 0 || numExpertsUsed > numExperts {
return nil, fmt.Errorf("lfm2: invalid expert_used_count=%d for expert_count=%d", numExpertsUsed, numExperts)
}
}
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -105,8 +241,16 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
VisionProjector: &VisionProjector{},
imageRowColIDs: make(map[imageGridPos]int32),
projectorOptions: VisionProjectorOptions{
scaleFactor: int(c.Uint("vision.projector.scale_factor", 2)),
useLayerNorm: c.Bool("vision.projector.use_layernorm", false),
},
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),
@@ -116,9 +260,66 @@ func New(c fs.Config) (model.Model, error) {
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
numExperts: numExperts,
numExpertsUsed: numExpertsUsed,
normTopKProb: c.Bool("norm_top_k_prob", true),
expertWeightsScale: c.Float("expert_weights_scale", 1.0),
expertGatingFunc: c.Uint("expert_gating_func", expertGatingFuncSoftmax),
},
}
lookupTokenID := func(token string) int32 {
for i, t := range vocabulary.Values {
if t == token {
return int32(i)
}
}
return 0
}
resolveTokenID := func(explicitKey, token string, fallback uint32) int32 {
if explicitKey != "" {
if id := c.Uint(explicitKey); id != 0 {
return int32(id)
}
}
if tokenID := lookupTokenID(token); tokenID != 0 {
return tokenID
}
return int32(fallback)
}
m.imageTokenID = resolveTokenID("vision.image_token_id", "<image>", 396)
m.imageStartToken = resolveTokenID("vision.image_start_token_id", "<|image_start|>", 0)
m.imageEndToken = resolveTokenID("vision.image_end_token_id", "<|image_end|>", 0)
m.imageThumbnailID = resolveTokenID("vision.image_thumbnail_token_id", "<|img_thumbnail|>", 0)
m.useSpecialTokens = c.Bool("vision.use_image_special_tokens", true)
maxGridTokens := int(c.Uint("vision.max_tiles", 10))
if maxGridTokens <= 0 {
maxGridTokens = 10
}
for row := 1; row <= maxGridTokens; row++ {
for col := 1; col <= maxGridTokens; col++ {
token := fmt.Sprintf("<|img_row_%d_col_%d|>", row, col)
if tokenID := lookupTokenID(token); tokenID > 0 {
m.imageRowColIDs[imageGridPos{row: row, col: col}] = tokenID
}
}
}
if !m.useSpecialTokens {
m.imageStartToken = 0
m.imageEndToken = 0
m.imageThumbnailID = 0
m.imageRowColIDs = map[imageGridPos]int32{}
}
if c.Uint("vision.block_count") == 0 {
m.VisionModel = nil
m.VisionProjector = nil
}
type headCounts interface {
HeadCount() []uint64
HeadCountKV() []uint64
@@ -133,6 +334,14 @@ func New(c fs.Config) (model.Model, error) {
m.numHeadsByLayer = make([]int, len(m.Layers))
m.numKVHeadsByLayer = make([]int, len(m.Layers))
leadingDenseBlockCount := int(c.Uint("leading_dense_block_count"))
if leadingDenseBlockCount < 0 {
leadingDenseBlockCount = 0
}
if leadingDenseBlockCount > len(m.Layers) {
leadingDenseBlockCount = len(m.Layers)
}
for i := range m.Layers {
m.numHeadsByLayer[i] = int(headCount[i])
m.numKVHeadsByLayer[i] = int(headCountKV[i])
@@ -142,6 +351,12 @@ func New(c fs.Config) (model.Model, error) {
} else {
m.Layers[i].Operator = &Attention{}
}
if isMoE && i >= leadingDenseBlockCount {
m.Layers[i].MLP = &sparseMLP{}
} else {
m.Layers[i].MLP = &denseMLP{}
}
}
lCache := int(c.Uint("shortconv.l_cache"))
@@ -188,22 +403,77 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
return sa.Output.Forward(ctx, attention)
}
type MLP struct {
type FeedForward interface {
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
}
type denseMLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
func (mlp *denseMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type sparseMLP struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
Bias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
}
func (mlp *sparseMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
// hiddenState: [hidden, tokens]
routerLogits := mlp.Router.Forward(ctx, hiddenState)
probs := routerLogits.Softmax(ctx)
if opts.expertGatingFunc == expertGatingFuncSigmoid {
probs = routerLogits.Sigmoid(ctx)
}
selectionProbs := probs
if mlp.Bias != nil {
selectionProbs = selectionProbs.Add(ctx, mlp.Bias)
}
selectedExperts := selectionProbs.TopK(ctx, opts.numExpertsUsed)
routingWeights := probs.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
if opts.normTopKProb {
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
weightsSum := routingWeights.SumRows(ctx)
weightsSum = weightsSum.Clamp(ctx, 1e-6, float32(math.Inf(1)))
routingWeights = routingWeights.Div(ctx, weightsSum)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
}
if opts.expertWeightsScale != 1 {
routingWeights = routingWeights.Scale(ctx, float64(opts.expertWeightsScale))
}
// Build routing-weights branch early to enable topk-MoE fusion.
ctx.Forward(routingWeights)
hiddenState3D := hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
experts := mlp.Gate.Forward(ctx, hiddenState3D, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenState3D, selectedExperts))
experts = mlp.Down.Forward(ctx, experts, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
nextState := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextState = nextState.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextState
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Operator Operator
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
MLP FeedForward
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor {
@@ -229,10 +499,233 @@ func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tenso
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
func multimodalTokenCount(mm input.Multimodal) int {
if mm.Tensor != nil {
return mm.Tensor.Dim(1)
}
switch data := mm.Data.(type) {
case int:
return data
case int32:
return int(data)
case visionChunkData:
return data.tokens
case *visionChunkData:
if data != nil {
return data.tokens
}
}
return 0
}
func multimodalChunkInfo(mm input.Multimodal) visionChunkData {
switch data := mm.Data.(type) {
case visionChunkData:
return data
case *visionChunkData:
if data != nil {
return *data
}
}
return visionChunkData{
tokens: multimodalTokenCount(mm),
}
}
func multimodalLayout(mm []input.Multimodal) visionEmbeddingLayout {
layout := visionEmbeddingLayout{rows: 1, cols: 1}
if len(mm) == 0 {
return layout
}
first := multimodalChunkInfo(mm[0])
if first.layout != nil {
return *first.layout
}
return layout
}
func (m *Model) imageRowColToken(row, col int) int32 {
if row <= 0 || col <= 0 {
return 0
}
return m.imageRowColIDs[imageGridPos{row: row, col: col}]
}
func (m *Model) appendImageChunk(result []*input.Input, chunk input.Multimodal, imageToken int32, hash uint64) ([]*input.Input, error) {
tokenCount := multimodalTokenCount(chunk)
if tokenCount <= 0 {
return nil, errors.New("lfm2: multimodal input has no tokens")
}
result = append(result, &input.Input{
Token: imageToken,
Multimodal: []input.Multimodal{chunk},
MultimodalHash: hash,
SameBatch: tokenCount - 1,
})
for range tokenCount - 1 {
result = append(result, &input.Input{Token: imageToken})
}
return result, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if m.VisionModel == nil || m.VisionProjector == nil || len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
processedImages, layout, err := m.ImageProcessor.ProcessImage(img)
if err != nil {
return nil, err
}
if m.ImageProcessor.patchSize <= 0 {
return nil, errors.New("lfm2: invalid vision patch size")
}
layoutInfo := &visionEmbeddingLayout{
rows: layout.rows,
cols: layout.cols,
hasThumbnail: layout.hasThumbnail,
}
mm := make([]input.Multimodal, 0, len(processedImages))
for i, processed := range processedImages {
patches := visionPatchGrid{
Width: processed.size.X / m.ImageProcessor.patchSize,
Height: processed.size.Y / m.ImageProcessor.patchSize,
}
if patches.Width == 0 || patches.Height == 0 {
return nil, errors.New("lfm2: invalid resized image dimensions")
}
pixelValues := ctx.Input().FromFloats(processed.data, processed.size.X, processed.size.Y, m.ImageProcessor.numChannels)
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, patches)
projected := m.VisionProjector.Forward(ctx, visionOutputs, patches, m.projectorOptions)
chunk := visionChunkData{
tokens: projected.Dim(1),
row: processed.row,
col: processed.col,
thumbnail: processed.thumbnail,
}
if i == 0 {
chunk.layout = layoutInfo
}
mm = append(mm, input.Multimodal{
Tensor: projected,
Data: chunk,
})
}
return mm, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
imageToken := m.imageTokenID
if imageToken == 0 {
imageToken = 396
}
useSpecialTokens := m.useSpecialTokens || m.imageStartToken > 0 || m.imageEndToken > 0 || m.imageThumbnailID > 0 || len(m.imageRowColIDs) > 0
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
continue
}
layout := multimodalLayout(inp.Multimodal)
if layout.rows <= 0 {
layout.rows = 1
}
if layout.cols <= 0 {
layout.cols = 1
}
tiles := layout.rows * layout.cols
multitile := tiles > 1
if useSpecialTokens && m.imageStartToken > 0 {
result = append(result, &input.Input{Token: m.imageStartToken})
}
for i, mm := range inp.Multimodal {
chunk := multimodalChunkInfo(mm)
if chunk.tokens <= 0 {
chunk.tokens = multimodalTokenCount(mm)
}
if multitile && !chunk.thumbnail && chunk.row == 0 && chunk.col == 0 && i < tiles {
chunk.row = i/layout.cols + 1
chunk.col = i%layout.cols + 1
}
if multitile && layout.hasThumbnail && i == tiles {
chunk.thumbnail = true
}
if useSpecialTokens && multitile {
if chunk.thumbnail {
if m.imageThumbnailID > 0 {
result = append(result, &input.Input{Token: m.imageThumbnailID})
}
} else if marker := m.imageRowColToken(chunk.row, chunk.col); marker > 0 {
result = append(result, &input.Input{Token: marker})
}
}
var err error
result, err = m.appendImageChunk(result, input.Multimodal{
Tensor: mm.Tensor,
Data: chunk,
}, imageToken, inp.MultimodalHash)
if err != nil {
return nil, err
}
}
if useSpecialTokens && m.imageEndToken > 0 {
result = append(result, &input.Input{Token: m.imageEndToken})
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
if len(batch.Multimodal) > 0 {
// We splice vision embeddings into token embeddings in-place; duplicate to
// avoid aliasing the raw embedding output graph.
hiddenState = hiddenState.Duplicate(ctx)
}
for _, mm := range batch.Multimodal {
offset := mm.Index
for _, multimodal := range mm.Multimodal {
if multimodal.Tensor == nil {
continue
}
visionOutputs := multimodal.Tensor
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, offset*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
offset += visionOutputs.Dim(1)
}
}
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
@@ -251,4 +744,5 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func init() {
model.Register("lfm2", New)
model.Register("lfm2moe", New)
}

View File

@@ -0,0 +1,160 @@
package lfm2
import (
"testing"
"github.com/ollama/ollama/model/input"
)
func TestPostTokenizeWithSpecialImageTokens(t *testing.T) {
m := &Model{
imageTokenID: 396,
imageStartToken: 2,
imageEndToken: 3,
useSpecialTokens: true,
}
in := []*input.Input{
{Token: 11},
{Multimodal: []input.Multimodal{{Data: 64}}, MultimodalHash: 123},
{Token: 12},
}
out, err := m.PostTokenize(in)
if err != nil {
t.Fatalf("PostTokenize returned error: %v", err)
}
if len(out) != 68 {
t.Fatalf("expected 68 tokens, got %d", len(out))
}
if out[0].Token != 11 {
t.Fatalf("out[0].Token = %d, want 11", out[0].Token)
}
if out[1].Token != 2 {
t.Fatalf("out[1].Token = %d, want 2", out[1].Token)
}
firstImage := out[2]
if firstImage.Token != 396 {
t.Fatalf("out[2].Token = %d, want 396", firstImage.Token)
}
if len(firstImage.Multimodal) != 1 {
t.Fatalf("expected multimodal payload on first image token")
}
if firstImage.MultimodalHash != 123 {
t.Fatalf("out[2].MultimodalHash = %d, want 123", firstImage.MultimodalHash)
}
if firstImage.SameBatch != 63 {
t.Fatalf("out[2].SameBatch = %d, want 63", firstImage.SameBatch)
}
for i := 3; i < 66; i++ {
if out[i].Token != 396 {
t.Fatalf("out[%d].Token = %d, want 396", i, out[i].Token)
}
if len(out[i].Multimodal) != 0 {
t.Fatalf("out[%d] should not carry multimodal payload", i)
}
}
if out[66].Token != 3 {
t.Fatalf("out[66].Token = %d, want 3", out[66].Token)
}
if out[67].Token != 12 {
t.Fatalf("out[67].Token = %d, want 12", out[67].Token)
}
}
func TestPostTokenizeWithoutSpecialImageTokens(t *testing.T) {
m := &Model{
imageTokenID: 777,
useSpecialTokens: false,
}
in := []*input.Input{
{Multimodal: []input.Multimodal{{Data: 5}}, MultimodalHash: 9},
}
out, err := m.PostTokenize(in)
if err != nil {
t.Fatalf("PostTokenize returned error: %v", err)
}
if len(out) != 5 {
t.Fatalf("expected 5 tokens, got %d", len(out))
}
if out[0].Token != 777 || out[0].SameBatch != 4 || len(out[0].Multimodal) != 1 {
t.Fatalf("unexpected first token: %+v", *out[0])
}
for i := 1; i < 5; i++ {
if out[i].Token != 777 {
t.Fatalf("out[%d].Token = %d, want 777", i, out[i].Token)
}
if len(out[i].Multimodal) != 0 {
t.Fatalf("out[%d] should not carry multimodal payload", i)
}
}
}
func TestPostTokenizeMultiTileLayoutTokens(t *testing.T) {
m := &Model{
imageTokenID: 396,
imageStartToken: 498,
imageEndToken: 499,
imageThumbnailID: 497,
imageRowColIDs: map[imageGridPos]int32{
{row: 1, col: 1}: 397,
{row: 1, col: 2}: 398,
},
useSpecialTokens: true,
}
layout := &visionEmbeddingLayout{rows: 1, cols: 2, hasThumbnail: true}
in := []*input.Input{{
Multimodal: []input.Multimodal{
{Data: visionChunkData{tokens: 3, row: 1, col: 1, layout: layout}},
{Data: visionChunkData{tokens: 3, row: 1, col: 2}},
{Data: visionChunkData{tokens: 2, thumbnail: true}},
},
MultimodalHash: 1,
}}
out, err := m.PostTokenize(in)
if err != nil {
t.Fatalf("PostTokenize returned error: %v", err)
}
got := make([]int32, len(out))
for i := range out {
got[i] = out[i].Token
}
want := []int32{
498, // <|image_start|>
397, // <|img_row_1_col_1|>
396, 396, 396,
398, // <|img_row_1_col_2|>
396, 396, 396,
497, // <|img_thumbnail|>
396, 396,
499, // <|image_end|>
}
if len(got) != len(want) {
t.Fatalf("len(out) = %d, want %d", len(got), len(want))
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("out[%d].Token = %d, want %d", i, got[i], want[i])
}
}
if len(out[2].Multimodal) != 1 || len(out[6].Multimodal) != 1 || len(out[10].Multimodal) != 1 {
t.Fatalf("expected multimodal payload on first token of each chunk")
}
if out[2].SameBatch != 2 || out[6].SameBatch != 2 || out[10].SameBatch != 1 {
t.Fatalf("unexpected SameBatch values: [%d %d %d]", out[2].SameBatch, out[6].SameBatch, out[10].SameBatch)
}
}

View File

@@ -0,0 +1,184 @@
package lfm2
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
const lfm2VisionBatchSize = 1
type visionPatchGrid struct {
Width int
Height int
}
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output,alt:attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), lfm2VisionBatchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), lfm2VisionBatchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), lfm2VisionBatchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), lfm2VisionBatchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
return mlp.Down.Forward(ctx, mlp.Up.Forward(ctx, hiddenState).GELU(ctx))
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"ln2"`
MLP *VisionMLP
}
func (l *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState)
return hiddenState.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize, numHeads int
imageSize, patchSize int
eps float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
PostLayerNorm *nn.LayerNorm `gguf:"post_ln"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, patches visionPatchGrid) ml.Tensor {
numPatches := patches.Width * patches.Height
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
if m.PositionEmbedding != nil {
posTokens := m.PositionEmbedding.Weight.Dim(1)
source := int(math.Sqrt(float64(posTokens)))
var positionEmbeddings ml.Tensor
if source > 0 && source*source == posTokens && (source != patches.Width || source != patches.Height) {
// SigLIP2 NAFlex-style position interpolation for variable image sizes.
positionIDs := ctx.Arange(0, float32(posTokens), 1, ml.DTypeI32)
positionEmbeddings = m.PositionEmbedding.Forward(ctx, positionIDs)
positionEmbeddings = positionEmbeddings.Reshape(ctx, -1, source, source)
positionEmbeddings = positionEmbeddings.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
positionEmbeddings = positionEmbeddings.Interpolate(ctx, [4]int{
patches.Width,
patches.Height,
hiddenState.Dim(0),
1,
}, ml.SamplingModeBilinear)
positionEmbeddings = positionEmbeddings.Permute(ctx, 1, 2, 0, 3)
positionEmbeddings = positionEmbeddings.Contiguous(ctx, -1, patches.Width*patches.Height)
} else {
positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeI32)
positionEmbeddings = m.PositionEmbedding.Forward(ctx, positionIDs)
}
hiddenState = hiddenState.Add(ctx, positionEmbeddings)
}
for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
}
return m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length", 1152)),
numHeads: int(c.Uint("vision.attention.head_count", 16)),
imageSize: int(c.Uint("vision.image_size", 256)),
patchSize: int(c.Uint("vision.patch_size", 16)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
},
}
}
type VisionProjector struct {
LayerNorm *nn.LayerNorm `gguf:"layer_norm"`
Linear1 *nn.Linear `gguf:"1"`
Linear2 *nn.Linear `gguf:"2"`
}
type VisionProjectorOptions struct {
scaleFactor int
useLayerNorm bool
}
func (p *VisionProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, patches visionPatchGrid, opts VisionProjectorOptions) ml.Tensor {
hiddenSize := visionOutputs.Dim(0)
featureMap := visionOutputs
merge := max(opts.scaleFactor, 1)
if merge > 1 {
width := patches.Width
height := patches.Height
featureMap = featureMap.Reshape(ctx, hiddenSize, width, height)
// Match llama.cpp patch merger: pad spatial dims to merge factor.
padWidth := (merge - width%merge) % merge
padHeight := (merge - height%merge) % merge
if padWidth != 0 || padHeight != 0 {
featureMap = featureMap.Pad(ctx, 0, padWidth, padHeight, 0)
width += padWidth
height += padHeight
}
featureMap = featureMap.Reshape(ctx, hiddenSize*merge, width/merge, height)
featureMap = featureMap.Permute(ctx, 0, 2, 1).Contiguous(ctx, hiddenSize*merge*merge, height/merge, width/merge)
featureMap = featureMap.Permute(ctx, 0, 2, 1).Contiguous(ctx)
featureMap = featureMap.Contiguous(ctx, featureMap.Dim(0), featureMap.Dim(1)*featureMap.Dim(2))
}
if opts.useLayerNorm && p.LayerNorm != nil {
featureMap = p.LayerNorm.Forward(ctx, featureMap, 1e-5)
}
featureMap = p.Linear1.Forward(ctx, featureMap).GELU(ctx)
return p.Linear2.Forward(ctx, featureMap)
}

View File

@@ -0,0 +1,260 @@
package lfm2
import (
"image"
stdimage "image/draw"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize, patchSize, numChannels int
downsampleFactor int
imageMean, imageStd [3]float32
doImageSplitting bool
minTiles int
maxTiles int
useThumbnail bool
tileSize int
minImageTokens int
maxImageTokens int
maxPixelsTolerance float64
}
type processedVisionImage struct {
data []float32
size image.Point
row int
col int
thumbnail bool
}
type processedVisionLayout struct {
rows int
cols int
hasThumbnail bool
}
func newImageProcessor(c fs.Config) ImageProcessor {
mean := c.Floats("vision.image_mean")
std := c.Floats("vision.image_std")
processor := ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 256)),
patchSize: int(c.Uint("vision.patch_size", 16)),
numChannels: int(c.Uint("vision.num_channels", 3)),
downsampleFactor: int(c.Uint("vision.projector.scale_factor", 2)),
imageMean: [3]float32{0.5, 0.5, 0.5},
imageStd: [3]float32{0.5, 0.5, 0.5},
doImageSplitting: c.Bool("vision.do_image_splitting", true),
minTiles: int(c.Uint("vision.min_tiles", 2)),
maxTiles: int(c.Uint("vision.max_tiles", 10)),
useThumbnail: c.Bool("vision.use_thumbnail", true),
tileSize: int(c.Uint("vision.tile_size", 512)),
minImageTokens: int(c.Uint("vision.min_image_tokens", 64)),
maxImageTokens: int(c.Uint("vision.max_image_tokens", 256)),
maxPixelsTolerance: float64(c.Float("vision.max_pixels_tolerance", 2.0)),
}
if len(mean) >= 3 {
processor.imageMean = [3]float32{mean[0], mean[1], mean[2]}
}
if len(std) >= 3 {
processor.imageStd = [3]float32{std[0], std[1], std[2]}
}
// Keep defaults aligned with HF unless explicitly configured.
if processor.downsampleFactor <= 0 {
processor.downsampleFactor = 2
}
if processor.patchSize <= 0 {
processor.patchSize = 16
}
if processor.tileSize <= 0 {
processor.tileSize = 512
}
if processor.minTiles <= 0 {
processor.minTiles = 2
}
if processor.maxTiles < processor.minTiles {
processor.maxTiles = processor.minTiles
}
if processor.minImageTokens <= 0 {
processor.minImageTokens = 64
}
if processor.maxImageTokens < processor.minImageTokens {
processor.maxImageTokens = processor.minImageTokens
}
if processor.maxPixelsTolerance <= 0 {
processor.maxPixelsTolerance = 2.0
}
return processor
}
func (p ImageProcessor) ProcessImage(img image.Image) ([]processedVisionImage, processedVisionLayout, error) {
img = imageproc.Composite(img)
orig := img.Bounds().Size()
resizedWidth, resizedHeight := p.smartResize(orig.Y, orig.X)
layout := processedVisionLayout{rows: 1, cols: 1}
if p.shouldSplit(orig.Y, orig.X) {
gridWidth, gridHeight, targetWidth, targetHeight := p.gridLayout(orig.Y, orig.X)
layout.rows = gridHeight
layout.cols = gridWidth
layout.hasThumbnail = p.useThumbnail && gridWidth*gridHeight != 1
resized := imageproc.Resize(img, image.Point{X: targetWidth, Y: targetHeight}, imageproc.ResizeBilinear)
images := make([]processedVisionImage, 0, gridWidth*gridHeight+1)
for row := range gridHeight {
for col := range gridWidth {
rect := image.Rect(
col*p.tileSize,
row*p.tileSize,
(col+1)*p.tileSize,
(row+1)*p.tileSize,
)
tile := cropImage(resized, rect)
images = append(images, processedVisionImage{
data: imageproc.Normalize(tile, p.imageMean, p.imageStd, true, true),
size: tile.Bounds().Size(),
row: row + 1,
col: col + 1,
})
}
}
if layout.hasThumbnail {
thumbnail := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
images = append(images, processedVisionImage{
data: imageproc.Normalize(thumbnail, p.imageMean, p.imageStd, true, true),
size: thumbnail.Bounds().Size(),
thumbnail: true,
})
}
return images, layout, nil
}
single := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
return []processedVisionImage{{
data: imageproc.Normalize(single, p.imageMean, p.imageStd, true, true),
size: single.Bounds().Size(),
}}, layout, nil
}
func (p ImageProcessor) shouldSplit(height, width int) bool {
if !p.doImageSplitting || p.minTiles == 1 && p.maxTiles == 1 {
return false
}
totalFactor := p.patchSize * p.downsampleFactor
hBar := max(p.patchSize, roundByFactor(height, totalFactor))
wBar := max(p.patchSize, roundByFactor(width, totalFactor))
limit := float64(p.maxImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor)
limit *= p.maxPixelsTolerance
return float64(hBar*wBar) > limit
}
func (p ImageProcessor) smartResize(height, width int) (int, int) {
totalFactor := p.patchSize * p.downsampleFactor
minPixels := p.minImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor
maxPixels := p.maxImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor
hBar := max(totalFactor, roundByFactor(height, totalFactor))
wBar := max(totalFactor, roundByFactor(width, totalFactor))
if hBar*wBar > maxPixels {
beta := math.Sqrt(float64(height*width) / float64(maxPixels))
hBar = max(totalFactor, int(math.Floor(float64(height)/beta/float64(totalFactor)))*totalFactor)
wBar = max(totalFactor, int(math.Floor(float64(width)/beta/float64(totalFactor)))*totalFactor)
} else if hBar*wBar < minPixels {
beta := math.Sqrt(float64(minPixels) / float64(height*width))
hBar = int(math.Ceil(float64(height)*beta/float64(totalFactor))) * totalFactor
wBar = int(math.Ceil(float64(width)*beta/float64(totalFactor))) * totalFactor
}
return wBar, hBar
}
func (p ImageProcessor) gridLayout(height, width int) (gridWidth, gridHeight, targetWidth, targetHeight int) {
aspectRatio := float64(width) / float64(height)
targetRatios := p.targetRatios()
bestRatio := clipImageSize{width: 1, height: 1}
bestRatioDiff := math.MaxFloat64
area := float64(width * height)
for _, ratio := range targetRatios {
targetAspect := float64(ratio.width) / float64(ratio.height)
ratioDiff := math.Abs(aspectRatio - targetAspect)
if ratioDiff < bestRatioDiff {
bestRatioDiff = ratioDiff
bestRatio = ratio
continue
}
if ratioDiff == bestRatioDiff {
targetArea := float64(p.tileSize * p.tileSize * ratio.width * ratio.height)
if area > 0.5*targetArea {
bestRatio = ratio
}
}
}
return bestRatio.width, bestRatio.height, p.tileSize * bestRatio.width, p.tileSize * bestRatio.height
}
type clipImageSize struct {
width int
height int
}
func (p ImageProcessor) targetRatios() []clipImageSize {
targetRatios := make([]clipImageSize, 0, p.maxTiles*p.maxTiles)
for n := p.minTiles; n <= p.maxTiles; n++ {
for w := 1; w <= n; w++ {
for h := 1; h <= n; h++ {
if w*h < p.minTiles || w*h > p.maxTiles {
continue
}
targetRatios = append(targetRatios, clipImageSize{width: w, height: h})
}
}
}
unique := targetRatios[:0]
for _, ratio := range targetRatios {
if slices.Contains(unique, ratio) {
continue
}
unique = append(unique, ratio)
}
slices.SortFunc(unique, func(a, b clipImageSize) int {
return a.width*a.height - b.width*b.height
})
return unique
}
func roundByFactor(number, factor int) int {
if factor <= 0 {
return number
}
return int(math.RoundToEven(float64(number)/float64(factor))) * factor
}
func cropImage(img image.Image, rect image.Rectangle) image.Image {
dst := image.NewRGBA(image.Rect(0, 0, rect.Dx(), rect.Dy()))
stdimage.Draw(dst, dst.Bounds(), img, rect.Min, stdimage.Src)
return dst
}

View File

@@ -0,0 +1,105 @@
package lfm2
import (
"image"
"image/color"
"testing"
)
func TestProcessImageSingleTile(t *testing.T) {
p := ImageProcessor{
patchSize: 16,
downsampleFactor: 2,
numChannels: 3,
imageMean: [3]float32{0.5, 0.5, 0.5},
imageStd: [3]float32{0.5, 0.5, 0.5},
doImageSplitting: true,
minTiles: 2,
maxTiles: 10,
useThumbnail: true,
tileSize: 512,
minImageTokens: 64,
maxImageTokens: 256,
maxPixelsTolerance: 2.0,
}
img := image.NewRGBA(image.Rect(0, 0, 320, 320))
out, layout, err := p.ProcessImage(img)
if err != nil {
t.Fatalf("ProcessImage returned error: %v", err)
}
if layout.rows != 1 || layout.cols != 1 || layout.hasThumbnail {
t.Fatalf("layout = %+v, want rows=1 cols=1 hasThumbnail=false", layout)
}
if len(out) != 1 {
t.Fatalf("len(out) = %d, want 1", len(out))
}
if out[0].size != (image.Point{X: 320, Y: 320}) {
t.Fatalf("single image size = %+v, want 320x320", out[0].size)
}
if out[0].thumbnail {
t.Fatalf("single image should not be marked as thumbnail")
}
}
func TestProcessImageDynamicTiling(t *testing.T) {
p := ImageProcessor{
patchSize: 16,
downsampleFactor: 2,
numChannels: 3,
imageMean: [3]float32{0.5, 0.5, 0.5},
imageStd: [3]float32{0.5, 0.5, 0.5},
doImageSplitting: true,
minTiles: 2,
maxTiles: 10,
useThumbnail: true,
tileSize: 512,
minImageTokens: 64,
maxImageTokens: 256,
maxPixelsTolerance: 2.0,
}
// Wide image that should trigger multi-tile splitting.
img := image.NewRGBA(image.Rect(0, 0, 3000, 1000))
fill := color.RGBA{R: 120, G: 90, B: 60, A: 255}
for y := range 1000 {
for x := range 3000 {
img.Set(x, y, fill)
}
}
out, layout, err := p.ProcessImage(img)
if err != nil {
t.Fatalf("ProcessImage returned error: %v", err)
}
if layout.rows*layout.cols <= 1 {
t.Fatalf("expected multi-tile layout, got %+v", layout)
}
if !layout.hasThumbnail {
t.Fatalf("expected thumbnail for multi-tile layout")
}
wantLen := layout.rows*layout.cols + 1
if len(out) != wantLen {
t.Fatalf("len(out) = %d, want %d", len(out), wantLen)
}
for i := range layout.rows * layout.cols {
if out[i].size != (image.Point{X: 512, Y: 512}) {
t.Fatalf("tile[%d] size = %+v, want 512x512", i, out[i].size)
}
if out[i].thumbnail {
t.Fatalf("tile[%d] should not be marked as thumbnail", i)
}
}
thumb := out[len(out)-1]
if !thumb.thumbnail {
t.Fatalf("last chunk should be thumbnail")
}
if thumb.size.X%32 != 0 || thumb.size.Y%32 != 0 {
t.Fatalf("thumbnail size = %+v, want dimensions aligned to 32", thumb.size)
}
}

View File

@@ -15,6 +15,7 @@ import (
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
_ "github.com/ollama/ollama/model/models/nemotronh"
_ "github.com/ollama/ollama/model/models/nomicbert"
_ "github.com/ollama/ollama/model/models/olmo3"
_ "github.com/ollama/ollama/model/models/qwen2"

View File

@@ -0,0 +1,88 @@
package nemotronh
import (
"fmt"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// Attention implements simple attention without RoPE for Nemotron-H.
// Unlike Qwen3Next, Nemotron-H attention has:
// - No RoPE (position info comes from Mamba2 layers)
// - Standard scaled dot-product attention
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
hiddenDim := hiddenStates.Dim(0)
nSeqTokens := hiddenStates.Dim(1)
switch hiddenStates.Dim(2) {
case 0:
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
case 1:
default:
return nil, ErrUnsupportedBatchLayout
}
// Nemotron-H is currently clamped to num_parallel=1.
if cache != nil && cache.IsSupportedForBatch() {
if cache.numSeqs() != 1 {
return nil, ErrUnsupportedBatchLayout
}
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
return nil, ErrUnsupportedBatchLayout
}
}
batchSize := nSeqTokens
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize)
headDim := opts.getHeadDim()
if headDim <= 0 {
return nil, fmt.Errorf("nemotronh: invalid attention head dimension %d", headDim)
}
// Q projection
query := a.Query.Forward(ctx, hiddenStates)
if query.Dim(0)%headDim != 0 {
return nil, fmt.Errorf("nemotronh: query dim %d not divisible by head dim %d", query.Dim(0), headDim)
}
numHeads := query.Dim(0) / headDim
query = query.Reshape(ctx, headDim, numHeads, batchSize)
// K projection
key := a.Key.Forward(ctx, hiddenStates)
if key.Dim(0)%headDim != 0 {
return nil, fmt.Errorf("nemotronh: key dim %d not divisible by head dim %d", key.Dim(0), headDim)
}
numKVHeads := key.Dim(0) / headDim
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
// V projection
value := a.Value.Forward(ctx, hiddenStates)
if value.Dim(0)%headDim != 0 {
return nil, fmt.Errorf("nemotronh: value dim %d not divisible by head dim %d", value.Dim(0), headDim)
}
if value.Dim(0)/headDim != numKVHeads {
return nil, fmt.Errorf("nemotronh: key heads %d and value heads %d do not match", numKVHeads, value.Dim(0)/headDim)
}
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
// Standard attention computation (no RoPE)
scale := opts.attentionScale
if scale == 0 {
scale = 1.0 / math.Sqrt(float64(headDim))
}
attention := nn.Attention(ctx, query, key, value, scale, cache)
// Flatten heads
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
// Output projection
return a.Output.Forward(ctx, attention), nil
}

View File

@@ -0,0 +1,55 @@
package nemotronh
import (
"errors"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
// with the layer requirements.
var ErrUnsupportedBatchLayout = errors.New("nemotronh: unsupported batch layout")
var (
_ kvcache.Cache = (*HybridCache)(nil)
_ kvcache.CheckpointCache = (*HybridCache)(nil)
)
// HybridCache adapts the shared recurrent cache base for Nemotron-H naming.
type HybridCache struct {
*kvcache.Recurrent
}
func NewHybridCache(convDim, convChannels, ssmStateSize int) *HybridCache {
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
Shift: Shift,
ConvDim: convDim,
ConvChannels: convChannels,
RecurrentStateSize: ssmStateSize,
CheckpointLogPrefix: "nemotronh",
})
return &HybridCache{Recurrent: base}
}
// SSMState returns the SSM state for current batch sequences.
func (c *HybridCache) SSMState(ctx ml.Context, layer int, dState, headDim, nHead int) (ml.Tensor, error) {
return c.RecurrentState4D(ctx, layer, dState, headDim, nHead)
}
// UpdateSSMState writes a new SSM state for current batch sequences.
func (c *HybridCache) UpdateSSMState(ctx ml.Context, layer int, newState ml.Tensor) {
c.UpdateRecurrentState(ctx, layer, newState)
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.SlotsTensor()
}
func (c *HybridCache) seqTokens() int {
return c.SeqTokens()
}
func (c *HybridCache) numSeqs() int {
return c.NumSeqs()
}

View File

@@ -0,0 +1,197 @@
package nemotronh
import (
"log/slog"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// convKernel wraps the 1D convolution kernel tensor
type convKernel struct {
Weight ml.Tensor `gguf:"weight"`
}
// Mamba2 implements the Mamba2 SSM layer for Nemotron-H.
// The forward pass follows llama.cpp's build_mamba2_layer:
// 1. Input projection: zxBCdt = SSMIn @ hidden
// 2. Split: z, xBC, dt
// 3. Concat with conv state, apply SSMConv, save new conv state
// 4. Apply SiLU to convolved xBC
// 5. Split: x, B, C
// 6. Add dt bias
// 7. SSMScan: y = SSMScan(state, x, dt, A, B, C, ids)
// 8. D skip: y = y + x * D
// 9. Swiglu with z: y = z * silu(y)
// 10. Group RMSNorm
// 11. Output projection
type Mamba2 struct {
SSMIn *nn.Linear `gguf:"ssm_in"` // n_embd → d_in_proj (2*d_inner + 2*n_group*d_state + n_head)
SSMConv1D *convKernel `gguf:"ssm_conv1d"` // conv kernel
SSMConv1DB ml.Tensor `gguf:"ssm_conv1d.bias"`
SSMDtB ml.Tensor `gguf:"ssm_dt.bias"` // dt bias [n_head]
SSMA ml.Tensor `gguf:"ssm_a"` // A parameter [1, n_head]
SSMD ml.Tensor `gguf:"ssm_d"` // D skip connection [1, n_head]
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` // group norm
SSMOut *nn.Linear `gguf:"ssm_out"` // output projection
Layer int
}
func (m *Mamba2) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
layer := m.Layer
hiddenDim := hiddenStates.Dim(0)
nSeqTokens := hiddenStates.Dim(1)
switch hiddenStates.Dim(2) {
case 0:
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
case 1:
default:
return nil, ErrUnsupportedBatchLayout
}
// Nemotron-H is currently clamped to num_parallel=1.
if cache != nil && cache.IsSupportedForBatch() {
if cache.numSeqs() != 1 {
return nil, ErrUnsupportedBatchLayout
}
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
return nil, ErrUnsupportedBatchLayout
}
}
nSeqs := 1
dConv := opts.ssmDConv
dInner := opts.ssmDInner
dState := opts.ssmDState
nHead := opts.ssmNHead
headDim := dInner / nHead
nGroup := opts.ssmNGroup
// {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
// d_in_proj = 2*d_inner + 2*n_group*d_state + n_head
zxBCdt := m.SSMIn.Forward(ctx, hiddenStates)
// Split into z, xBC, dt
// z: [head_dim, n_head, n_seq_tokens, n_seqs]
z := zxBCdt.Slice(ctx, 0, 0, dInner, 1)
z = z.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
// xBC: [d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs]
xBCSize := dInner + 2*nGroup*dState
xBC := zxBCdt.Slice(ctx, 0, dInner, dInner+xBCSize, 1)
if nSeqTokens == 1 {
xBC = xBC.Reshape(ctx, xBCSize, 1, nSeqs)
}
// dt: [n_head, n_seq_tokens, n_seqs]
dt := zxBCdt.Slice(ctx, 0, 2*dInner+2*nGroup*dState, 2*dInner+2*nGroup*dState+nHead, 1)
if nSeqTokens == 1 {
dt = dt.Reshape(ctx, nHead, 1, nSeqs)
} else {
dt = dt.Contiguous(ctx, nHead, nSeqTokens, nSeqs)
}
// Get conv state from cache
convStates, err := cache.ConvState(ctx, layer)
if err != nil {
slog.Warn("nemotronh: failed to get conv state, using zeros", "layer", layer, "error", err)
convStates = ctx.Input().Zeros(ml.DTypeF32, dConv-1, xBCSize, nSeqs)
}
// Reshape conv states: [d_conv-1, xBCSize, n_seqs]
convStates = convStates.Reshape(ctx, dConv-1, xBCSize, nSeqs)
// For decode (n_seq_tokens == 1), reshape avoids a transpose/contiguous pair.
var xBCT ml.Tensor
if nSeqTokens == 1 {
xBCT = xBC.Reshape(ctx, 1, xBCSize, nSeqs)
} else {
// Prefill path: [xBCSize, n_seq_tokens, n_seqs] -> [n_seq_tokens, xBCSize, n_seqs]
xBCT = xBC.Permute(ctx, 1, 0, 2, 3)
}
// Concatenate with conv state: [d_conv-1 + n_seq_tokens, xBCSize, n_seqs]
convInput := convStates.Concat(ctx, xBCT, 0)
// Save new conv state (last d_conv-1 columns)
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+dConv-1, 1)
cache.UpdateConvState(ctx, layer, lastConvStates)
// Apply SSM convolution
xBC = convInput.SSMConv(ctx, m.SSMConv1D.Weight)
// Add conv bias
if m.SSMConv1DB != nil {
xBC = xBC.Add(ctx, m.SSMConv1DB)
}
// Apply SiLU
xBC = xBC.SILU(ctx)
// Split xBC into x, B, C
// x: [head_dim, n_head, n_seq_tokens, n_seqs]
x := xBC.Slice(ctx, 0, 0, dInner, 1)
x = x.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
// B: [d_state, n_group, n_seq_tokens, n_seqs]
B := xBC.Slice(ctx, 0, dInner, dInner+nGroup*dState, 1)
B = B.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
// C: [d_state, n_group, n_seq_tokens, n_seqs]
C := xBC.Slice(ctx, 0, dInner+nGroup*dState, dInner+2*nGroup*dState, 1)
C = C.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
// Add dt bias
dt = dt.Add(ctx, m.SSMDtB)
// Get SSM state from cache
state, err := cache.SSMState(ctx, layer, dState, headDim, nHead)
if err != nil {
slog.Warn("nemotronh: failed to get SSM state, using zeros", "layer", layer, "error", err)
state = ctx.Input().Zeros(ml.DTypeF32, dState, headDim, nHead, nSeqs)
}
// SSMScan
// state: [d_state, head_dim, n_head, n_seqs]
// returns: [head_dim, n_head, n_seq_tokens, n_seqs] concatenated with new state
ySsm := state.SSMScan(ctx, x, dt, m.SSMA, B, C, cache.slotsTensor())
// ySsm is a packed 1D buffer: [y (nSeqTokens*headDim*nHead*nSeqs), newState]
yElems := headDim * nHead * nSeqTokens * nSeqs
y := ySsm.View(ctx, 0, yElems).Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
stateOffsetBytes := yElems * x.Stride(0)
stateElems := dState * headDim * nHead * nSeqs
newState := ySsm.View(ctx, stateOffsetBytes, stateElems)
newState = newState.Reshape(ctx, dState, headDim, nHead, nSeqs)
// Update SSM state in cache
cache.UpdateSSMState(ctx, layer, newState)
// D skip connection: y = y + x * D
if m.SSMD != nil {
// SSMD shape: [1, n_head] -> broadcast to [head_dim, n_head, n_seq_tokens, n_seqs]
xD := x.Mul(ctx, m.SSMD)
y = y.Add(ctx, xD)
}
// Swiglu with z: y = z * silu(y)
y = z.SILU(ctx, y)
// Group RMSNorm
if m.SSMNorm != nil {
// Reshape for group norm: [d_inner/n_group, n_group, n_seq_tokens, n_seqs]
innerPerGroup := dInner / nGroup
y = y.Reshape(ctx, innerPerGroup, nGroup, nSeqTokens, nSeqs)
y = m.SSMNorm.Forward(ctx, y, opts.eps)
}
// Reshape back to [d_inner, n_seq_tokens, n_seqs]
y = y.Reshape(ctx, dInner, nSeqTokens, nSeqs)
// Output projection
out := m.SSMOut.Forward(ctx, y)
// Reshape to 2D for consistency with attention output
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
}

View File

@@ -0,0 +1,417 @@
package nemotronh
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
// Options contains model configuration
type Options struct {
hiddenSize int
numHeads int // attention heads
numKVHeads int // KV heads for attention layers
headDim int
eps float32
// Mamba2 SSM config
ssmDConv int // conv kernel size
ssmDInner int // inner dimension (d_inner)
ssmDState int // state dimension
ssmNHead int // number of SSM heads (dt_rank)
ssmNGroup int // number of groups for B, C
// Per-layer configuration
isRecurrent []bool // true = Mamba2, false = attention or FFN
nFF []int // n_ff per layer (0 = attention-only)
// Attention scale
attentionScale float64
// MoE config
numExperts int
numExpertsUsed int
expertWeightsNorm bool
expertWeightsScale float32
expertWeightsNormClip float32
}
func (o Options) getHeadDim() int {
if o.headDim > 0 {
return o.headDim
}
if o.numHeads <= 0 {
return 0
}
return o.hiddenSize / o.numHeads
}
// Operator is the interface for layer operators (Mamba2 or Attention)
type Operator interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
}
// MLP is the interface for feedforward networks
type MLP interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
}
// Layer represents a single transformer layer
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Operator Operator // Mamba2, Attention, or nil (for FFN-only layers)
MLP MLP // Dense or MoE FFN, or nil
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
residual := hiddenStates
// Pre-layer norm
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
// Layer operator (Mamba2, Attention, or FFN)
if l.Operator != nil {
var err error
hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, cache, opts)
if err != nil {
return nil, err
}
} else if l.MLP != nil {
// FFN-only layer
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
}
// Output projection for last layer
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
// Residual connection
return hiddenStates.Add(ctx, residual), nil
}
// Model is the main Nemotron-H model
type Model struct {
model.Base
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Layers []Layer `gguf:"blk"`
*Options
}
// Shift is used for KV cache position shifting.
// Nemotron-H attention does not apply RoPE, so keys do not need to be transformed.
func Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
cache := m.Cache.(*HybridCache)
for i, layer := range m.Layers {
cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
var err error
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, outputs, cache, m.Options)
if err != nil {
return nil, err
}
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers)
// Get per-layer configuration from GGUF metadata
// Use the same interface pattern as qwen3next
type perLayerConfig interface {
HeadCount() []uint64
HeadCountKV() []uint64
FFNLength() []uint64
}
var headCount []uint64
var headCountKV []uint64
var ffnLength []uint64
if plc, ok := c.(perLayerConfig); ok {
headCount = plc.HeadCount()
headCountKV = plc.HeadCountKV()
ffnLength = plc.FFNLength()
}
// Build per-layer arrays with defaults
isRecurrent := make([]bool, numLayers)
nFF := make([]int, numLayers)
for i := range numLayers {
// Get per-layer values
kvHeads := uint64(1) // Default non-zero
if i < len(headCountKV) {
kvHeads = headCountKV[i]
}
ff := uint64(0)
if i < len(ffnLength) {
ff = ffnLength[i]
}
nFF[i] = int(ff)
// A layer is recurrent IFF n_head_kv == 0 AND n_ff == 0
// This matches llama.cpp behavior for Nemotron-H
isRecurrent[i] = kvHeads == 0 && ff == 0
}
// Determine if MoE
isMoE := c.Uint("expert_count") > 0
for i := range layers {
if isRecurrent[i] {
// Mamba2 layer
layers[i].Operator = &Mamba2{Layer: i}
} else if nFF[i] == 0 {
// Attention-only layer (n_head_kv > 0, n_ff == 0)
layers[i].Operator = &Attention{}
} else {
// FFN layer (n_ff > 0)
if isMoE {
layers[i].MLP = &MoESparse{}
} else {
layers[i].MLP = &Dense{}
}
}
}
// Get attention head configuration
numHeads := int(c.Uint("attention.head_count"))
if numHeads == 0 {
for i := range numLayers {
if i < len(headCount) && i < len(headCountKV) && headCount[i] > 0 && headCountKV[i] > 0 {
numHeads = int(headCount[i])
break
}
}
}
numKVHeads := int(c.Uint("attention.head_count_kv"))
if numKVHeads == 0 {
for i := range numLayers {
if i < len(headCountKV) && i < len(ffnLength) && headCountKV[i] > 0 && ffnLength[i] == 0 {
numKVHeads = int(headCountKV[i])
break
}
}
if numKVHeads == 0 {
numKVHeads = numHeads
}
}
headDim := int(c.Uint("attention.head_dim"))
if headDim == 0 {
if keyLength := int(c.Uint("attention.key_length")); keyLength > 0 {
headDim = keyLength
} else if numHeads > 0 {
headDim = int(c.Uint("embedding_length")) / numHeads
}
}
if headDim <= 0 {
return nil, fmt.Errorf("nemotronh: invalid attention head dimension")
}
if numHeads <= 0 {
// Attention layers derive per-layer head counts from projection weights.
// Keep a non-zero default to avoid invalid option math.
numHeads = 1
}
numExperts := int(c.Uint("expert_count"))
numExpertsUsed := int(c.Uint("expert_used_count"))
if numExperts > 0 {
if numExpertsUsed <= 0 || numExpertsUsed > numExperts {
return nil, fmt.Errorf("nemotronh: invalid expert_used_count=%d for expert_count=%d", numExpertsUsed, numExperts)
}
}
opts := &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: numHeads,
numKVHeads: numKVHeads,
headDim: headDim,
eps: c.Float("attention.layer_norm_rms_epsilon"),
ssmDConv: int(c.Uint("ssm.conv_kernel")),
ssmDInner: int(c.Uint("ssm.inner_size")),
ssmDState: int(c.Uint("ssm.state_size")),
ssmNHead: int(c.Uint("ssm.time_step_rank")),
ssmNGroup: int(c.Uint("ssm.group_count")),
isRecurrent: isRecurrent,
nFF: nFF,
attentionScale: float64(c.Float("attention.scale")),
numExperts: numExperts,
numExpertsUsed: numExpertsUsed,
expertWeightsNorm: c.Bool("expert_weights_norm", false),
expertWeightsScale: c.Float("expert_weights_scale", 1.0),
expertWeightsNormClip: c.Float("expert_weights_norm_clip", 0),
}
// Calculate cache dimensions
convDim := max(0, opts.ssmDConv-1)
convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
ssmHeadDim := 0
if opts.ssmNHead > 0 {
ssmHeadDim = opts.ssmDInner / opts.ssmNHead
}
ssmStateSize := opts.ssmDState * ssmHeadDim * opts.ssmNHead
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
Layers: layers,
Options: opts,
}
m.Cache = NewHybridCache(convDim, convChannels, ssmStateSize)
return &m, nil
}
func init() {
model.Register("nemotron_h", New)
model.Register("nemotron_h_moe", New)
}
// Ensure Model implements model.Model
var _ model.Model = (*Model)(nil)
// Dense implements standard feedforward with ReLU-squared activation
type Dense struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (d *Dense) Forward(ctx ml.Context, x ml.Tensor, opts *Options) ml.Tensor {
// up -> ReLU-squared -> down
up := d.Up.Forward(ctx, x)
up = up.RELU(ctx)
up = up.Mul(ctx, up) // Square
return d.Down.Forward(ctx, up)
}
// MoESparse implements MoE with shared experts for Nemotron-H-MoE
type MoESparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
Bias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
LatentIn *nn.Linear `gguf:"ffn_latent_in"`
LatentOut *nn.Linear `gguf:"ffn_latent_out"`
// Shared experts
SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
}
func (m *MoESparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenDim := hiddenStates.Dim(0)
seqLen := hiddenStates.Dim(1)
batchSize := hiddenStates.Dim(2)
if batchSize == 0 {
batchSize = 1
}
hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, seqLen*batchSize)
// Router logits with sigmoid gating
routerLogits := m.Router.Forward(ctx, hiddenStates2D)
// Weights come from unbiased sigmoid probabilities.
probs := routerLogits.Sigmoid(ctx)
// Selection uses optional bias.
selectionProbs := probs
if m.Bias != nil {
selectionProbs = selectionProbs.Add(ctx, m.Bias)
}
// Select top-k experts
selectedExperts := selectionProbs.TopK(ctx, opts.numExpertsUsed)
routingWeights := probs.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
// Normalize routing weights
if opts.expertWeightsNorm {
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
weightsSum := routingWeights.SumRows(ctx)
weightsSum = weightsSum.Clamp(ctx, 6.103515625e-5, float32(math.MaxFloat32))
routingWeights = routingWeights.Div(ctx, weightsSum)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
}
// Scale routing weights
if opts.expertWeightsScale != 1.0 {
routingWeights = routingWeights.Scale(ctx, float64(opts.expertWeightsScale))
}
routedInput := hiddenStates2D
if m.LatentIn != nil {
routedInput = m.LatentIn.Forward(ctx, routedInput)
}
hiddenStates3D := routedInput.Reshape(ctx, routedInput.Dim(0), 1, routedInput.Dim(1))
// Expert computation with ReLU-squared activation
upOut := m.Up.Forward(ctx, hiddenStates3D, selectedExperts)
upOut = upOut.RELU(ctx)
upOut = upOut.Mul(ctx, upOut) // Square
experts := m.Down.Forward(ctx, upOut, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
// Sum over experts
moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
if m.LatentOut != nil {
moeOut = m.LatentOut.Forward(ctx, moeOut)
}
// Add shared experts if present
if m.SharedUp != nil {
sharedUp := m.SharedUp.Forward(ctx, hiddenStates2D)
sharedUp = sharedUp.RELU(ctx)
sharedUp = sharedUp.Mul(ctx, sharedUp) // Square
sharedOut := m.SharedDown.Forward(ctx, sharedUp)
moeOut = moeOut.Add(ctx, sharedOut)
}
return moeOut
}

View File

@@ -32,6 +32,8 @@ type LFM2Parser struct {
hasThinkingSupport bool
needsThinkingLeadingTrim bool // trim leading whitespace after <think> tag
needsContentLeadingTrim bool // trim leading whitespace after </think> tag
toolNames map[string]struct{}
hasTools bool
}
func (p *LFM2Parser) HasToolSupport() bool {
@@ -63,6 +65,13 @@ func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.T
}
func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.toolNames = make(map[string]struct{}, len(tools))
p.hasTools = len(tools) > 0
for _, tool := range tools {
if tool.Function.Name != "" {
p.toolNames[tool.Function.Name] = struct{}{}
}
}
p.setInitialState(lastMessage, thinkValue)
return tools
}
@@ -105,9 +114,33 @@ func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string,
}
}
// Fallback for models that emit bare tool calls without <|tool_call_*|> wrappers.
if done && len(toolCalls) == 0 && p.hasTools {
candidate := strings.TrimSpace(contentSb.String())
if fallbackCalls, parseErr := p.parseToolCallsContent(candidate); parseErr == nil && p.toolCallsAllowed(fallbackCalls) {
contentSb.Reset()
toolCalls = append(toolCalls, fallbackCalls...)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *LFM2Parser) toolCallsAllowed(calls []api.ToolCall) bool {
if len(calls) == 0 {
return false
}
if len(p.toolNames) == 0 {
return true
}
for _, call := range calls {
if _, ok := p.toolNames[call.Function.Name]; !ok {
return false
}
}
return true
}
func (p *LFM2Parser) parseEvents() []lfm2Event {
var all []lfm2Event
@@ -269,36 +302,16 @@ func (p *LFM2Parser) eat() ([]lfm2Event, bool) {
return events, false
}
// parseToolCallsContent parses one or more tool calls from content
// Supports JSON format and Python-style format including multiple calls: [func1(...),func2(...)]
// parseToolCallsContent parses one or more Python-style tool calls.
// Example: [func1(arg='v'), func2(x=1)]
func (p *LFM2Parser) parseToolCallsContent(content string) ([]api.ToolCall, error) {
content = strings.TrimSpace(content)
// Try JSON format first: {"name": "func", "arguments": {...}}
var parsed struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
// Be tolerant of malformed outputs that include wrapper tags without proper pairing.
content = strings.TrimSpace(strings.TrimPrefix(content, lfm2ToolCallStartTag))
content = strings.TrimSpace(strings.TrimSuffix(content, lfm2ToolCallEndTag))
if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" {
var args api.ToolCallFunctionArguments
if len(parsed.Arguments) > 0 {
if err := json.Unmarshal(parsed.Arguments, &args); err != nil {
return nil, err
}
} else {
args = api.NewToolCallFunctionArguments()
}
return []api.ToolCall{{
Function: api.ToolCallFunction{
Name: parsed.Name,
Arguments: args,
},
}}, nil
}
// Try Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
// Parse Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
return p.parsePythonStyleToolCalls(content)
}
@@ -417,21 +430,16 @@ func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error)
// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2"
func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error {
// Simple state machine to parse key='value' pairs
// Handles: command='ls', flag="-la", count=42, enabled=true
var key string
i := 0
for i < len(argsStr) {
// Skip whitespace
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
// Skip separators and whitespace.
for i < len(argsStr) && (argsStr[i] == ',' || unicode.IsSpace(rune(argsStr[i]))) {
i++
}
if i >= len(argsStr) {
break
}
// Parse key
keyStart := i
for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' {
i++
@@ -439,60 +447,238 @@ func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error
if i >= len(argsStr) || argsStr[i] != '=' {
return errors.New("invalid argument: expected '='")
}
key = strings.TrimSpace(argsStr[keyStart:i])
key := strings.TrimSpace(argsStr[keyStart:i])
if key == "" {
return errors.New("invalid argument: empty key")
}
i++ // skip '='
// Skip whitespace after =
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') {
for i < len(argsStr) && unicode.IsSpace(rune(argsStr[i])) {
i++
}
// Parse value
var value string
if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') {
// Quoted string
quote := argsStr[i]
i++
valueStart := i
for i < len(argsStr) && argsStr[i] != quote {
if argsStr[i] == '\\' && i+1 < len(argsStr) {
i += 2 // skip escaped char
} else {
i++
}
}
value = argsStr[valueStart:i]
if i < len(argsStr) {
i++ // skip closing quote
}
args.Set(key, value)
} else {
// Unquoted value (number, bool, etc)
valueStart := i
for i < len(argsStr) && argsStr[i] != ',' {
i++
}
value = strings.TrimSpace(argsStr[valueStart:i])
// Try to parse as number or bool
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
args.Set(key, v)
} else if v, err := strconv.ParseFloat(value, 64); err == nil {
args.Set(key, v)
} else if value == "true" {
args.Set(key, true)
} else if value == "false" {
args.Set(key, false)
} else {
args.Set(key, value)
}
if i >= len(argsStr) {
return errors.New("invalid argument: missing value")
}
// Skip comma and whitespace
for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
value, next, err := parsePythonArgValue(argsStr, i)
if err != nil {
return err
}
args.Set(key, value)
i = next
// Optional trailing comma before next key/value.
if i < len(argsStr) && argsStr[i] == ',' {
i++
}
}
return nil
}
func parsePythonArgValue(s string, i int) (any, int, error) {
if i >= len(s) {
return nil, i, errors.New("invalid argument: missing value")
}
// Quoted string literal.
if s[i] == '\'' || s[i] == '"' {
quote := s[i]
i++
start := i
for i < len(s) {
if s[i] == '\\' && i+1 < len(s) {
i += 2
continue
}
if s[i] == quote {
value := s[start:i]
i++
return value, i, nil
}
i++
}
return nil, i, errors.New("invalid argument: unterminated string")
}
// Unquoted literal. Consume until top-level comma.
start := i
depthParen, depthSquare, depthCurly := 0, 0, 0
inString := false
var quote byte
escaped := false
for i < len(s) {
ch := s[i]
if inString {
if escaped {
escaped = false
} else if ch == '\\' {
escaped = true
} else if ch == quote {
inString = false
}
i++
continue
}
switch ch {
case '\'', '"':
inString = true
quote = ch
case '(':
depthParen++
case ')':
if depthParen > 0 {
depthParen--
}
case '[':
depthSquare++
case ']':
if depthSquare > 0 {
depthSquare--
}
case '{':
depthCurly++
case '}':
if depthCurly > 0 {
depthCurly--
}
case ',':
if depthParen == 0 && depthSquare == 0 && depthCurly == 0 {
token := strings.TrimSpace(s[start:i])
value, err := parsePythonLiteral(token)
return value, i, err
}
}
i++
}
token := strings.TrimSpace(s[start:i])
value, err := parsePythonLiteral(token)
return value, i, err
}
func parsePythonLiteral(token string) (any, error) {
switch token {
case "":
return "", nil
case "true", "True":
return true, nil
case "false", "False":
return false, nil
case "null", "None":
return nil, nil
}
if v, err := strconv.ParseInt(token, 10, 64); err == nil {
return v, nil
}
if v, err := strconv.ParseFloat(token, 64); err == nil {
return v, nil
}
if strings.HasPrefix(token, "[") || strings.HasPrefix(token, "{") {
var parsed any
if err := json.Unmarshal([]byte(token), &parsed); err == nil {
return parsed, nil
}
if converted, err := pythonLiteralToJSON(token); err == nil {
if err := json.Unmarshal([]byte(converted), &parsed); err == nil {
return parsed, nil
}
}
}
return token, nil
}
func pythonLiteralToJSON(s string) (string, error) {
var out strings.Builder
out.Grow(len(s) + len(s)/8)
inString := false
var quote byte
escaped := false
for i := 0; i < len(s); i++ {
ch := s[i]
if inString {
if escaped {
out.WriteByte(ch)
escaped = false
continue
}
if ch == '\\' {
out.WriteByte(ch)
escaped = true
continue
}
if ch == quote {
out.WriteByte('"')
inString = false
continue
}
if quote == '\'' && ch == '"' {
out.WriteString(`\"`)
continue
}
out.WriteByte(ch)
continue
}
if ch == '\'' || ch == '"' {
inString = true
quote = ch
escaped = false
out.WriteByte('"')
continue
}
// Replace Python identifiers with JSON equivalents when outside strings.
if isIdentStart(ch) {
j := i + 1
for j < len(s) && isIdentPart(s[j]) {
j++
}
ident := s[i:j]
switch ident {
case "True":
out.WriteString("true")
case "False":
out.WriteString("false")
case "None":
out.WriteString("null")
default:
out.WriteString(ident)
}
i = j - 1
continue
}
out.WriteByte(ch)
}
if inString {
return "", errors.New("unterminated string")
}
return out.String(), nil
}
func isIdentStart(b byte) bool {
return (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || b == '_'
}
func isIdentPart(b byte) bool {
return isIdentStart(b) || (b >= '0' && b <= '9')
}

View File

@@ -39,7 +39,7 @@ func TestLFM2Parser(t *testing.T) {
},
{
name: "tool_call_simple",
input: "I'll check the weather.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
input: "I'll check the weather.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>",
expectedContent: "I'll check the weather.",
expectedCalls: []api.ToolCall{
{
@@ -55,7 +55,7 @@ func TestLFM2Parser(t *testing.T) {
},
{
name: "multiple_tool_calls",
input: "Getting weather for both cities.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|><|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"London\"}}<|tool_call_end|>",
input: "Getting weather for both cities.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|><|tool_call_start|>[get_weather(location=\"London\")]<|tool_call_end|>",
expectedContent: "Getting weather for both cities.",
expectedCalls: []api.ToolCall{
{
@@ -79,7 +79,7 @@ func TestLFM2Parser(t *testing.T) {
},
{
name: "complex_tool_arguments",
input: "Processing data.<|tool_call_start|>{\"name\":\"process_data\",\"arguments\":{\"items\":[\"item1\",\"item2\"],\"config\":{\"enabled\":true,\"threshold\":0.95}}}<|tool_call_end|>",
input: "Processing data.<|tool_call_start|>[process_data(items=['item1','item2'], config={'enabled': True, 'threshold': 0.95})]<|tool_call_end|>",
expectedContent: "Processing data.",
expectedCalls: []api.ToolCall{
{
@@ -96,7 +96,7 @@ func TestLFM2Parser(t *testing.T) {
},
{
name: "thinking_with_tool_call",
input: "Let me check the weather...</think>I'll get that for you.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
input: "Let me check the weather...</think>I'll get that for you.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>",
expectedThinking: "Let me check the weather...",
expectedContent: "I'll get that for you.",
expectedCalls: []api.ToolCall{
@@ -144,16 +144,16 @@ func TestLFM2Parser(t *testing.T) {
hasThinking: true,
},
{
name: "tool_call_with_unicode_args",
input: "Searching for information.<|tool_call_start|>{\"name\":\"search\",\"arguments\":{\"query\":\"北京天气\",\"language\":\"中文\"}}<|tool_call_end|>",
name: "tool_call_with_text_args",
input: "Searching for information.<|tool_call_start|>[search(query='beijing weather', language='zh')]<|tool_call_end|>",
expectedContent: "Searching for information.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: testArgs(map[string]any{
"query": "北京天气",
"language": "中文",
"query": "beijing weather",
"language": "zh",
}),
},
},
@@ -169,7 +169,7 @@ func TestLFM2Parser(t *testing.T) {
},
{
name: "empty_tool_call_args",
input: "Pinging server.<|tool_call_start|>{\"name\":\"ping\",\"arguments\":{}}<|tool_call_end|>",
input: "Pinging server.<|tool_call_start|>[ping()]<|tool_call_end|>",
expectedContent: "Pinging server.",
expectedCalls: []api.ToolCall{
{
@@ -353,7 +353,7 @@ func TestLFM2Parser_Streaming(t *testing.T) {
},
{
name: "streaming_tool_call",
chunks: []string{"I'll check weather.", "<|tool_call_start|>", "{\"name\":\"get_weather\",", "\"arguments\":{\"location\":\"Paris\"}}", "<|tool_call_end|>"},
chunks: []string{"I'll check weather.", "<|tool_call_start|>", "[get_weather(", "location=\"Paris\")]", "<|tool_call_end|>"},
expectedContent: "I'll check weather.",
expectedCalls: []api.ToolCall{
{
@@ -381,16 +381,16 @@ func TestLFM2Parser_Streaming(t *testing.T) {
hasThinking: false,
},
{
name: "streaming_tool_call_with_split_json",
chunks: []string{"Processing.", "<|tool_call_start|>{\"name\":\"calc\",\"arguments\":{\"x\":", "42,\"y\":", "24}}<|tool_call_end|>"},
name: "streaming_tool_call_with_split_python",
chunks: []string{"Processing.", "<|tool_call_start|>", "[calc(", "x=42, ", "y=24)]", "<|tool_call_end|>"},
expectedContent: "Processing.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "calc",
Arguments: testArgs(map[string]any{
"x": float64(42),
"y": float64(24),
"x": int64(42),
"y": int64(24),
}),
},
},
@@ -516,8 +516,8 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) {
expectError bool
}{
{
name: "valid_tool_call",
content: `{"name":"get_weather","arguments":{"location":"Paris"}}`,
name: "python_style_single_call",
content: `get_weather(location="Paris")`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
@@ -528,21 +528,33 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) {
},
},
{
name: "complex_arguments",
content: `{"name":"process_data","arguments":{"items":["a","b"],"config":{"enabled":true}}}`,
name: "python_style_with_brackets",
content: `[get_weather(location="Paris")]`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process_data",
Name: "get_weather",
Arguments: testArgs(map[string]any{
"items": []interface{}{"a", "b"},
"config": map[string]interface{}{"enabled": true},
"location": "Paris",
}),
},
},
},
{
name: "empty_arguments",
content: `{"name":"ping","arguments":{}}`,
name: "python_style_complex_arguments",
content: `process(items=['a', 'b'], config={'enabled': True})`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process",
Arguments: testArgs(map[string]any{
"items": []any{"a", "b"},
"config": map[string]any{"enabled": true},
}),
},
},
},
{
name: "python_style_empty_arguments",
content: `ping()`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "ping",
@@ -551,44 +563,13 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) {
},
},
{
name: "unicode_in_tool_name",
content: `{"name":"获取天气","arguments":{"城市":"北京"}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: testArgs(map[string]any{
"城市": "北京",
}),
},
},
},
{
name: "numeric_arguments",
content: `{"name":"calculate","arguments":{"x":3.14,"y":42,"enabled":true}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: testArgs(map[string]any{
"x": 3.14,
"y": float64(42),
"enabled": true,
}),
},
},
},
{
name: "invalid_json",
content: `{invalid json}`,
name: "missing_parenthesis",
content: `get_weather location="Paris")`,
expectError: true,
},
{
name: "missing_name",
content: `{"arguments":{"arg":"value"}}`,
expectError: true,
},
{
name: "empty_name",
content: `{"name":"","arguments":{"arg":"value"}}`,
name: "invalid_argument_format",
content: `bash(command)`,
expectError: true,
},
}
@@ -645,6 +626,24 @@ func TestLFM2Parser_parseToolCallsContent(t *testing.T) {
},
},
},
{
name: "python_style_complex_literals",
content: `[AskUserQuestion(question="What's up?", headers=['Hello!', 'How can I help you?'], options=['Debugging help', 'Code writing assistance'], multiSelect=False, metadata={'priority': 1, 'active': True})]`,
expected: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "AskUserQuestion",
Arguments: testArgs(map[string]any{
"question": "What's up?",
"headers": []any{"Hello!", "How can I help you?"},
"options": []any{"Debugging help", "Code writing assistance"},
"multiSelect": false,
"metadata": map[string]any{"priority": float64(1), "active": true},
}),
},
},
},
},
{
name: "single_python_style_call",
content: `bash(command='ls -la')`,
@@ -673,6 +672,34 @@ func TestLFM2Parser_parseToolCallsContent(t *testing.T) {
},
},
},
{
name: "single_call_with_orphan_end_tag",
content: `[bash(command='ls')]<|tool_call_end|>`,
expected: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "bash",
Arguments: testArgs(map[string]any{
"command": "ls",
}),
},
},
},
},
{
name: "single_call_with_wrapper_tags",
content: `<|tool_call_start|>[bash(command='pwd')]<|tool_call_end|>`,
expected: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "bash",
Arguments: testArgs(map[string]any{
"command": "pwd",
}),
},
},
},
},
{
name: "multiple_different_functions",
content: `[get_weather(location='Paris'),search(query='news')]`,
@@ -1086,3 +1113,106 @@ func TestLFM2Parser_EdgeCases(t *testing.T) {
})
}
}
func TestLFM2Parser_BareToolCallFallback(t *testing.T) {
parser := &LFM2Parser{}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
},
},
}
parser.Init(tools, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add(`[get_weather(location="Paris")]`, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name get_weather, got %q", calls[0].Function.Name)
}
}
func TestLFM2Parser_BareUnknownToolCallDoesNotParse(t *testing.T) {
parser := &LFM2Parser{}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
},
},
}
parser.Init(tools, nil, &api.ThinkValue{Value: false})
input := `[unknown_tool(location="Paris")]`
content, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if content != input {
t.Fatalf("expected content to be preserved, got %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestLFM2Parser_ImagePlaceholdersPreserved(t *testing.T) {
tests := []struct {
name string
input string
}{
{
name: "indexed_img_placeholder",
input: "[img-0]describe this image",
},
{
name: "template_image_placeholder",
input: "<image>describe this image",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &LFM2Parser{}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "bash",
},
},
}
parser.Init(tools, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add(tt.input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if content != tt.input {
t.Fatalf("expected content %q, got %q", tt.input, content)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
})
}
}

View File

@@ -57,6 +57,8 @@ func TestBuiltInParsersStillWork(t *testing.T) {
{"qwen3"},
{"qwen3-thinking"},
{"qwen3-coder"},
{"lfm2"},
{"lfm2-thinking"},
{"harmony"},
}

View File

@@ -1,7 +1,9 @@
package renderers
import (
"bytes"
"encoding/json"
"sort"
"strings"
"github.com/ollama/ollama/api"
@@ -9,18 +11,218 @@ import (
type LFM2Renderer struct {
IsThinking bool
useImgTags bool
}
const lfm2BOSToken = "<|startoftext|>"
const (
lfm2ToolListStartTag = "<|tool_list_start|>"
lfm2ToolListEndTag = "<|tool_list_end|>"
lfm2ToolCallStartTag = "<|tool_call_start|>"
lfm2ToolCallEndTag = "<|tool_call_end|>"
lfm2ToolResponseStartTag = "<|tool_response_start|>"
lfm2ToolResponseEndTag = "<|tool_response_end|>"
)
func lfm2RenderSystemContent(content any) string {
switch v := content.(type) {
case string:
return v
case []any:
var sb strings.Builder
for _, item := range v {
obj, ok := item.(map[string]any)
if !ok {
continue
}
if itemType, _ := obj["type"].(string); itemType == "text" {
if text, ok := obj["text"].(string); ok {
sb.WriteString(text)
}
}
}
return sb.String()
default:
return ""
}
}
func lfm2JSON(v any) string {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.SetEscapeHTML(false)
if err := enc.Encode(v); err != nil {
fallback, _ := json.Marshal(v)
return string(fallback)
}
encoded := bytes.TrimSuffix(buf.Bytes(), []byte{'\n'})
// HF `tojson` defaults to `json.dumps(..., separators=None)`, which inserts
// a space after commas and colons.
var out strings.Builder
out.Grow(len(encoded) + len(encoded)/8)
inString := false
escaped := false
for i, b := range encoded {
out.WriteByte(b)
if inString {
if escaped {
escaped = false
continue
}
if b == '\\' {
escaped = true
continue
}
if b == '"' {
inString = false
}
continue
}
if b == '"' {
inString = true
continue
}
if (b == ':' || b == ',') && i+1 < len(encoded) {
next := encoded[i+1]
if next != ' ' && next != '\n' && next != '\r' && next != '\t' {
out.WriteByte(' ')
}
}
}
return out.String()
}
func lfm2ImagePlaceholder(useImgTags bool) string {
if useImgTags {
return "[img]"
}
return "<image>"
}
func lfm2RenderContent(content any, useImgTags bool) string {
switch v := content.(type) {
case string:
return v
case []any:
var sb strings.Builder
for _, item := range v {
obj, ok := item.(map[string]any)
if !ok {
sb.WriteString(lfm2JSON(item))
continue
}
itemType, _ := obj["type"].(string)
switch itemType {
case "image":
sb.WriteString(lfm2ImagePlaceholder(useImgTags))
case "text":
if text, ok := obj["text"].(string); ok {
sb.WriteString(text)
} else {
sb.WriteString(lfm2JSON(item))
}
default:
sb.WriteString(lfm2JSON(item))
}
}
return sb.String()
default:
return lfm2JSON(content)
}
}
func lfm2ToolSchema(tool api.Tool) any {
if tool.Function.Name == "" {
return tool
}
// LFM2 templates are typically fed function-schema objects (name/description/parameters).
return tool.Function
}
func lfm2ToolCallArgument(v any) string {
return lfm2JSON(v)
}
func lfm2RenderToolCalls(calls []api.ToolCall) string {
var sb strings.Builder
sb.WriteString(lfm2ToolCallStartTag)
sb.WriteString("[")
for i, tc := range calls {
if i > 0 {
sb.WriteString(",")
}
sb.WriteString(tc.Function.Name)
sb.WriteString("(")
keys := make([]string, 0, tc.Function.Arguments.Len())
for key := range tc.Function.Arguments.All() {
keys = append(keys, key)
}
sort.Strings(keys)
for j, key := range keys {
if j > 0 {
sb.WriteString(",")
}
value, _ := tc.Function.Arguments.Get(key)
sb.WriteString(key)
sb.WriteString("=")
sb.WriteString(lfm2ToolCallArgument(value))
}
sb.WriteString(")")
}
sb.WriteString("]")
sb.WriteString(lfm2ToolCallEndTag)
return sb.String()
}
func (r *LFM2Renderer) renderMessageContent(message api.Message) string {
content := lfm2RenderContent(message.Content, r.useImgTags)
if len(message.Images) == 0 {
return content
}
// chatPrompt may already have inserted [img] / [img-n] placeholders.
if strings.Contains(content, "[img]") || strings.Contains(content, "[img-") || strings.Contains(content, "<image>") {
return content
}
var sb strings.Builder
placeholder := lfm2ImagePlaceholder(r.useImgTags)
for range message.Images {
sb.WriteString(placeholder)
}
sb.WriteString(content)
return sb.String()
}
func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
// Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer
// Follow Liquid tool-use formatting for LFM2 tool wrappers.
sb.WriteString(lfm2BOSToken)
// Extract first system message if present (to combine with tools)
var firstSystemContent string
startIdx := 0
if len(messages) > 0 && messages[0].Role == "system" {
firstSystemContent = messages[0].Content
firstSystemContent = lfm2RenderSystemContent(messages[0].Content)
startIdx = 1
}
@@ -29,18 +231,17 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
if firstSystemContent != "" {
firstSystemContent += "\n"
}
firstSystemContent += "List of tools: ["
firstSystemContent += "List of tools: "
firstSystemContent += lfm2ToolListStartTag
firstSystemContent += "["
for i, tool := range tools {
toolJSON, err := json.Marshal(tool)
if err != nil {
return "", err
}
firstSystemContent += string(toolJSON)
firstSystemContent += lfm2JSON(lfm2ToolSchema(tool))
if i < len(tools)-1 {
firstSystemContent += ", "
}
}
firstSystemContent += "]"
firstSystemContent += lfm2ToolListEndTag
}
// Output first system block if it has content
@@ -50,6 +251,8 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
sb.WriteString("<|im_end|>\n")
}
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
// Find the index of the last assistant message for thinking stripping
lastAssistantIndex := -1
for i := len(messages) - 1; i >= startIdx; i-- {
@@ -59,85 +262,47 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
}
}
// Track whether we need to add generation prompt
needsGenerationPrompt := len(messages) > 0
for i := startIdx; i < len(messages); i++ {
message := messages[i]
switch message.Role {
case "system":
// Additional system messages (after the first) are rendered normally
sb.WriteString("<|im_start|>system\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
lastMessage := i == len(messages)-1
prefill := lastMessage && message.Role == "assistant"
case "user":
sb.WriteString("<|im_start|>user\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true
sb.WriteString("<|im_start|>")
sb.WriteString(message.Role)
sb.WriteString("\n")
case "assistant":
sb.WriteString("<|im_start|>assistant\n")
// Check if this is the last assistant message
isLastAssistant := i == lastAssistantIndex
// Process content (may need thinking stripped)
content := message.Content
// Handle thinking tags in assistant content
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
if strings.Contains(content, "</think>") {
parts := strings.SplitN(content, "</think>", 2)
if len(parts) > 1 {
if !isLastAssistant && !keepPastThinking {
// Strip thinking entirely for past assistant messages
content = strings.TrimSpace(parts[1])
} else {
// Preserve thinking but trim whitespace after </think>
content = parts[0] + "</think>" + strings.TrimLeft(parts[1], " \t\n\r")
}
}
content := r.renderMessageContent(message)
if message.Role == "assistant" && !keepPastThinking && i != lastAssistantIndex {
if idx := strings.LastIndex(content, "</think>"); idx >= 0 {
content = strings.TrimSpace(content[idx+len("</think>"):])
}
if len(message.ToolCalls) > 0 {
// Assistant with tool calls - write content first (if any after stripping)
if content != "" {
sb.WriteString(content)
}
for _, toolCall := range message.ToolCalls {
sb.WriteString("<|tool_call_start|>")
toolCallJSON := map[string]any{
"name": toolCall.Function.Name,
"arguments": toolCall.Function.Arguments,
}
callJSON, _ := json.Marshal(toolCallJSON)
sb.WriteString(string(callJSON))
sb.WriteString("<|tool_call_end|>")
}
}
if message.Role == "assistant" && len(message.ToolCalls) > 0 && !strings.Contains(content, lfm2ToolCallStartTag) {
if strings.TrimSpace(content) == "" {
content = lfm2RenderToolCalls(message.ToolCalls) + content
} else {
sb.WriteString(content)
content = lfm2RenderToolCalls(message.ToolCalls) + "\n" + content
}
}
if message.Role == "tool" && !strings.Contains(content, lfm2ToolResponseStartTag) {
content = lfm2ToolResponseStartTag + content + lfm2ToolResponseEndTag
}
sb.WriteString(content)
if !prefill {
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true
case "tool":
// Tool responses are rendered as plain messages per the chat template
sb.WriteString("<|im_start|>tool\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true
}
}
// Add generation prompt
needsGenerationPrompt := true
if len(messages) > 0 && messages[len(messages)-1].Role == "assistant" {
needsGenerationPrompt = false
}
if needsGenerationPrompt {
// RenderWithRenderer uses add_generation_prompt=true for chat rendering,
// unless we're prefilling a trailing assistant message.
sb.WriteString("<|im_start|>assistant\n")
// Note: Model is a "thinking-only" model - it will output <think> itself
// We don't add <think> tag to the prompt
}
return sb.String(), nil

View File

@@ -8,73 +8,136 @@ import (
"github.com/ollama/ollama/api"
)
func TestLFM2Renderer(t *testing.T) {
func TestLFM2Renderer_ChatTemplateParity(t *testing.T) {
tests := []struct {
name string
renderer *LFM2Renderer
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
}{
{
name: "basic user message",
name: "user_only",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "user", Content: "Hello!"},
{Role: "user", Content: "Hello"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
expected: "<|startoftext|><|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple system messages rendered separately",
messages: []api.Message{
{Role: "system", Content: "First instruction."},
{Role: "system", Content: "Second instruction."},
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multi-turn conversation",
messages: []api.Message{
{Role: "user", Content: "What is 2+2?"},
{Role: "assistant", Content: "The answer is 4."},
{Role: "user", Content: "Thanks!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "only system message",
name: "system_and_user",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hi"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n",
expected: "<|startoftext|><|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n",
},
{
// When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false)
name: "user-assistant-user: last assistant preserves thinking",
name: "tools_without_system",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reasoning</think>A1"},
{Role: "user", Content: "Q2"},
{Role: "user", Content: "Use tools"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reasoning</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n",
expected: "<|startoftext|><|im_start|>system\nList of tools: <|tool_list_start|>[{\"name\": \"get_weather\", \"parameters\": {\"type\": \"object\", \"properties\": null}}]<|tool_list_end|><|im_end|>\n" +
"<|im_start|>user\nUse tools<|im_end|>\n<|im_start|>assistant\n",
},
{
// With two assistants, first is stripped (not last), second preserved (is last)
name: "multi-turn thinking: first stripped, second preserved",
name: "first_system_combined_with_tools",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "system", Content: "Follow instructions."},
{Role: "user", Content: "Do work"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "tool_a",
Parameters: api.ToolFunctionParameters{
Type: "object",
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "tool_b",
Parameters: api.ToolFunctionParameters{
Type: "object",
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|startoftext|><|im_start|>system\nFollow instructions.\nList of tools: <|tool_list_start|>[{\"name\": \"tool_a\", \"parameters\": {\"type\": \"object\", \"properties\": null}}, {\"name\": \"tool_b\", \"parameters\": {\"type\": \"object\", \"properties\": null}}]<|tool_list_end|><|im_end|>\n" +
"<|im_start|>user\nDo work<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "assistant_tool_calls_and_tool_responses_are_rendered",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "user", Content: "Call a tool"},
{
Role: "assistant",
Content: "",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|startoftext|><|im_start|>user\nCall a tool<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|><|im_end|>\n<|im_start|>tool\n<|tool_response_start|>22C<|tool_response_end|><|im_end|>\n<|im_start|>assistant\n",
},
{
name: "assistant_tool_calls_with_content_preserves_both",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "user", Content: "Call a tool"},
{
Role: "assistant",
Content: "Checking now.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|startoftext|><|im_start|>user\nCall a tool<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>\nChecking now.",
},
{
name: "thinking_strips_non_last_assistant_when_disabled",
renderer: &LFM2Renderer{IsThinking: true},
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reason1</think>A1"},
@@ -82,11 +145,11 @@ func TestLFM2Renderer(t *testing.T) {
{Role: "assistant", Content: "<think>reason2</think>A2"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
expected: "<|startoftext|><|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2",
},
{
// With thinking enabled (keep_past_thinking=true), both preserved
name: "multi-turn thinking: both preserved when thinking enabled",
name: "thinking_preserves_past_assistant_when_enabled",
renderer: &LFM2Renderer{IsThinking: true},
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reason1</think>A1"},
@@ -94,334 +157,137 @@ func TestLFM2Renderer(t *testing.T) {
{Role: "assistant", Content: "<think>reason2</think>A2"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
expected: "<|startoftext|><|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2",
},
{
name: "assistant with tool calls",
name: "arbitrary_roles_are_rendered_verbatim",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "developer", Content: "Do X"},
{Role: "user", Content: "Hi"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
expected: "<|startoftext|><|im_start|>developer\nDo X<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "assistant with content and tool calls",
name: "empty_messages_still_add_generation_prompt",
renderer: &LFM2Renderer{IsThinking: false},
messages: nil,
thinkValue: &api.ThinkValue{Value: false},
expected: "<|startoftext|><|im_start|>assistant\n",
},
{
name: "assistant_prefill_no_generation_prompt",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: "Let me check.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tool response",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{Role: "assistant", Content: "Let me check."},
{Role: "tool", Content: "22C, Sunny"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple tool calls",
messages: []api.Message{
{Role: "user", Content: "Get weather for Paris and London"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "London",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tools definitions with system message",
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "What's the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tools definitions without system message",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "multiple tools without system message",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "get_time",
Description: "Get time",
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nList of tools: [{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather\",\"parameters\":{\"type\":\"\",\"properties\":null}}}, {\"type\":\"function\",\"function\":{\"name\":\"get_time\",\"description\":\"Get time\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "user-tool sequence",
messages: []api.Message{
{Role: "user", Content: "Check weather"},
{Role: "tool", Content: "22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "full tool call cycle",
messages: []api.Message{
{Role: "user", Content: "Check weather"},
{Role: "assistant", Content: "Let me check"},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "It's 22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "unicode content",
messages: []api.Message{
{Role: "user", Content: "你好世界! مرحبا 🌍"},
{Role: "assistant", Content: "Hello! 👋"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "newlines in content",
messages: []api.Message{
{Role: "user", Content: "Line 1\nLine 2\n\nLine 4"},
{Role: "assistant", Content: "Response with\nmultiple\nlines"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "empty assistant content",
messages: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: ""},
{Role: "user", Content: "OK"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n",
},
{
// Generation prompt does NOT include <think> - model outputs it
name: "generation prompt has no think tag",
messages: []api.Message{
{Role: "user", Content: "Think hard"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n",
},
{
// Interleaved: thinking before tool call - last assistant preserves thinking
name: "thinking before tool call (last assistant)",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>I need to check the weather</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>I need to check the weather</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
},
{
// Two assistants with tool calls - first has thinking stripped
name: "two assistants with tools: first thinking stripped",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>checking</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
},
{
// Two assistants with tools - both preserved when thinking enabled
name: "two assistants with tools: both preserved when thinking enabled",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>checking</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>checking</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
},
{
// Content before thinking before tool call
name: "content then thinking then tool call",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "Let me check.<think>Using weather API</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<think>Using weather API</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
expected: "<|startoftext|><|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\nHello",
},
}
renderer := &LFM2Renderer{IsThinking: true}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
rendered, err := tt.renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatalf("Render() error = %v", err)
}
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestLFM2Renderer_Images(t *testing.T) {
tests := []struct {
name string
renderer *LFM2Renderer
message api.Message
expected string
}{
{
name: "single_image_default_placeholder",
renderer: &LFM2Renderer{},
message: api.Message{
Role: "user",
Content: "Describe this image.",
Images: []api.ImageData{api.ImageData("img1")},
},
expected: "<|startoftext|><|im_start|>user\n<image>Describe this image.<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple_images_default_placeholder",
renderer: &LFM2Renderer{},
message: api.Message{
Role: "user",
Content: "Describe these images.",
Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")},
},
expected: "<|startoftext|><|im_start|>user\n<image><image>Describe these images.<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "single_image_img_tag_placeholder",
renderer: &LFM2Renderer{useImgTags: true},
message: api.Message{
Role: "user",
Content: "Describe this image.",
Images: []api.ImageData{api.ImageData("img1")},
},
expected: "<|startoftext|><|im_start|>user\n[img]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "existing_indexed_img_placeholder_not_duplicated",
renderer: &LFM2Renderer{useImgTags: true},
message: api.Message{
Role: "user",
Content: "[img-0]Describe this image.",
Images: []api.ImageData{api.ImageData("img1")},
},
expected: "<|startoftext|><|im_start|>user\n[img-0]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "existing_template_image_placeholder_not_duplicated",
renderer: &LFM2Renderer{},
message: api.Message{
Role: "user",
Content: "<image>Describe this image.",
Images: []api.ImageData{api.ImageData("img1")},
},
expected: "<|startoftext|><|im_start|>user\n<image>Describe this image.<|im_end|>\n<|im_start|>assistant\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.renderer.Render([]api.Message{tt.message}, nil, nil)
if err != nil {
t.Fatalf("Render() error = %v", err)
}
if diff := cmp.Diff(tt.expected, got); diff != "" {
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestLFM2Renderer_JSONFormatting(t *testing.T) {
tool := api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: "echo",
Description: "<html>",
Parameters: api.ToolFunctionParameters{
Type: "object",
},
},
}
got := lfm2JSON(tool)
want := "{\"type\": \"function\", \"function\": {\"name\": \"echo\", \"description\": \"<html>\", \"parameters\": {\"type\": \"object\", \"properties\": null}}}"
if diff := cmp.Diff(want, got); diff != "" {
t.Fatalf("lfm2JSON mismatch (-want +got):\n%s", diff)
}
}

View File

@@ -85,9 +85,9 @@ func rendererForName(name string) Renderer {
case "glm-ocr":
return &GlmOcrRenderer{}
case "lfm2":
return &LFM2Renderer{IsThinking: false}
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
case "lfm2-thinking":
return &LFM2Renderer{IsThinking: true}
return &LFM2Renderer{IsThinking: true, useImgTags: RenderImgTags}
default:
return nil
}

View File

@@ -447,7 +447,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
// Some architectures are not safe with num_parallel > 1.
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
numParallel = 1
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
}

View File

@@ -37,9 +37,11 @@ func QuantizationParams(quantization string) (groupSize, bits int, mode string)
case "MXFP8":
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbias)
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8", "":
case "FP8", "Q8", "INT8":
// 8-bit quantization with affine mode (default for quantized models)
return 64, 8, "affine"
case "":
return 0, 0, ""
default:
return 32, 8, "affine" // Default to affine
}

View File

@@ -3,94 +3,65 @@
package mlxrunner
import (
"fmt"
"log/slog"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// CacheEntry stores a single sequence
type CacheEntry struct {
Caches []cache.Cache
Count int
Entries map[int32]*CacheEntry
Tokens []int32
Caches []cache.Cache
}
func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
current := &CacheEntry{Entries: s.CacheEntries}
index, cacheIndex := 0, -1
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
break
}
current = current.Entries[token]
if len(current.Caches) > 0 {
cacheIndex = index
}
index += 1
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
if r.cache == nil {
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
}
if cacheIndex == len(tokens)-1 {
slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
return current.Caches, []int32{}
} else if cacheIndex > 1 {
slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
return current.Caches, tokens[cacheIndex+1:]
} else if index > 0 && cacheIndex < 0 {
type stackItem struct {
entry *CacheEntry
tokens []int32
}
var best, item stackItem
stack := []stackItem{{entry: current, tokens: []int32{}}}
for len(stack) > 0 {
item, stack = stack[len(stack)-1], stack[:len(stack)-1]
if len(item.entry.Caches) > 0 {
if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
best = item
}
} else {
for token, entry := range item.entry.Entries {
stack = append(stack, stackItem{
entry: entry,
tokens: append(item.tokens, token),
})
}
}
}
prefix := min(len(tokens)-1, index)
caches := make([]cache.Cache, len(best.entry.Caches))
trim := len(best.tokens)+1
for i := range caches {
caches[i] = best.entry.Caches[i].Clone()
caches[i].Trim(trim)
}
slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
return caches, tokens[prefix:]
// Find longest common prefix
prefix := 0
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
prefix++
}
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
switch {
case prefix == 0:
for _, c := range r.cache.Caches {
c.Free()
}
r.cache = nil
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
case prefix < len(r.cache.Tokens):
trim := len(r.cache.Tokens) - prefix
for _, c := range r.cache.Caches {
c.Trim(trim)
}
r.cache.Tokens = r.cache.Tokens[:prefix]
}
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
return r.cache.Caches, tokens[prefix:]
}
func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
current := &CacheEntry{Entries: s.CacheEntries}
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
current.Entries[token] = &CacheEntry{
Entries: make(map[int32]*CacheEntry),
}
}
current = current.Entries[token]
}
if len(current.Caches) > 0 {
current.Count += 1
} else {
current.Caches = caches
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
r.cache = &CacheEntry{
Tokens: tokens,
Caches: caches,
}
}
func (c *CacheEntry) LogCache() {
var totalBytes int
for _, kv := range c.Caches {
k, v := kv.State()
totalBytes += k.NumBytes() + v.NumBytes()
}
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
}

View File

@@ -3,8 +3,7 @@
package cache
import (
"log/slog"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
@@ -13,6 +12,7 @@ type Cache interface {
State() (keys, values *mlx.Array)
Trim(int) int
Clone() Cache
Free()
Offset() int
Len() int
}
@@ -47,6 +47,7 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
mlx.Pin(c.keys, c.values)
}
}
@@ -73,12 +74,19 @@ func (c *KVCache) Trim(n int) int {
}
func (c *KVCache) Clone() Cache {
return &KVCache{
clone := &KVCache{
keys: c.keys.Clone(),
values: c.values.Clone(),
offset: c.offset,
step: c.step,
}
mlx.Pin(clone.keys, clone.values)
return clone
}
func (c *KVCache) Free() {
mlx.Unpin(c.keys, c.values)
c.keys, c.values = nil, nil
}
func (c *KVCache) Offset() int { return c.offset }
@@ -104,9 +112,10 @@ func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
}
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
logutil.Trace("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
if c.keys == nil {
c.keys, c.values = keys, values
c.keys, c.values = keys.Clone(), values.Clone()
mlx.Pin(c.keys, c.values)
} else {
if c.idx < c.keys.Dim(2) {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
@@ -130,7 +139,7 @@ func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV
}
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
slog.Debug("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
logutil.Trace("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
@@ -145,6 +154,7 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
mlx.Pin(c.keys, c.values)
}
c.idx = prev
}

View File

@@ -119,16 +119,13 @@ func NewClient(modelName string) (*Client, error) {
stdout, _ := cmd.StdoutPipe()
stderr, _ := cmd.StderrPipe()
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
slog.Info("mlx-runner", "msg", scanner.Text())
}
io.Copy(os.Stderr, stdout) //nolint:errcheck
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
slog.Warn("mlx-runner", "msg", line)
fmt.Fprintln(os.Stderr, line)
c.lastErrLock.Lock()
c.lastErr = line
c.lastErrLock.Unlock()

View File

@@ -7,48 +7,29 @@ import "C"
import (
"encoding/binary"
"fmt"
"log/slog"
"reflect"
"sort"
"strings"
"time"
"unsafe"
"github.com/ollama/ollama/logutil"
)
type tensorDesc struct {
name string
inputs []*Array
numRefs int
}
func (d tensorDesc) LogValue() slog.Value {
return slog.GroupValue(
slog.String("name", d.name),
slog.Int("inputs", len(d.inputs)),
slog.Int("num_refs", d.numRefs),
)
}
type Array struct {
ctx C.mlx_array
desc tensorDesc
ctx C.mlx_array
name string
pinned bool
}
var arrays []*Array
// constructor utilities
func New(name string, inputs ...*Array) *Array {
t := &Array{
desc: tensorDesc{
name: name,
inputs: inputs,
},
}
for _, input := range inputs {
input.desc.numRefs++
}
logutil.Trace("New", "t", t)
func New(name string) *Array {
t := &Array{name: name}
arrays = append(arrays, t)
return t
}
@@ -133,18 +114,51 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
}
func (t *Array) Set(other *Array) {
Free(t.desc.inputs...)
other.desc.numRefs++
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx)
}
func (t *Array) Clone() *Array {
tt := New(t.desc.name, t.desc.inputs...)
tt := New(t.name)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
// lifecycle utilities
// Pin marks arrays as in-use so they are retained during Sweep.
func Pin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned = true
}
}
}
// Unpin marks arrays as no longer in-use, allowing Sweep to free them.
func Unpin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned = false
}
}
}
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
// free them when there are no other references, including dependencies in the graph.
func Sweep() {
n := 0
for _, t := range arrays {
if t.pinned && t.Valid() {
arrays[n] = t
n++
} else if t.Valid() {
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
arrays = arrays[:n]
}
// misc. utilities
func (t *Array) Valid() bool {
@@ -159,7 +173,10 @@ func (t *Array) String() string {
}
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{slog.Any("", t.desc)}
attrs := []slog.Attr{
slog.String("name", t.name),
slog.Bool("pinned", t.pinned),
}
if t.Valid() {
attrs = append(attrs,
slog.Any("dtype", t.DType()),
@@ -238,37 +255,15 @@ func (t Array) Save(name string) error {
return nil
}
func Free(s ...*Array) (n int) {
now := time.Now()
defer func() {
if n > 0 {
logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now))
}
}()
// LogArrays logs all live arrays, sorted by size
func LogArrays() {
sort.Slice(arrays, func(i, j int) bool {
return arrays[i].NumBytes() > arrays[j].NumBytes()
})
free := make([]*Array, 0, 8192)
fn := func(t *Array) {
if t.Valid() {
t.desc.numRefs--
if t.desc.numRefs <= 0 {
free = append(free, t.desc.inputs...)
logutil.Trace("Free", "t", t)
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
for _, t := range arrays {
nb := t.NumBytes()
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims()))
}
for _, t := range s {
fn(t)
}
for len(free) > 0 {
tail := free[len(free)-1]
free = free[:len(free)-1]
fn(tail)
}
return n
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
}

View File

@@ -20,7 +20,7 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", query, key, value, mask, sinks)
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
@@ -31,7 +31,7 @@ type LayerNorm struct {
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_LAYERNORM", x)
out := New("FAST_LAYERNORM")
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
@@ -41,7 +41,7 @@ type RMSNorm struct {
}
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
@@ -55,7 +55,7 @@ type RoPE struct {
func (r RoPE) Forward(t *Array, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", t, freqs)
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
t.ctx,

View File

@@ -37,7 +37,9 @@ func Load(path string) iter.Seq2[string, *Array] {
}
name := C.GoString(key)
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
arr := New(name)
arr.ctx = value
if !yield(name, arr) {
break
}
}

View File

@@ -10,43 +10,43 @@ import (
)
func (t *Array) Abs() *Array {
out := New("ABS", t)
out := New("ABS")
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Add(other *Array) *Array {
out := New("ADD", t, other)
out := New("ADD")
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
out := New("ADDMM", t, a, b)
out := New("ADDMM")
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
return out
}
func (t *Array) Argmax(axis int, keepDims bool) *Array {
out := New("ARGMAX", t)
out := New("ARGMAX")
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
out := New("ARGPARTITION", t)
out := New("ARGPARTITION")
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) ArgsortAxis(axis int) *Array {
out := New("ARGSORT_AXIS", t)
out := New("ARGSORT_AXIS")
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) AsType(dtype DType) *Array {
out := New("AS_TYPE", t)
out := New("AS_TYPE")
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
@@ -62,7 +62,7 @@ func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
cStrides[i] = C.int64_t(s)
}
out := New("AS_STRIDED", t)
out := New("AS_STRIDED")
C.mlx_as_strided(
&out.ctx, t.ctx,
unsafe.SliceData(cShape), C.size_t(len(shape)),
@@ -82,31 +82,31 @@ func (t *Array) Concatenate(axis int, others ...*Array) *Array {
C.mlx_vector_array_append_value(vector, other.ctx)
}
out := New("CONCATENATE", s...)
out := New("CONCATENATE")
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Divide(other *Array) *Array {
out := New("DIVIDE", t, other)
out := New("DIVIDE")
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS", t)
out := New("EXPAND_DIMS")
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Flatten(startAxis, endAxis int) *Array {
out := New("FLATTEN", t)
out := New("FLATTEN")
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
return out
}
func (t *Array) FloorDivide(other *Array) *Array {
out := New("FLOOR_DIVIDE", t, other)
out := New("FLOOR_DIVIDE")
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
@@ -118,43 +118,43 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
if rhs == nil {
rhs = New("")
}
out := New("GATHER_MM", t, other, lhs, rhs)
out := New("GATHER_MM")
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
return out
}
func (t *Array) Logsumexp(keepDims bool) *Array {
out := New("LOGSUMEXP", t)
out := New("LOGSUMEXP")
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL", t, other)
out := New("MATMUL")
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Multiply(other *Array) *Array {
out := New("MULTIPLY", t, other)
out := New("MULTIPLY")
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Negative() *Array {
out := New("NEGATIVE", t)
out := New("NEGATIVE")
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Power(exponent *Array) *Array {
out := New("POWER", t, exponent)
out := New("POWER")
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
return out
}
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS", t, indices, values)
out := New("PUT_ALONG_AXIS")
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
@@ -165,25 +165,25 @@ func (t *Array) Reshape(axes ...int) *Array {
cAxes[i] = C.int(axes[i])
}
out := New("RESHAPE", t)
out := New("RESHAPE")
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func (t *Array) Sigmoid() *Array {
out := New("SIGMOID", t)
out := New("SIGMOID")
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Sqrt() *Array {
out := New("SQRT", t)
out := New("SQRT")
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Squeeze(axis int) *Array {
out := New("SQUEEZE", t)
out := New("SQUEEZE")
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
@@ -198,37 +198,37 @@ func (t *Array) StackAxis(axis int, others ...*Array) *Array {
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK_AXIS", append(others, t)...)
out := New("STACK_AXIS")
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Subtract(other *Array) *Array {
out := New("SUBTRACT", t, other)
out := New("SUBTRACT")
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
out := New("SUM_AXIS", t)
out := New("SUM_AXIS")
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
out := New("TAKE_AXIS", t, indices)
out := New("TAKE_AXIS")
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
out := New("TAKE_ALONG_AXIS", t, indices)
out := New("TAKE_ALONG_AXIS")
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Tanh() *Array {
out := New("TANH", t)
out := New("TANH")
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
@@ -239,7 +239,7 @@ func (t *Array) Transpose(axes ...int) *Array {
cAxes[i] = C.int(axis)
}
out := New("TRANSPOSE", t)
out := New("TRANSPOSE")
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}

View File

@@ -41,14 +41,12 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
optDtype := C.mlx_optional_dtype{has_value: false}
inputs := []*Array{w, scales}
var b C.mlx_array
if biases != nil {
b = biases.ctx
inputs = append(inputs, biases)
}
out := New("DEQUANTIZE", inputs...)
out := New("DEQUANTIZE")
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
return out
}
@@ -59,14 +57,12 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
inputs := []*Array{x, w, scales}
var b C.mlx_array
if biases != nil {
b = biases.ctx
inputs = append(inputs, biases)
}
out := New("QUANTIZED_MATMUL", inputs...)
out := New("QUANTIZED_MATMUL")
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
return out
}
@@ -77,22 +73,18 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
inputs := []*Array{x, w, scales}
var b, lhs, rhs C.mlx_array
if biases != nil {
b = biases.ctx
inputs = append(inputs, biases)
}
if lhsIndices != nil {
lhs = lhsIndices.ctx
inputs = append(inputs, lhsIndices)
}
if rhsIndices != nil {
rhs = rhsIndices.ctx
inputs = append(inputs, rhsIndices)
}
out := New("GATHER_QMM", inputs...)
out := New("GATHER_QMM")
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
return out
}
@@ -104,7 +96,7 @@ func Tile(a *Array, reps []int32) *Array {
for i, r := range reps {
cReps[i] = C.int(r)
}
out := New("TILE", a)
out := New("TILE")
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
return out
}
@@ -116,7 +108,7 @@ func Tri(n, m int32, k int) *Array {
}
func Where(condition, a, b *Array) *Array {
out := New("WHERE", condition, a, b)
out := New("WHERE")
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
@@ -131,7 +123,7 @@ func Stack(arrays []*Array, axis int) *Array {
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK", arrays...)
out := New("STACK")
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
@@ -153,13 +145,13 @@ func Take(a *Array, indices *Array, axis int) *Array {
}
func RSqrt(a *Array) *Array {
out := New("RSQRT", a)
out := New("RSQRT")
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Mean(a *Array, axis int, keepDims bool) *Array {
out := New("MEAN_AXIS", a)
out := New("MEAN_AXIS")
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
@@ -235,7 +227,7 @@ func SliceStartStop(a *Array, start, stop []int32) *Array {
cStop[i] = C.int(stop[i])
cStrides[i] = 1
}
out := New("SLICE", a)
out := New("SLICE")
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
return out
}
@@ -257,7 +249,7 @@ func SiLU(a *Array) *Array {
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", x, freqs)
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
x.ctx,
@@ -289,13 +281,13 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", q, k, v, mask, sinks)
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
func RMSNormFn(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
@@ -322,7 +314,7 @@ func scalarWithDtype(s float32, a *Array) C.mlx_array {
func AddScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("ADD_SCALAR", a)
out := New("ADD_SCALAR")
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
@@ -330,7 +322,7 @@ func AddScalar(a *Array, s float32) *Array {
func MulScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("MUL_SCALAR", a)
out := New("MUL_SCALAR")
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
@@ -338,7 +330,7 @@ func MulScalar(a *Array, s float32) *Array {
func DivScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("DIV_SCALAR", a)
out := New("DIV_SCALAR")
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out

View File

@@ -7,7 +7,7 @@ import "C"
func (t *Array) Categorical(axis int) *Array {
key := New("")
out := New("", t, key)
out := New("")
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
return out
}

View File

@@ -61,7 +61,7 @@ func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
func (t *Array) Slice(slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE", t)
out := New("SLICE")
C.mlx_slice(
&out.ctx, t.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
@@ -74,7 +74,7 @@ func (t *Array) Slice(slices ...slice) *Array {
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE_UPDATE", t, other)
out := New("SLICE_UPDATE")
C.mlx_slice_update(
&out.ctx, t.ctx, other.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),

View File

@@ -78,8 +78,21 @@ func New(root *model.Root) (Model, error) {
return fn(root)
}
// Weights returns the model's LoadWeights method, which encapsulates all
// weight assignment and post-processing (MLA absorption, expert stacking).
// Weights returns a function that loads model weights, then pins all
// arrays reachable from the model struct and sweeps everything else.
func Weights(m Model) func(map[string]*mlx.Array) error {
return m.LoadWeights
return func(tensors map[string]*mlx.Array) error {
if err := m.LoadWeights(tensors); err != nil {
return err
}
collected := mlx.Collect(m)
for _, arr := range collected {
mlx.Pin(arr)
}
mlx.Sweep()
mlx.Eval(collected...)
return nil
}
}

View File

@@ -17,8 +17,10 @@ func QuantizationParams(quantization string) (groupSize, bits int, mode string)
return 32, 4, "affine"
case "MXFP8":
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8", "":
case "FP8", "Q8", "INT8":
return 64, 8, "affine"
case "":
return 0, 0, ""
default:
return 32, 8, "affine"
}

View File

@@ -4,10 +4,12 @@ package mlxrunner
import (
"bytes"
"context"
"errors"
"log/slog"
"time"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
@@ -45,8 +47,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
n := min(2<<10, total-processed-1)
temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
defer mlx.Free(temp)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
mlx.Sweep()
mlx.Eval(func() []*mlx.Array {
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
@@ -65,11 +67,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true))
return request.Sample(logprobs), logprobs
sample := request.Sample(logprobs)
mlx.Pin(sample, logprobs)
mlx.Sweep()
mlx.AsyncEval(sample, logprobs)
return sample, logprobs
}
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
mlx.AsyncEval(sample, logprobs)
var b bytes.Buffer
@@ -78,7 +85,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
nextSample, nextLogprobs := step(sample)
mlx.AsyncEval(nextSample, nextLogprobs)
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
@@ -91,6 +97,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
outputs = append(outputs, output)
if r.Tokenizer.IsEOS(output) {
mlx.Unpin(nextSample, nextLogprobs)
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
@@ -102,7 +109,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
Token: int(output),
}
mlx.Free(sample, logprobs)
mlx.Unpin(sample, logprobs)
if i%256 == 0 {
mlx.ClearCache()
}
@@ -110,10 +117,19 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
sample, logprobs = nextSample, nextLogprobs
}
mlx.Free(sample, logprobs)
mlx.Unpin(sample, logprobs)
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
r.InsertCache(append(inputs, outputs...), caches)
mlx.Sweep()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
if r.cache != nil {
r.cache.LogCache()
}
}
return nil
}

View File

@@ -58,10 +58,10 @@ type Response struct {
}
type Runner struct {
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
CacheEntries map[int32]*CacheEntry
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
cache *CacheEntry
}
func (r *Runner) Load(modelName string) error {

View File

@@ -40,8 +40,7 @@ func Execute(args []string) error {
flagSet.Parse(args)
runner := Runner{
Requests: make(chan Request),
CacheEntries: make(map[int32]*CacheEntry),
Requests: make(chan Request),
}
if err := runner.Load(modelName); err != nil {

View File

@@ -401,9 +401,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
if m.NormScaled == nil {
return fmt.Errorf("missing precomputed final norm weight")
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}

View File

@@ -702,9 +702,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
}
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}

View File

@@ -235,9 +235,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.Layers[i] = layer
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}

View File

@@ -252,9 +252,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.Layers[i] = layer
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}