mirror of
https://github.com/ollama/ollama.git
synced 2026-02-23 18:46:44 -05:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3323c1d319 | ||
|
|
f20dc6b698 | ||
|
|
4b2ac1f369 | ||
|
|
8daf47fb3a | ||
|
|
6c980579cd | ||
|
|
5c73c4e2ee | ||
|
|
5daf59cc66 | ||
|
|
0ade9205cc |
@@ -253,6 +253,8 @@ func main() {
|
||||
done <- osrv.Run(octx)
|
||||
}()
|
||||
|
||||
upd := &updater.Updater{Store: st}
|
||||
|
||||
uiServer := ui.Server{
|
||||
Token: token,
|
||||
Restart: func() {
|
||||
@@ -267,6 +269,10 @@ func main() {
|
||||
ToolRegistry: toolRegistry,
|
||||
Dev: devMode,
|
||||
Logger: slog.Default(),
|
||||
Updater: upd,
|
||||
UpdateAvailableFunc: func() {
|
||||
UpdateAvailable("")
|
||||
},
|
||||
}
|
||||
|
||||
srv := &http.Server{
|
||||
@@ -284,8 +290,13 @@ func main() {
|
||||
slog.Debug("background desktop server done")
|
||||
}()
|
||||
|
||||
updater := &updater.Updater{Store: st}
|
||||
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
||||
upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
||||
|
||||
// Check for pending updates on startup (show tray notification if update is ready)
|
||||
if updater.IsUpdatePending() {
|
||||
slog.Debug("update pending on startup, showing tray notification")
|
||||
UpdateAvailable("")
|
||||
}
|
||||
|
||||
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
||||
if err != nil {
|
||||
@@ -348,6 +359,18 @@ func startHiddenTasks() {
|
||||
// CLI triggered app startup use-case
|
||||
slog.Info("deferring pending update for fast startup")
|
||||
} else {
|
||||
// Check if auto-update is enabled before automatically upgrading
|
||||
st := &store.Store{}
|
||||
settings, err := st.Settings()
|
||||
if err != nil {
|
||||
slog.Warn("failed to load settings for upgrade check", "error", err)
|
||||
} else if !settings.AutoUpdateEnabled {
|
||||
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
|
||||
// Still show tray notification so user knows update is ready
|
||||
UpdateAvailable("")
|
||||
return
|
||||
}
|
||||
|
||||
if err := updater.DoUpgradeAtStartup(); err != nil {
|
||||
slog.Info("unable to perform upgrade at startup", "error", err)
|
||||
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization
|
||||
|
||||
@@ -9,12 +9,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// currentSchemaVersion defines the current database schema version.
|
||||
// Increment this when making schema changes that require migrations.
|
||||
const currentSchemaVersion = 14
|
||||
const currentSchemaVersion = 15
|
||||
|
||||
// database wraps the SQLite connection.
|
||||
// SQLite handles its own locking for concurrent access:
|
||||
@@ -86,6 +86,7 @@ func (db *database) init() error {
|
||||
think_level TEXT NOT NULL DEFAULT '',
|
||||
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
||||
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
|
||||
schema_version INTEGER NOT NULL DEFAULT %d
|
||||
);
|
||||
|
||||
@@ -257,6 +258,12 @@ func (db *database) migrate() error {
|
||||
return fmt.Errorf("migrate v13 to v14: %w", err)
|
||||
}
|
||||
version = 14
|
||||
case 14:
|
||||
// add auto_update_enabled column to settings table
|
||||
if err := db.migrateV14ToV15(); err != nil {
|
||||
return fmt.Errorf("migrate v14 to v15: %w", err)
|
||||
}
|
||||
version = 15
|
||||
default:
|
||||
// If we have a version we don't recognize, just set it to current
|
||||
// This might happen during development
|
||||
@@ -496,6 +503,21 @@ func (db *database) migrateV13ToV14() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateV14ToV15 adds the auto_update_enabled column to the settings table
|
||||
func (db *database) migrateV14ToV15() error {
|
||||
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
|
||||
if err != nil && !duplicateColumnError(err) {
|
||||
return fmt.Errorf("add auto_update_enabled column: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 15`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update schema version: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
||||
func (db *database) cleanupOrphanedData() error {
|
||||
_, err := db.conn.Exec(`
|
||||
@@ -526,19 +548,11 @@ func (db *database) cleanupOrphanedData() error {
|
||||
}
|
||||
|
||||
func duplicateColumnError(err error) bool {
|
||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||
strings.Contains(sqlite3Err.Error(), "duplicate column name")
|
||||
}
|
||||
return false
|
||||
return err != nil && strings.Contains(err.Error(), "duplicate column name")
|
||||
}
|
||||
|
||||
func columnNotExists(err error) bool {
|
||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||
strings.Contains(sqlite3Err.Error(), "no such column")
|
||||
}
|
||||
return false
|
||||
return err != nil && strings.Contains(err.Error(), "no such column")
|
||||
}
|
||||
|
||||
func (db *database) getAllChats() ([]Chat, error) {
|
||||
@@ -1152,9 +1166,9 @@ func (db *database) getSettings() (Settings, error) {
|
||||
var s Settings
|
||||
|
||||
err := db.conn.QueryRow(`
|
||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
|
||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
|
||||
FROM settings
|
||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
|
||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
|
||||
if err != nil {
|
||||
return Settings{}, fmt.Errorf("get settings: %w", err)
|
||||
}
|
||||
@@ -1164,9 +1178,9 @@ func (db *database) getSettings() (Settings, error) {
|
||||
|
||||
func (db *database) setSettings(s Settings) error {
|
||||
_, err := db.conn.Exec(`
|
||||
UPDATE settings
|
||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
|
||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
|
||||
UPDATE settings
|
||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
|
||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set settings: %w", err)
|
||||
}
|
||||
|
||||
@@ -166,6 +166,9 @@ type Settings struct {
|
||||
|
||||
// SidebarOpen indicates if the chat sidebar is open
|
||||
SidebarOpen bool
|
||||
|
||||
// AutoUpdateEnabled indicates if automatic updates should be downloaded
|
||||
AutoUpdateEnabled bool
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
|
||||
@@ -414,6 +414,7 @@ export class Settings {
|
||||
ThinkLevel: string;
|
||||
SelectedModel: string;
|
||||
SidebarOpen: boolean;
|
||||
AutoUpdateEnabled: boolean;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
@@ -431,6 +432,7 @@ export class Settings {
|
||||
this.ThinkLevel = source["ThinkLevel"];
|
||||
this.SelectedModel = source["SelectedModel"];
|
||||
this.SidebarOpen = source["SidebarOpen"];
|
||||
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
|
||||
}
|
||||
}
|
||||
export class SettingsResponse {
|
||||
|
||||
@@ -17,7 +17,10 @@ import {
|
||||
} from "@/hooks/useChats";
|
||||
import { useNavigate } from "@tanstack/react-router";
|
||||
import { useSelectedModel } from "@/hooks/useSelectedModel";
|
||||
import { useHasVisionCapability } from "@/hooks/useModelCapabilities";
|
||||
import {
|
||||
useHasVisionCapability,
|
||||
useHasToolsCapability,
|
||||
} from "@/hooks/useModelCapabilities";
|
||||
import { useUser } from "@/hooks/useUser";
|
||||
import { DisplayLogin } from "@/components/DisplayLogin";
|
||||
import { ErrorEvent, Message } from "@/gotypes";
|
||||
@@ -149,12 +152,7 @@ function ChatForm({
|
||||
} = useSettings();
|
||||
const { cloudDisabled } = useCloudStatus();
|
||||
|
||||
// current supported models for web search
|
||||
const modelLower = selectedModel?.model.toLowerCase() || "";
|
||||
const supportsWebSearch =
|
||||
modelLower.startsWith("gpt-oss") ||
|
||||
modelLower.startsWith("qwen3") ||
|
||||
modelLower.startsWith("deepseek-v3");
|
||||
const supportsWebSearch = useHasToolsCapability(selectedModel?.model);
|
||||
// Use per-chat thinking level instead of global
|
||||
const thinkLevel: ThinkingLevel =
|
||||
settingsThinkLevel === "none" || !settingsThinkLevel
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
XMarkIcon,
|
||||
CogIcon,
|
||||
ArrowLeftIcon,
|
||||
ArrowDownTrayIcon,
|
||||
} from "@heroicons/react/20/solid";
|
||||
import { Settings as SettingsType } from "@/gotypes";
|
||||
import { useNavigate } from "@tanstack/react-router";
|
||||
@@ -440,6 +441,29 @@ export default function Settings() {
|
||||
</div>
|
||||
</Field>
|
||||
|
||||
{/* Auto Update */}
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
<div className="flex items-start space-x-3 flex-1">
|
||||
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
|
||||
<div>
|
||||
<Label>Auto-download updates</Label>
|
||||
<Description>
|
||||
{settings.AutoUpdateEnabled
|
||||
? "Automatically download updates when available."
|
||||
: "Updates will not be downloaded automatically."}
|
||||
</Description>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex-shrink-0">
|
||||
<Switch
|
||||
checked={settings.AutoUpdateEnabled}
|
||||
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Field>
|
||||
|
||||
{/* Expose Ollama */}
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
|
||||
@@ -20,3 +20,8 @@ export function useHasVisionCapability(modelName: string | undefined) {
|
||||
const { data: capabilitiesResponse } = useModelCapabilities(modelName);
|
||||
return capabilitiesResponse?.capabilities?.includes("vision") ?? false;
|
||||
}
|
||||
|
||||
export function useHasToolsCapability(modelName: string | undefined) {
|
||||
const { data: capabilitiesResponse } = useModelCapabilities(modelName);
|
||||
return capabilitiesResponse?.capabilities?.includes("tools") ?? false;
|
||||
}
|
||||
|
||||
39
app/ui/ui.go
39
app/ui/ui.go
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/ollama/ollama/app/tools"
|
||||
"github.com/ollama/ollama/app/types/not"
|
||||
"github.com/ollama/ollama/app/ui/responses"
|
||||
"github.com/ollama/ollama/app/updater"
|
||||
"github.com/ollama/ollama/app/version"
|
||||
ollamaAuth "github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -106,6 +107,10 @@ type Server struct {
|
||||
|
||||
// Dev is true if the server is running in development mode
|
||||
Dev bool
|
||||
|
||||
// Updater for checking and downloading updates
|
||||
Updater *updater.Updater
|
||||
UpdateAvailableFunc func()
|
||||
}
|
||||
|
||||
func (s *Server) log() *slog.Logger {
|
||||
@@ -829,8 +834,9 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
||||
|
||||
if !hasAttachments {
|
||||
WebSearchEnabled := req.WebSearch != nil && *req.WebSearch
|
||||
hasToolsCapability := slices.Contains(details.Capabilities, model.CapabilityTools)
|
||||
|
||||
if WebSearchEnabled {
|
||||
if WebSearchEnabled && hasToolsCapability {
|
||||
if supportsBrowserTools(req.Model) {
|
||||
browserState, ok := s.browserState(chat)
|
||||
if !ok {
|
||||
@@ -840,7 +846,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
||||
registry.Register(tools.NewBrowserSearch(browser))
|
||||
registry.Register(tools.NewBrowserOpen(browser))
|
||||
registry.Register(tools.NewBrowserFind(browser))
|
||||
} else if supportsWebSearchTools(req.Model) {
|
||||
} else {
|
||||
registry.Register(&tools.WebSearch{})
|
||||
registry.Register(&tools.WebFetch{})
|
||||
}
|
||||
@@ -1446,6 +1452,24 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
||||
return fmt.Errorf("failed to save settings: %w", err)
|
||||
}
|
||||
|
||||
// Handle auto-update toggle changes
|
||||
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
|
||||
if !settings.AutoUpdateEnabled {
|
||||
// Auto-update disabled: cancel any ongoing download
|
||||
if s.Updater != nil {
|
||||
s.Updater.CancelOngoingDownload()
|
||||
}
|
||||
} else {
|
||||
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
|
||||
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
|
||||
s.UpdateAvailableFunc()
|
||||
} else if s.Updater != nil {
|
||||
// Trigger the background checker to run immediately
|
||||
s.Updater.TriggerImmediateCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if old.ContextLength != settings.ContextLength ||
|
||||
old.Models != settings.Models ||
|
||||
old.Expose != settings.Expose {
|
||||
@@ -1648,17 +1672,6 @@ func supportsBrowserTools(model string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(model), "gpt-oss")
|
||||
}
|
||||
|
||||
// Web search tools are simpler, providing only basic web search and fetch capabilities (e.g., "web_search", "web_fetch") without simulating a browser. Currently only qwen3 and deepseek-v3 support web search tools.
|
||||
func supportsWebSearchTools(model string) bool {
|
||||
model = strings.ToLower(model)
|
||||
prefixes := []string{"qwen3", "deepseek-v3"}
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(model, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// buildChatRequest converts store.Chat to api.ChatRequest
|
||||
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
|
||||
|
||||
@@ -4,6 +4,7 @@ package ui
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -11,9 +12,11 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/app/store"
|
||||
"github.com/ollama/ollama/app/updater"
|
||||
)
|
||||
|
||||
func TestHandlePostApiSettings(t *testing.T) {
|
||||
@@ -522,3 +525,290 @@ func TestUserAgentTransport(t *testing.T) {
|
||||
|
||||
t.Logf("User-Agent transport successfully set: %s", receivedUA)
|
||||
}
|
||||
|
||||
func TestSupportsBrowserTools(t *testing.T) {
|
||||
tests := []struct {
|
||||
model string
|
||||
want bool
|
||||
}{
|
||||
{"gpt-oss", true},
|
||||
{"gpt-oss-latest", true},
|
||||
{"GPT-OSS", true},
|
||||
{"Gpt-Oss-v2", true},
|
||||
{"qwen3", false},
|
||||
{"deepseek-v3", false},
|
||||
{"llama3.3", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.model, func(t *testing.T) {
|
||||
if got := supportsBrowserTools(tt.model); got != tt.want {
|
||||
t.Errorf("supportsBrowserTools(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSearchToolRegistration(t *testing.T) {
|
||||
// Validates that the capability-gating logic in chat() correctly
|
||||
// decides which tools to register based on model capabilities and
|
||||
// the web search flag.
|
||||
tests := []struct {
|
||||
name string
|
||||
webSearchEnabled bool
|
||||
hasToolsCap bool
|
||||
model string
|
||||
wantBrowser bool // expects browser tools (gpt-oss)
|
||||
wantWebSearch bool // expects basic web search/fetch tools
|
||||
wantNone bool // expects no tools registered
|
||||
}{
|
||||
{
|
||||
name: "web search enabled with tools capability - browser model",
|
||||
webSearchEnabled: true,
|
||||
hasToolsCap: true,
|
||||
model: "gpt-oss-latest",
|
||||
wantBrowser: true,
|
||||
},
|
||||
{
|
||||
name: "web search enabled with tools capability - non-browser model",
|
||||
webSearchEnabled: true,
|
||||
hasToolsCap: true,
|
||||
model: "qwen3",
|
||||
wantWebSearch: true,
|
||||
},
|
||||
{
|
||||
name: "web search enabled without tools capability",
|
||||
webSearchEnabled: true,
|
||||
hasToolsCap: false,
|
||||
model: "llama3.3",
|
||||
wantNone: true,
|
||||
},
|
||||
{
|
||||
name: "web search disabled with tools capability",
|
||||
webSearchEnabled: false,
|
||||
hasToolsCap: true,
|
||||
model: "qwen3",
|
||||
wantNone: true,
|
||||
},
|
||||
{
|
||||
name: "web search disabled without tools capability",
|
||||
webSearchEnabled: false,
|
||||
hasToolsCap: false,
|
||||
model: "llama3.3",
|
||||
wantNone: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Replicate the decision logic from chat() handler
|
||||
gotBrowser := false
|
||||
gotWebSearch := false
|
||||
|
||||
if tt.webSearchEnabled && tt.hasToolsCap {
|
||||
if supportsBrowserTools(tt.model) {
|
||||
gotBrowser = true
|
||||
} else {
|
||||
gotWebSearch = true
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantBrowser && !gotBrowser {
|
||||
t.Error("expected browser tools to be registered")
|
||||
}
|
||||
if tt.wantWebSearch && !gotWebSearch {
|
||||
t.Error("expected web search tools to be registered")
|
||||
}
|
||||
if tt.wantNone && (gotBrowser || gotWebSearch) {
|
||||
t.Error("expected no tools to be registered")
|
||||
}
|
||||
if !tt.wantBrowser && gotBrowser {
|
||||
t.Error("unexpected browser tools registered")
|
||||
}
|
||||
if !tt.wantWebSearch && gotWebSearch {
|
||||
t.Error("unexpected web search tools registered")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingsToggleAutoUpdateOff_CancelsDownload(t *testing.T) {
|
||||
testStore := &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
// Start with auto-update enabled
|
||||
settings, err := testStore.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = true
|
||||
if err := testStore.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
upd := &updater.Updater{Store: &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
|
||||
}}
|
||||
defer upd.Store.Close()
|
||||
|
||||
// We can't easily mock CancelOngoingDownload, but we can verify
|
||||
// the full settings handler flow works without error
|
||||
server := &Server{
|
||||
Store: testStore,
|
||||
Restart: func() {},
|
||||
Updater: upd,
|
||||
}
|
||||
|
||||
// Disable auto-update via settings API
|
||||
settings.AutoUpdateEnabled = false
|
||||
body, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
if err := server.settings(rr, req); err != nil {
|
||||
t.Fatalf("settings() error = %v", err)
|
||||
}
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Verify settings were saved with auto-update disabled
|
||||
saved, err := testStore.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if saved.AutoUpdateEnabled {
|
||||
t.Fatal("expected AutoUpdateEnabled to be false after toggle off")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingsToggleAutoUpdateOn_WithPendingUpdate_ShowsNotification(t *testing.T) {
|
||||
testStore := &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
// Start with auto-update disabled
|
||||
settings, err := testStore.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = false
|
||||
if err := testStore.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Simulate that an update was previously downloaded
|
||||
oldVal := updater.UpdateDownloaded
|
||||
updater.UpdateDownloaded = true
|
||||
defer func() { updater.UpdateDownloaded = oldVal }()
|
||||
|
||||
var notificationCalled atomic.Bool
|
||||
server := &Server{
|
||||
Store: testStore,
|
||||
Restart: func() {},
|
||||
UpdateAvailableFunc: func() {
|
||||
notificationCalled.Store(true)
|
||||
},
|
||||
}
|
||||
|
||||
// Re-enable auto-update via settings API
|
||||
settings.AutoUpdateEnabled = true
|
||||
body, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
if err := server.settings(rr, req); err != nil {
|
||||
t.Fatalf("settings() error = %v", err)
|
||||
}
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
if !notificationCalled.Load() {
|
||||
t.Fatal("expected UpdateAvailableFunc to be called when re-enabling with a downloaded update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingsToggleAutoUpdateOn_NoPendingUpdate_TriggersCheck(t *testing.T) {
|
||||
testStore := &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
// Start with auto-update disabled
|
||||
settings, err := testStore.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = false
|
||||
if err := testStore.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Ensure no pending update - clear both the downloaded flag and the stage dir
|
||||
oldVal := updater.UpdateDownloaded
|
||||
updater.UpdateDownloaded = false
|
||||
defer func() { updater.UpdateDownloaded = oldVal }()
|
||||
|
||||
oldStageDir := updater.UpdateStageDir
|
||||
updater.UpdateStageDir = t.TempDir() // empty dir means IsUpdatePending() returns false
|
||||
defer func() { updater.UpdateStageDir = oldStageDir }()
|
||||
|
||||
upd := &updater.Updater{Store: &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
|
||||
}}
|
||||
defer upd.Store.Close()
|
||||
|
||||
// Initialize the checkNow channel by starting (and immediately stopping) the checker
|
||||
// so TriggerImmediateCheck doesn't panic on nil channel
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
upd.StartBackgroundUpdaterChecker(ctx, func(string) error { return nil })
|
||||
defer cancel()
|
||||
|
||||
var notificationCalled atomic.Bool
|
||||
server := &Server{
|
||||
Store: testStore,
|
||||
Restart: func() {},
|
||||
Updater: upd,
|
||||
UpdateAvailableFunc: func() {
|
||||
notificationCalled.Store(true)
|
||||
},
|
||||
}
|
||||
|
||||
// Re-enable auto-update via settings API
|
||||
settings.AutoUpdateEnabled = true
|
||||
body, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
if err := server.settings(rr, req); err != nil {
|
||||
t.Fatalf("settings() error = %v", err)
|
||||
}
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// UpdateAvailableFunc should NOT be called since there's no pending update
|
||||
if notificationCalled.Load() {
|
||||
t.Fatal("UpdateAvailableFunc should not be called when there is no pending update")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/app/store"
|
||||
@@ -58,7 +59,8 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
||||
query := requestURL.Query()
|
||||
query.Add("os", runtime.GOOS)
|
||||
query.Add("arch", runtime.GOARCH)
|
||||
query.Add("version", version.Version)
|
||||
currentVersion := version.Version
|
||||
query.Add("version", currentVersion)
|
||||
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
|
||||
// The original macOS app used to use the device ID
|
||||
@@ -131,15 +133,27 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
||||
}
|
||||
|
||||
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
||||
// Create a cancellable context for this download
|
||||
downloadCtx, cancel := context.WithCancel(ctx)
|
||||
u.cancelDownloadLock.Lock()
|
||||
u.cancelDownload = cancel
|
||||
u.cancelDownloadLock.Unlock()
|
||||
defer func() {
|
||||
u.cancelDownloadLock.Lock()
|
||||
u.cancelDownload = nil
|
||||
u.cancelDownloadLock.Unlock()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Do a head first to check etag info
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
|
||||
req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// In case of slow downloads, continue the update check in the background
|
||||
bgctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
bgctx, bgcancel := context.WithCancel(downloadCtx)
|
||||
defer bgcancel()
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
@@ -176,6 +190,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
||||
_, err = os.Stat(stageFilename)
|
||||
if err == nil {
|
||||
slog.Info("update already downloaded", "bundle", stageFilename)
|
||||
UpdateDownloaded = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -244,33 +259,84 @@ func cleanupOldDownloads(stageDir string) {
|
||||
}
|
||||
|
||||
type Updater struct {
|
||||
Store *store.Store
|
||||
Store *store.Store
|
||||
cancelDownload context.CancelFunc
|
||||
cancelDownloadLock sync.Mutex
|
||||
checkNow chan struct{}
|
||||
}
|
||||
|
||||
// CancelOngoingDownload cancels any currently running download
|
||||
func (u *Updater) CancelOngoingDownload() {
|
||||
u.cancelDownloadLock.Lock()
|
||||
defer u.cancelDownloadLock.Unlock()
|
||||
if u.cancelDownload != nil {
|
||||
slog.Info("cancelling ongoing update download")
|
||||
u.cancelDownload()
|
||||
u.cancelDownload = nil
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerImmediateCheck signals the background checker to check for updates immediately
|
||||
func (u *Updater) TriggerImmediateCheck() {
|
||||
if u.checkNow != nil {
|
||||
select {
|
||||
case u.checkNow <- struct{}{}:
|
||||
default:
|
||||
// Check already pending, no need to queue another
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
||||
u.checkNow = make(chan struct{}, 1)
|
||||
go func() {
|
||||
// Don't blast an update message immediately after startup
|
||||
time.Sleep(UpdateCheckInitialDelay)
|
||||
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
|
||||
ticker := time.NewTicker(UpdateCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
available, resp := u.checkForUpdate(ctx)
|
||||
if available {
|
||||
err := u.DownloadNewRelease(ctx, resp)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
|
||||
} else {
|
||||
err = cb(resp.UpdateVersion)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug("stopping background update checker")
|
||||
return
|
||||
default:
|
||||
time.Sleep(UpdateCheckInterval)
|
||||
case <-u.checkNow:
|
||||
// Immediate check triggered
|
||||
case <-ticker.C:
|
||||
// Regular interval check
|
||||
}
|
||||
|
||||
// Always check for updates
|
||||
available, resp := u.checkForUpdate(ctx)
|
||||
if !available {
|
||||
continue
|
||||
}
|
||||
|
||||
// Update is available - check if auto-update is enabled for downloading
|
||||
settings, err := u.Store.Settings()
|
||||
if err != nil {
|
||||
slog.Error("failed to load settings", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !settings.AutoUpdateEnabled {
|
||||
// Auto-update disabled - don't download, just log
|
||||
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
|
||||
continue
|
||||
}
|
||||
|
||||
// Auto-update is enabled - download
|
||||
err = u.DownloadNewRelease(ctx, resp)
|
||||
if err != nil {
|
||||
slog.Error("failed to download new release", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Download successful - show tray notification (regardless of toggle state)
|
||||
err = cb(resp.UpdateVersion)
|
||||
if err != nil {
|
||||
slog.Warn("failed to register update available with tray", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -33,7 +35,7 @@ func TestIsNewReleaseAvailable(t *testing.T) {
|
||||
defer server.Close()
|
||||
slog.Debug("server", "url", server.URL)
|
||||
|
||||
updater := &Updater{Store: &store.Store{}}
|
||||
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||
defer updater.Store.Close() // Ensure database is closed
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
updatePresent, resp := updater.checkForUpdate(t.Context())
|
||||
@@ -84,8 +86,18 @@ func TestBackgoundChecker(t *testing.T) {
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{}}
|
||||
defer updater.Store.Close() // Ensure database is closed
|
||||
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||
defer updater.Store.Close()
|
||||
|
||||
settings, err := updater.Store.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = true
|
||||
if err := updater.Store.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
@@ -99,3 +111,264 @@ func TestBackgoundChecker(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoUpdateDisabledSkipsDownload(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
var downloadAttempted atomic.Bool
|
||||
done := make(chan struct{})
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
UpdateCheckInitialDelay = 5 * time.Millisecond
|
||||
UpdateCheckInterval = 5 * time.Millisecond
|
||||
VerifyDownload = func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/update.json" {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||
server.URL+"/9.9.9/"+Installer)))
|
||||
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||
downloadAttempted.Store(true)
|
||||
buf := &bytes.Buffer{}
|
||||
zw := zip.NewWriter(buf)
|
||||
zw.Close()
|
||||
io.Copy(w, buf)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||
defer updater.Store.Close()
|
||||
|
||||
// Ensure auto-update is disabled
|
||||
settings, err := updater.Store.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = false
|
||||
if err := updater.Store.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb := func(ver string) error {
|
||||
t.Fatal("callback should not be called when auto-update is disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
|
||||
// Wait enough time for multiple check cycles
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
close(done)
|
||||
|
||||
if downloadAttempted.Load() {
|
||||
t.Fatal("download should not be attempted when auto-update is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoUpdateReenabledDownloadsUpdate(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
var downloadAttempted atomic.Bool
|
||||
callbackCalled := make(chan struct{}, 1)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
UpdateCheckInitialDelay = 5 * time.Millisecond
|
||||
UpdateCheckInterval = 5 * time.Millisecond
|
||||
VerifyDownload = func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/update.json" {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||
server.URL+"/9.9.9/"+Installer)))
|
||||
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||
downloadAttempted.Store(true)
|
||||
buf := &bytes.Buffer{}
|
||||
zw := zip.NewWriter(buf)
|
||||
zw.Close()
|
||||
io.Copy(w, buf)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
upd := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||
defer upd.Store.Close()
|
||||
|
||||
// Start with auto-update disabled
|
||||
settings, err := upd.Store.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = false
|
||||
if err := upd.Store.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb := func(ver string) error {
|
||||
select {
|
||||
case callbackCalled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
upd.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
|
||||
// Wait for a few cycles with auto-update disabled - no download should happen
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if downloadAttempted.Load() {
|
||||
t.Fatal("download should not happen while auto-update is disabled")
|
||||
}
|
||||
|
||||
// Re-enable auto-update
|
||||
settings.AutoUpdateEnabled = true
|
||||
if err := upd.Store.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait for the checker to pick it up and download
|
||||
select {
|
||||
case <-callbackCalled:
|
||||
// Success: download happened and callback was called after re-enabling
|
||||
if !downloadAttempted.Load() {
|
||||
t.Fatal("expected download to be attempted after re-enabling")
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("expected download and callback after re-enabling auto-update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelOngoingDownload(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
downloadStarted := make(chan struct{})
|
||||
downloadCancelled := make(chan struct{})
|
||||
|
||||
ctx := t.Context()
|
||||
VerifyDownload = func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/update.json" {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||
server.URL+"/9.9.9/"+Installer)))
|
||||
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||
if r.Method == http.MethodHead {
|
||||
w.Header().Set("Content-Length", "1000000")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
// Signal that download has started
|
||||
close(downloadStarted)
|
||||
// Wait for cancellation or timeout
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
close(downloadCancelled)
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("download was not cancelled in time")
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||
defer updater.Store.Close()
|
||||
|
||||
_, resp := updater.checkForUpdate(ctx)
|
||||
|
||||
// Start download in goroutine
|
||||
go func() {
|
||||
_ = updater.DownloadNewRelease(ctx, resp)
|
||||
}()
|
||||
|
||||
// Wait for download to start
|
||||
select {
|
||||
case <-downloadStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("download did not start in time")
|
||||
}
|
||||
|
||||
// Cancel the download
|
||||
updater.CancelOngoingDownload()
|
||||
|
||||
// Verify cancellation was received
|
||||
select {
|
||||
case <-downloadCancelled:
|
||||
// Success
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("download cancellation was not received by server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTriggerImmediateCheck(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
checkCount := atomic.Int32{}
|
||||
checkDone := make(chan struct{}, 10)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
// Set a very long interval so only TriggerImmediateCheck causes checks
|
||||
UpdateCheckInitialDelay = 1 * time.Millisecond
|
||||
UpdateCheckInterval = 1 * time.Hour
|
||||
VerifyDownload = func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/update.json" {
|
||||
checkCount.Add(1)
|
||||
select {
|
||||
case checkDone <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
// Return no update available
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||
defer updater.Store.Close()
|
||||
|
||||
cb := func(ver string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
|
||||
// Wait for goroutine to start and pass initial delay
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// With 1 hour interval, no check should have happened yet
|
||||
initialCount := checkCount.Load()
|
||||
|
||||
// Trigger immediate check
|
||||
updater.TriggerImmediateCheck()
|
||||
|
||||
// Wait for the triggered check
|
||||
select {
|
||||
case <-checkDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("triggered check did not happen")
|
||||
}
|
||||
|
||||
finalCount := checkCount.Load()
|
||||
if finalCount <= initialCount {
|
||||
t.Fatalf("TriggerImmediateCheck did not cause additional check: initial=%d, final=%d", initialCount, finalCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -369,25 +369,6 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
|
||||
// const ERROR_SUCCESS syscall.Errno = 0
|
||||
|
||||
// t.muMenus.RLock()
|
||||
// menu := uintptr(t.menus[parentId])
|
||||
// t.muMenus.RUnlock()
|
||||
// res, _, err := pRemoveMenu.Call(
|
||||
// menu,
|
||||
// uintptr(menuItemId),
|
||||
// MF_BYCOMMAND,
|
||||
// )
|
||||
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
|
||||
// return err
|
||||
// }
|
||||
// t.delFromVisibleItems(parentId, menuItemId)
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func (t *winTray) showMenu() error {
|
||||
p := point{}
|
||||
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))
|
||||
|
||||
@@ -51,7 +51,6 @@ const (
|
||||
IMAGE_ICON = 1 // Loads an icon
|
||||
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
|
||||
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
|
||||
MF_BYCOMMAND = 0x00000000
|
||||
MFS_DISABLED = 0x00000003
|
||||
MFT_SEPARATOR = 0x00000800
|
||||
MFT_STRING = 0x00000000
|
||||
|
||||
@@ -257,10 +257,11 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
bts = sanitizeNonFiniteJSON(bts)
|
||||
|
||||
var p ModelParameters
|
||||
if err := json.Unmarshal(bts, &p); err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, fmt.Errorf("parse config.json: %w", err)
|
||||
}
|
||||
|
||||
if len(p.Architectures) < 1 {
|
||||
@@ -315,16 +316,20 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &glm4MoeLiteModel{}
|
||||
case "GlmOcrForConditionalGeneration":
|
||||
conv = &glmOcrModel{}
|
||||
case "Lfm2ForCausalLM":
|
||||
case "Lfm2ForCausalLM", "Lfm2MoeForCausalLM":
|
||||
conv = &lfm2Model{}
|
||||
case "Lfm2VlForConditionalGeneration":
|
||||
conv = &lfm2VLTextModel{}
|
||||
case "Qwen3NextForCausalLM":
|
||||
conv = &qwen3NextModel{}
|
||||
case "NemotronHForCausalLM":
|
||||
conv = &nemotronHModel{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, conv); err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, fmt.Errorf("parse config.json for %q: %w", p.Architectures[0], err)
|
||||
}
|
||||
|
||||
if t, ok := conv.(moreParser); ok {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@@ -13,42 +15,149 @@ type lfm2Model struct {
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
BlockFFDim uint32 `json:"block_ff_dim"`
|
||||
BlockMultipleOf uint32 `json:"block_multiple_of"`
|
||||
BlockAutoAdjustFFDim bool `json:"block_auto_adjust_ff_dim"`
|
||||
BlockFFNDimMultiplier float32 `json:"block_ffn_dim_multiplier"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
NormEps float32 `json:"norm_eps"`
|
||||
ConvLCache uint32 `json:"conv_L_cache"`
|
||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumLocalExperts uint32 `json:"num_local_experts"`
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
NumDenseLayers uint32 `json:"num_dense_layers"`
|
||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
TieEmbedding bool `json:"tie_embedding"`
|
||||
RopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
} `json:"rope_parameters"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*lfm2Model)(nil)
|
||||
|
||||
const (
|
||||
defaultMaxPositionEmbeddings = uint32(128_000)
|
||||
fallbackContextLength = uint32(32_768)
|
||||
)
|
||||
|
||||
func (p *lfm2Model) isMoE() bool {
|
||||
return p.ModelType == "lfm2_moe" || p.expertCount() > 0
|
||||
}
|
||||
|
||||
func (p *lfm2Model) ropeFreqBase() float32 {
|
||||
if p.RopeTheta != 0 {
|
||||
return p.RopeTheta
|
||||
}
|
||||
|
||||
return p.RopeParameters.RopeTheta
|
||||
}
|
||||
|
||||
func (p *lfm2Model) expertCount() uint32 {
|
||||
if p.NumLocalExperts > 0 {
|
||||
return p.NumLocalExperts
|
||||
}
|
||||
return p.NumExperts
|
||||
}
|
||||
|
||||
func (p *lfm2Model) feedForwardLength() uint32 {
|
||||
ff := p.IntermediateSize
|
||||
if p.BlockFFDim != 0 {
|
||||
ff = p.BlockFFDim
|
||||
}
|
||||
|
||||
if !p.BlockAutoAdjustFFDim || p.BlockMultipleOf == 0 {
|
||||
return ff
|
||||
}
|
||||
|
||||
ff = (2 * ff) / 3
|
||||
|
||||
// Keep default multiplier behavior consistent with llama.cpp conversion.
|
||||
if p.BlockFFNDimMultiplier != 0 {
|
||||
ff = uint32(float32(ff) * p.BlockFFNDimMultiplier)
|
||||
}
|
||||
|
||||
m := p.BlockMultipleOf
|
||||
return m * ((ff + m - 1) / m)
|
||||
}
|
||||
|
||||
func (p *lfm2Model) hasKnownContextLengthFallbackSignature() bool {
|
||||
return p.isMoE() &&
|
||||
p.VocabSize == 65536 &&
|
||||
p.HiddenSize == 2048 &&
|
||||
p.NumHiddenLayers == 40 &&
|
||||
p.IntermediateSize == 11776 &&
|
||||
p.NumAttentionHeads == 32 &&
|
||||
p.NumKeyValueHeads == 8 &&
|
||||
p.NumDenseLayers == 2 &&
|
||||
p.expertCount() == 64 &&
|
||||
p.NumExpertsPerToken == 4 &&
|
||||
p.MoEIntermediateSize == 1536
|
||||
}
|
||||
|
||||
func (p *lfm2Model) contextLength() uint32 {
|
||||
if p.MaxPositionEmbeddings == defaultMaxPositionEmbeddings && p.hasKnownContextLengthFallbackSignature() {
|
||||
return fallbackContextLength
|
||||
}
|
||||
|
||||
return p.MaxPositionEmbeddings
|
||||
}
|
||||
|
||||
func (p *lfm2Model) KV(t *Tokenizer) KV {
|
||||
architecture := "lfm2"
|
||||
if p.isMoE() {
|
||||
architecture = "lfm2moe"
|
||||
}
|
||||
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "lfm2"
|
||||
kv["lfm2.vocab_size"] = p.VocabSize
|
||||
kv["lfm2.block_count"] = p.NumHiddenLayers
|
||||
kv["lfm2.embedding_length"] = p.HiddenSize
|
||||
kv["lfm2.feed_forward_length"] = p.IntermediateSize
|
||||
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["general.architecture"] = architecture
|
||||
kv["tokenizer.ggml.pre"] = "lfm2"
|
||||
kv["vocab_size"] = p.VocabSize
|
||||
kv["block_count"] = p.NumHiddenLayers
|
||||
kv["embedding_length"] = p.HiddenSize
|
||||
kv["feed_forward_length"] = p.feedForwardLength()
|
||||
kv["context_length"] = p.contextLength()
|
||||
|
||||
// Build per-layer KV head count array based on layer_types
|
||||
// (0 = shortconv layer, non-zero = attention layer with that many KV heads)
|
||||
// (0 = shortconv layer, non-zero = attention layer with that many KV heads).
|
||||
//
|
||||
// Dense LFM2 in HF defaults to all attention layers when layer_types is absent.
|
||||
// Preserve that behavior to avoid accidentally emitting all-conv metadata.
|
||||
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
|
||||
for i := range p.NumHiddenLayers {
|
||||
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
|
||||
if len(p.LayerTypes) == 0 {
|
||||
for i := range p.NumHiddenLayers {
|
||||
kvHeadCounts[i] = p.NumKeyValueHeads
|
||||
}
|
||||
} else {
|
||||
for i := range p.NumHiddenLayers {
|
||||
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
|
||||
kvHeadCounts[i] = p.NumKeyValueHeads
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kv["lfm2.attention.head_count"] = p.NumAttentionHeads
|
||||
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
|
||||
kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
|
||||
kv["lfm2.rope.freq_base"] = p.RopeTheta
|
||||
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
|
||||
kv["attention.head_count"] = p.NumAttentionHeads
|
||||
kv["attention.head_count_kv"] = kvHeadCounts
|
||||
kv["attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["attention.layer_norm_rms_epsilon"] = p.NormEps
|
||||
kv["shortconv.l_cache"] = p.ConvLCache
|
||||
|
||||
if ropeFreqBase := p.ropeFreqBase(); ropeFreqBase != 0 {
|
||||
kv["rope.freq_base"] = ropeFreqBase
|
||||
}
|
||||
|
||||
if p.isMoE() {
|
||||
kv["expert_count"] = p.expertCount()
|
||||
kv["expert_used_count"] = p.NumExpertsPerToken
|
||||
kv["expert_feed_forward_length"] = p.MoEIntermediateSize
|
||||
kv["leading_dense_block_count"] = p.NumDenseLayers
|
||||
kv["expert_gating_func"] = uint32(2) // sigmoid
|
||||
kv["expert_weights_scale"] = cmp.Or(p.RoutedScalingFactor, float32(1.0))
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
@@ -56,6 +165,30 @@ func (p *lfm2Model) KV(t *Tokenizer) KV {
|
||||
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
if p.isMoE() {
|
||||
merges := make([]merge, 0, p.NumHiddenLayers*3)
|
||||
for i := range p.NumHiddenLayers {
|
||||
if i < p.NumDenseLayers {
|
||||
continue
|
||||
}
|
||||
|
||||
merges = append(merges, merge{
|
||||
fmt.Sprintf("blk.%d.feed_forward.experts.*.w1.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.feed_forward.experts.*.w2.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.feed_forward.experts.*.w3.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
})
|
||||
}
|
||||
|
||||
merged, remaining := mergeTensors(ts, merges...)
|
||||
out = append(out, merged...)
|
||||
ts = remaining
|
||||
}
|
||||
|
||||
for _, t := range ts {
|
||||
shape := t.Shape()
|
||||
|
||||
@@ -80,7 +213,7 @@ func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
func (p *lfm2Model) Replacements() []string {
|
||||
return []string{
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.embedding_norm", "output_norm",
|
||||
"model.embedding_norm", "token_embd_norm",
|
||||
"model.layers", "blk",
|
||||
"operator_norm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
@@ -92,6 +225,8 @@ func (p *lfm2Model) Replacements() []string {
|
||||
"conv.conv", "shortconv.conv",
|
||||
"conv.in_proj", "shortconv.in_proj",
|
||||
"conv.out_proj", "shortconv.out_proj",
|
||||
"feed_forward.gate", "ffn_gate_inp",
|
||||
"feed_forward.expert_bias", "exp_probs_b.bias",
|
||||
"feed_forward.w1", "ffn_gate",
|
||||
"feed_forward.w2", "ffn_down",
|
||||
"feed_forward.w3", "ffn_up",
|
||||
|
||||
271
convert/convert_lfm2_test.go
Normal file
271
convert/convert_lfm2_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type lfm2StubTensor struct {
|
||||
tensorBase
|
||||
}
|
||||
|
||||
func newLFM2StubTensor(name string, shape []uint64) *lfm2StubTensor {
|
||||
return &lfm2StubTensor{
|
||||
tensorBase: tensorBase{
|
||||
name: name,
|
||||
shape: shape,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *lfm2StubTensor) WriteTo(io.Writer) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (t *lfm2StubTensor) Clone() Tensor {
|
||||
return &lfm2StubTensor{
|
||||
tensorBase: tensorBase{
|
||||
name: t.name,
|
||||
shape: slices.Clone(t.shape),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2MoEKV(t *testing.T) {
|
||||
var p lfm2Model
|
||||
p.ModelParameters.ModelType = "lfm2_moe"
|
||||
p.VocabSize = 65536
|
||||
p.HiddenSize = 2048
|
||||
p.NumHiddenLayers = 4
|
||||
p.MaxPositionEmbeddings = 128000
|
||||
p.IntermediateSize = 11776
|
||||
p.NumAttentionHeads = 32
|
||||
p.NumKeyValueHeads = 8
|
||||
p.LayerTypes = []string{"conv", "full_attention", "conv", "full_attention"}
|
||||
p.NormEps = 1e-5
|
||||
p.ConvLCache = 3
|
||||
p.MoEIntermediateSize = 1536
|
||||
p.NumExperts = 64
|
||||
p.NumExpertsPerToken = 4
|
||||
p.NumDenseLayers = 2
|
||||
p.RopeParameters.RopeTheta = 1_000_000
|
||||
|
||||
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||
|
||||
if got, want := kv["general.architecture"], "lfm2moe"; got != want {
|
||||
t.Fatalf("general.architecture = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want {
|
||||
t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["expert_count"], uint32(64); got != want {
|
||||
t.Fatalf("expert_count = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["expert_used_count"], uint32(4); got != want {
|
||||
t.Fatalf("expert_used_count = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["expert_feed_forward_length"], uint32(1536); got != want {
|
||||
t.Fatalf("expert_feed_forward_length = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["leading_dense_block_count"], uint32(2); got != want {
|
||||
t.Fatalf("leading_dense_block_count = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["expert_gating_func"], uint32(2); got != want {
|
||||
t.Fatalf("expert_gating_func = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
gotHeadCounts, ok := kv["attention.head_count_kv"].([]uint32)
|
||||
if !ok {
|
||||
t.Fatalf("attention.head_count_kv has unexpected type %T", kv["attention.head_count_kv"])
|
||||
}
|
||||
|
||||
wantHeadCounts := []uint32{0, 8, 0, 8}
|
||||
if !slices.Equal(gotHeadCounts, wantHeadCounts) {
|
||||
t.Fatalf("attention.head_count_kv = %v, want %v", gotHeadCounts, wantHeadCounts)
|
||||
}
|
||||
|
||||
if got, want := kv["rope.freq_base"], float32(1_000_000); got != want {
|
||||
t.Fatalf("rope.freq_base = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2DenseKV(t *testing.T) {
|
||||
p := lfm2Model{
|
||||
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 32000},
|
||||
HiddenSize: 1024,
|
||||
NumHiddenLayers: 2,
|
||||
MaxPositionEmbeddings: 32768,
|
||||
IntermediateSize: 4096,
|
||||
NumAttentionHeads: 16,
|
||||
NumKeyValueHeads: 4,
|
||||
LayerTypes: []string{"conv", "full_attention"},
|
||||
NormEps: 1e-5,
|
||||
ConvLCache: 3,
|
||||
RopeTheta: 10000,
|
||||
}
|
||||
|
||||
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||
|
||||
if got, want := kv["general.architecture"], "lfm2"; got != want {
|
||||
t.Fatalf("general.architecture = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want {
|
||||
t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if _, ok := kv["expert_count"]; ok {
|
||||
t.Fatalf("expert_count should not be set for dense lfm2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2MoETensors(t *testing.T) {
|
||||
p := lfm2Model{
|
||||
ModelParameters: ModelParameters{ModelType: "lfm2_moe"},
|
||||
NumHiddenLayers: 4,
|
||||
NumDenseLayers: 2,
|
||||
}
|
||||
|
||||
in := []Tensor{
|
||||
newLFM2StubTensor("blk.2.feed_forward.experts.0.w1.weight", []uint64{1536, 2048}),
|
||||
newLFM2StubTensor("blk.2.feed_forward.experts.1.w1.weight", []uint64{1536, 2048}),
|
||||
newLFM2StubTensor("blk.2.feed_forward.experts.0.w2.weight", []uint64{2048, 1536}),
|
||||
newLFM2StubTensor("blk.2.feed_forward.experts.1.w2.weight", []uint64{2048, 1536}),
|
||||
newLFM2StubTensor("blk.2.feed_forward.experts.0.w3.weight", []uint64{1536, 2048}),
|
||||
newLFM2StubTensor("blk.2.feed_forward.experts.1.w3.weight", []uint64{1536, 2048}),
|
||||
newLFM2StubTensor("blk.0.shortconv.conv.weight", []uint64{2048, 1, 3}),
|
||||
}
|
||||
|
||||
out := p.Tensors(in)
|
||||
|
||||
byName := make(map[string][]uint64, len(out))
|
||||
for _, tns := range out {
|
||||
byName[tns.Name] = tns.Shape
|
||||
}
|
||||
|
||||
if got, ok := byName["blk.2.ffn_gate_exps.weight"]; !ok {
|
||||
t.Fatalf("missing merged tensor blk.2.ffn_gate_exps.weight")
|
||||
} else if !slices.Equal(got, []uint64{2, 1536, 2048}) {
|
||||
t.Fatalf("blk.2.ffn_gate_exps.weight shape = %v, want [2 1536 2048]", got)
|
||||
}
|
||||
|
||||
if got, ok := byName["blk.2.ffn_down_exps.weight"]; !ok {
|
||||
t.Fatalf("missing merged tensor blk.2.ffn_down_exps.weight")
|
||||
} else if !slices.Equal(got, []uint64{2, 2048, 1536}) {
|
||||
t.Fatalf("blk.2.ffn_down_exps.weight shape = %v, want [2 2048 1536]", got)
|
||||
}
|
||||
|
||||
if got, ok := byName["blk.2.ffn_up_exps.weight"]; !ok {
|
||||
t.Fatalf("missing merged tensor blk.2.ffn_up_exps.weight")
|
||||
} else if !slices.Equal(got, []uint64{2, 1536, 2048}) {
|
||||
t.Fatalf("blk.2.ffn_up_exps.weight shape = %v, want [2 1536 2048]", got)
|
||||
}
|
||||
|
||||
if got, ok := byName["blk.0.shortconv.conv.weight"]; !ok {
|
||||
t.Fatalf("missing shortconv tensor")
|
||||
} else if !slices.Equal(got, []uint64{2048, 3}) {
|
||||
t.Fatalf("blk.0.shortconv.conv.weight shape = %v, want [2048 3]", got)
|
||||
}
|
||||
|
||||
if _, ok := byName["blk.2.feed_forward.experts.0.w1.weight"]; ok {
|
||||
t.Fatalf("unmerged expert tensor should not be present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2MoEReplacements(t *testing.T) {
|
||||
p := lfm2Model{}
|
||||
replacer := strings.NewReplacer(p.Replacements()...)
|
||||
|
||||
if got, want := replacer.Replace("model.layers.2.feed_forward.expert_bias"), "blk.2.exp_probs_b.bias"; got != want {
|
||||
t.Fatalf("expert bias replacement = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
if got, want := replacer.Replace("model.layers.2.feed_forward.gate.weight"), "blk.2.ffn_gate_inp.weight"; got != want {
|
||||
t.Fatalf("gate replacement = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2KVContextLengthEdgeCaseFallbackOverride(t *testing.T) {
|
||||
p := lfm2Model{
|
||||
ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536},
|
||||
HiddenSize: 2048,
|
||||
NumHiddenLayers: 40,
|
||||
MaxPositionEmbeddings: 128000,
|
||||
IntermediateSize: 11776,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 8,
|
||||
LayerTypes: make([]string, 40),
|
||||
NormEps: 1e-5,
|
||||
ConvLCache: 3,
|
||||
MoEIntermediateSize: 1536,
|
||||
NumExperts: 64,
|
||||
NumExpertsPerToken: 4,
|
||||
NumDenseLayers: 2,
|
||||
}
|
||||
for i := 0; i < len(p.LayerTypes); i++ {
|
||||
p.LayerTypes[i] = "conv"
|
||||
}
|
||||
p.LayerTypes[2] = "full_attention"
|
||||
|
||||
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||
|
||||
if got, want := kv["context_length"], uint32(32768); got != want {
|
||||
t.Fatalf("context_length = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2KVContextLengthNoOverride(t *testing.T) {
|
||||
p := lfm2Model{
|
||||
ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536},
|
||||
HiddenSize: 2048,
|
||||
NumHiddenLayers: 39, // mismatch: should not trigger edge case
|
||||
MaxPositionEmbeddings: 128000,
|
||||
IntermediateSize: 11776,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 8,
|
||||
LayerTypes: []string{"conv", "full_attention"},
|
||||
NormEps: 1e-5,
|
||||
ConvLCache: 3,
|
||||
MoEIntermediateSize: 1536,
|
||||
NumExperts: 64,
|
||||
NumExpertsPerToken: 4,
|
||||
NumDenseLayers: 2,
|
||||
}
|
||||
|
||||
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||
|
||||
if got, want := kv["context_length"], uint32(128000); got != want {
|
||||
t.Fatalf("context_length = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2KVFeedForwardLengthAutoAdjust(t *testing.T) {
|
||||
p := lfm2Model{
|
||||
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536},
|
||||
HiddenSize: 2048,
|
||||
NumHiddenLayers: 16,
|
||||
MaxPositionEmbeddings: 128000,
|
||||
IntermediateSize: 12288, // should be ignored when block_ff_dim is set
|
||||
BlockFFDim: 12288,
|
||||
BlockAutoAdjustFFDim: true,
|
||||
BlockMultipleOf: 256,
|
||||
BlockFFNDimMultiplier: 1.0,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 8,
|
||||
LayerTypes: []string{"conv", "full_attention"},
|
||||
NormEps: 1e-5,
|
||||
ConvLCache: 3,
|
||||
}
|
||||
|
||||
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||
|
||||
if got, want := kv["feed_forward_length"], uint32(8192); got != want {
|
||||
t.Fatalf("feed_forward_length = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
417
convert/convert_lfm2_vl.go
Normal file
417
convert/convert_lfm2_vl.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
// lfm2VLTextModel converts the language model component of LFM2 VL checkpoints.
|
||||
type lfm2VLTextModel struct {
|
||||
TextConfig lfm2Model `json:"text_config"`
|
||||
DoImageSplitting *bool `json:"do_image_splitting"`
|
||||
DownsampleFactor uint32 `json:"downsample_factor"`
|
||||
EncoderPatchSize uint32 `json:"encoder_patch_size"`
|
||||
ImageTokenID uint32 `json:"image_token_id"`
|
||||
MaxImageTokens uint32 `json:"max_image_tokens"`
|
||||
MinImageTokens uint32 `json:"min_image_tokens"`
|
||||
MaxTiles uint32 `json:"max_tiles"`
|
||||
MinTiles uint32 `json:"min_tiles"`
|
||||
TileSize uint32 `json:"tile_size"`
|
||||
MaxPixelsTolerance float32 `json:"max_pixels_tolerance"`
|
||||
ProjectorUseLayernorm bool `json:"projector_use_layernorm"`
|
||||
ProjectorHiddenSize uint32 `json:"projector_hidden_size"`
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
UseImageSpecialTokens *bool `json:"use_image_special_tokens"`
|
||||
UseThumbnail *bool `json:"use_thumbnail"`
|
||||
VisionConfig struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_eps"`
|
||||
} `json:"vision_config"`
|
||||
Processor struct {
|
||||
ImageProcessor struct {
|
||||
DoImageSplitting *bool `json:"do_image_splitting"`
|
||||
DownsampleFactor uint32 `json:"downsample_factor"`
|
||||
MaxImageTokens uint32 `json:"max_image_tokens"`
|
||||
MinImageTokens uint32 `json:"min_image_tokens"`
|
||||
MaxTiles uint32 `json:"max_tiles"`
|
||||
MinTiles uint32 `json:"min_tiles"`
|
||||
MaxPixelsTol float32 `json:"max_pixels_tolerance"`
|
||||
TileSize uint32 `json:"tile_size"`
|
||||
UseThumbnail *bool `json:"use_thumbnail"`
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
Size struct {
|
||||
Height uint32 `json:"height"`
|
||||
Width uint32 `json:"width"`
|
||||
} `json:"size"`
|
||||
} `json:"image_processor"`
|
||||
}
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) textModel() *lfm2Model {
|
||||
return &p.TextConfig
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) specialTokenTypes() []string {
|
||||
return p.textModel().specialTokenTypes()
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) parseMore(fsys fs.FS) error {
|
||||
bts, err := fs.ReadFile(fsys, "processor_config.json")
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(bts, &p.Processor)
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) visionImageSize() uint32 {
|
||||
// LFM2-VL image processor operates on 512 tiles and downsamples by factor 2
|
||||
// before projection. Keep a fixed square image size compatible with position
|
||||
// embeddings and the simplified runtime image pipeline.
|
||||
tile := cmp.Or(
|
||||
p.Processor.ImageProcessor.TileSize,
|
||||
p.Processor.ImageProcessor.Size.Height,
|
||||
p.Processor.ImageProcessor.Size.Width,
|
||||
uint32(512),
|
||||
)
|
||||
downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
|
||||
if downsample == 0 {
|
||||
return tile
|
||||
}
|
||||
|
||||
return max(uint32(1), tile/downsample)
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) KV(t *Tokenizer) KV {
|
||||
kv := p.textModel().KV(t)
|
||||
|
||||
boolOr := func(defaultValue bool, values ...*bool) bool {
|
||||
for _, v := range values {
|
||||
if v != nil {
|
||||
return *v
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
kv["vision.block_count"] = cmp.Or(p.VisionConfig.NumHiddenLayers, uint32(27))
|
||||
kv["vision.embedding_length"] = cmp.Or(p.VisionConfig.HiddenSize, uint32(1152))
|
||||
kv["vision.feed_forward_length"] = cmp.Or(p.VisionConfig.IntermediateSize, uint32(4304))
|
||||
kv["vision.attention.head_count"] = cmp.Or(p.VisionConfig.NumAttentionHeads, uint32(16))
|
||||
kv["vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionConfig.LayerNormEpsilon, float32(1e-6))
|
||||
kv["vision.patch_size"] = cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16))
|
||||
kv["vision.num_channels"] = cmp.Or(p.VisionConfig.NumChannels, uint32(3))
|
||||
kv["vision.image_size"] = p.visionImageSize()
|
||||
kv["vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
|
||||
kv["vision.projector.use_layernorm"] = p.ProjectorUseLayernorm
|
||||
kv["vision.do_image_splitting"] = boolOr(true, p.DoImageSplitting, p.Processor.ImageProcessor.DoImageSplitting)
|
||||
kv["vision.min_tiles"] = cmp.Or(p.MinTiles, p.Processor.ImageProcessor.MinTiles, uint32(2))
|
||||
kv["vision.max_tiles"] = cmp.Or(p.MaxTiles, p.Processor.ImageProcessor.MaxTiles, uint32(10))
|
||||
kv["vision.tile_size"] = cmp.Or(p.TileSize, p.Processor.ImageProcessor.TileSize, uint32(512))
|
||||
kv["vision.min_image_tokens"] = cmp.Or(p.MinImageTokens, p.Processor.ImageProcessor.MinImageTokens, uint32(64))
|
||||
kv["vision.max_image_tokens"] = cmp.Or(p.MaxImageTokens, p.Processor.ImageProcessor.MaxImageTokens, uint32(256))
|
||||
kv["vision.max_pixels_tolerance"] = cmp.Or(p.MaxPixelsTolerance, p.Processor.ImageProcessor.MaxPixelsTol, float32(2.0))
|
||||
kv["vision.use_thumbnail"] = boolOr(true, p.UseThumbnail, p.Processor.ImageProcessor.UseThumbnail)
|
||||
kv["vision.use_image_special_tokens"] = boolOr(true, p.UseImageSpecialTokens)
|
||||
kv["vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5}))
|
||||
kv["vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5}))
|
||||
kv["vision.image_token_id"] = cmp.Or(p.ImageTokenID, uint32(396))
|
||||
|
||||
setVisionTokenID := func(k, token string) {
|
||||
if t == nil || t.Vocabulary == nil {
|
||||
return
|
||||
}
|
||||
for i, v := range t.Vocabulary.Tokens {
|
||||
if v == token {
|
||||
kv[k] = uint32(i)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
setVisionTokenID("vision.image_start_token_id", "<|image_start|>")
|
||||
setVisionTokenID("vision.image_end_token_id", "<|image_end|>")
|
||||
setVisionTokenID("vision.image_thumbnail_token_id", "<|img_thumbnail|>")
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
patchSize := int(cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16)))
|
||||
numChannels := int(cmp.Or(p.VisionConfig.NumChannels, uint32(3)))
|
||||
|
||||
for _, t := range ts {
|
||||
if t.Name() == "v.patch_embd.weight" {
|
||||
shape := t.Shape()
|
||||
if len(shape) == 2 {
|
||||
inputDim := uint64(numChannels * patchSize * patchSize)
|
||||
if shape[1] == inputDim {
|
||||
channels := numChannels
|
||||
patch := patchSize
|
||||
t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) {
|
||||
return repackPatchEmbeddingWeight(data, srcShape, channels, patch)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out := p.textModel().Tensors(ts)
|
||||
for _, t := range out {
|
||||
if t.Name == "v.patch_embd.weight" && len(t.Shape) == 2 {
|
||||
t.Shape = []uint64{t.Shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) Replacements() []string {
|
||||
out := make([]string, 0, 96)
|
||||
|
||||
addText := func(from, to string) {
|
||||
out = append(out, from, to)
|
||||
if strings.HasPrefix(from, "model.") {
|
||||
suffix := strings.TrimPrefix(from, "model.")
|
||||
out = append(out,
|
||||
"model.language_model."+suffix, to,
|
||||
"model.language_model.model."+suffix, to,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
base := p.textModel().Replacements()
|
||||
for i := 0; i+1 < len(base); i += 2 {
|
||||
addText(base[i], base[i+1])
|
||||
}
|
||||
|
||||
// Vision tower + multimodal projector tensors (single-file conversion).
|
||||
out = append(out,
|
||||
"model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd",
|
||||
"model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd",
|
||||
"model.vision_tower.vision_model.encoder.layers", "v.blk",
|
||||
"model.vision_tower.vision_model.post_layernorm", "v.post_ln",
|
||||
"model.multi_modal_projector.layer_norm", "mm.layer_norm",
|
||||
"model.multi_modal_projector.linear_1", "mm.1",
|
||||
"model.multi_modal_projector.linear_2", "mm.2",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.out_proj", "attn_out",
|
||||
"layer_norm1", "ln1",
|
||||
"layer_norm2", "ln2",
|
||||
"mlp.fc1", "ffn_up",
|
||||
"mlp.fc2", "ffn_down",
|
||||
)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// lfm2VLProjectorModel converts the vision encoder + projector component of LFM2 VL checkpoints.
|
||||
type lfm2VLProjectorModel struct {
|
||||
ModelParameters
|
||||
DownsampleFactor uint32 `json:"downsample_factor"`
|
||||
ProjectorHiddenDim uint32 `json:"projector_hidden_size"`
|
||||
VisionModel struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_eps"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
} `json:"vision_config"`
|
||||
Processor struct {
|
||||
ImageProcessor struct {
|
||||
DownsampleFactor uint32 `json:"downsample_factor"`
|
||||
TileSize uint32 `json:"tile_size"`
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
Size struct {
|
||||
Height uint32 `json:"height"`
|
||||
Width uint32 `json:"width"`
|
||||
} `json:"size"`
|
||||
} `json:"image_processor"`
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ ModelConverter = (*lfm2VLTextModel)(nil)
|
||||
_ ModelConverter = (*lfm2VLProjectorModel)(nil)
|
||||
_ moreParser = (*lfm2VLTextModel)(nil)
|
||||
_ moreParser = (*lfm2VLProjectorModel)(nil)
|
||||
)
|
||||
|
||||
func (p *lfm2VLProjectorModel) parseMore(fsys fs.FS) error {
|
||||
bts, err := fs.ReadFile(fsys, "processor_config.json")
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(bts, &p.Processor)
|
||||
}
|
||||
|
||||
func (p *lfm2VLProjectorModel) imageSize() uint32 {
|
||||
if p.VisionModel.ImageSize > 0 {
|
||||
return p.VisionModel.ImageSize
|
||||
}
|
||||
|
||||
downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
|
||||
baseSize := cmp.Or(
|
||||
p.Processor.ImageProcessor.TileSize,
|
||||
p.Processor.ImageProcessor.Size.Height,
|
||||
p.Processor.ImageProcessor.Size.Width,
|
||||
uint32(256),
|
||||
)
|
||||
if downsample == 0 {
|
||||
return baseSize
|
||||
}
|
||||
|
||||
return max(uint32(1), baseSize/downsample)
|
||||
}
|
||||
|
||||
func (p *lfm2VLProjectorModel) KV(_ *Tokenizer) KV {
|
||||
kv := KV{
|
||||
"general.architecture": "clip",
|
||||
"general.type": "mmproj",
|
||||
"general.file_type": uint32(1),
|
||||
"general.quantization_version": uint32(2),
|
||||
"clip.has_vision_encoder": true,
|
||||
"clip.projector_type": "lfm2",
|
||||
"clip.use_gelu": true,
|
||||
}
|
||||
|
||||
kv["clip.vision.block_count"] = cmp.Or(p.VisionModel.NumHiddenLayers, uint32(27))
|
||||
kv["clip.vision.embedding_length"] = cmp.Or(p.VisionModel.HiddenSize, uint32(1152))
|
||||
kv["clip.vision.feed_forward_length"] = cmp.Or(p.VisionModel.IntermediateSize, uint32(4304))
|
||||
kv["clip.vision.attention.head_count"] = cmp.Or(p.VisionModel.NumAttentionHeads, uint32(16))
|
||||
kv["clip.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, float32(1e-6))
|
||||
kv["clip.vision.patch_size"] = cmp.Or(p.VisionModel.PatchSize, uint32(16))
|
||||
kv["clip.vision.image_size"] = p.imageSize()
|
||||
kv["clip.vision.projection_dim"] = cmp.Or(p.ProjectorHiddenDim, uint32(2048))
|
||||
kv["clip.vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
|
||||
kv["clip.vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5}))
|
||||
kv["clip.vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5}))
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func defaultFloat32Slice(v, fallback []float32) []float32 {
|
||||
if len(v) > 0 {
|
||||
return v
|
||||
}
|
||||
|
||||
return fallback
|
||||
}
|
||||
|
||||
func (p *lfm2VLProjectorModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
numChannels := cmp.Or(p.VisionModel.NumChannels, uint32(3))
|
||||
patchSize := cmp.Or(p.VisionModel.PatchSize, uint32(16))
|
||||
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
if !(strings.HasPrefix(name, "v.") || strings.HasPrefix(name, "mm.")) {
|
||||
continue
|
||||
}
|
||||
|
||||
shape := t.Shape()
|
||||
if name == "v.patch_embd.weight" && len(shape) == 2 {
|
||||
inputDim := uint64(numChannels * patchSize * patchSize)
|
||||
if shape[1] == inputDim {
|
||||
shape = []uint64{shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)}
|
||||
channels := int(numChannels)
|
||||
patch := int(patchSize)
|
||||
t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) {
|
||||
return repackPatchEmbeddingWeight(data, srcShape, channels, patch)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *lfm2VLProjectorModel) Replacements() []string {
|
||||
return []string{
|
||||
"model.multi_modal_projector.linear_1", "mm.1",
|
||||
"model.multi_modal_projector.linear_2", "mm.2",
|
||||
"model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd",
|
||||
"model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd",
|
||||
"model.vision_tower.vision_model.encoder.layers", "v.blk",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.out_proj", "attn_out",
|
||||
"layer_norm1", "ln1",
|
||||
"layer_norm2", "ln2",
|
||||
"mlp.fc1", "ffn_up",
|
||||
"mlp.fc2", "ffn_down",
|
||||
"model.vision_tower.vision_model.post_layernorm", "v.post_ln",
|
||||
}
|
||||
}
|
||||
|
||||
func repackPatchEmbeddingWeight(data []float32, srcShape []uint64, channels, patch int) ([]float32, error) {
|
||||
if len(srcShape) != 2 {
|
||||
return nil, fmt.Errorf("invalid patch embedding shape rank: %d", len(srcShape))
|
||||
}
|
||||
|
||||
outDim := int(srcShape[0])
|
||||
flatInputDim := int(srcShape[1])
|
||||
expectedInputDim := channels * patch * patch
|
||||
if flatInputDim != expectedInputDim {
|
||||
return nil, fmt.Errorf("invalid patch embedding input dim: got %d, want %d", flatInputDim, expectedInputDim)
|
||||
}
|
||||
|
||||
expectedSize := outDim * flatInputDim
|
||||
if len(data) != expectedSize {
|
||||
return nil, fmt.Errorf("invalid patch embedding data size: got %d, want %d", len(data), expectedSize)
|
||||
}
|
||||
|
||||
repacked := make([]float32, len(data))
|
||||
perChannel := patch * patch
|
||||
|
||||
for o := range outDim {
|
||||
inBase := o * flatInputDim
|
||||
outBase := o * flatInputDim
|
||||
|
||||
for y := range patch {
|
||||
for x := range patch {
|
||||
inPixelBase := inBase + (y*patch+x)*channels
|
||||
for c := range channels {
|
||||
src := inPixelBase + c
|
||||
dst := outBase + c*perChannel + y*patch + x
|
||||
repacked[dst] = data[src]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return repacked, nil
|
||||
}
|
||||
249
convert/convert_lfm2_vl_test.go
Normal file
249
convert/convert_lfm2_vl_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLFM2VLTextModelKVUsesTextConfig(t *testing.T) {
|
||||
p := lfm2VLTextModel{
|
||||
TextConfig: lfm2Model{
|
||||
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536},
|
||||
HiddenSize: 2048,
|
||||
NumHiddenLayers: 16,
|
||||
MaxPositionEmbeddings: 128000,
|
||||
IntermediateSize: 12288,
|
||||
BlockFFDim: 12288,
|
||||
BlockAutoAdjustFFDim: true,
|
||||
BlockMultipleOf: 256,
|
||||
BlockFFNDimMultiplier: 1.0,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 8,
|
||||
LayerTypes: []string{"conv", "full_attention"},
|
||||
NormEps: 1e-5,
|
||||
ConvLCache: 3,
|
||||
},
|
||||
DownsampleFactor: 2,
|
||||
VisionConfig: struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_eps"`
|
||||
}{
|
||||
HiddenSize: 1152,
|
||||
IntermediateSize: 4304,
|
||||
NumAttentionHeads: 16,
|
||||
NumHiddenLayers: 27,
|
||||
NumChannels: 3,
|
||||
PatchSize: 16,
|
||||
LayerNormEpsilon: 1e-6,
|
||||
},
|
||||
}
|
||||
p.Processor.ImageProcessor.TileSize = 512
|
||||
p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||
p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5}
|
||||
|
||||
kv := p.KV(&Tokenizer{
|
||||
Vocabulary: &Vocabulary{
|
||||
Model: "gpt2",
|
||||
Tokens: []string{"<|pad|>", "<image>", "<|image_start|>", "<|image_end|>", "<|img_thumbnail|>"},
|
||||
},
|
||||
})
|
||||
|
||||
if got, want := kv["general.architecture"], "lfm2"; got != want {
|
||||
t.Fatalf("general.architecture = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["feed_forward_length"], uint32(8192); got != want {
|
||||
t.Fatalf("feed_forward_length = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["vision.block_count"], uint32(27); got != want {
|
||||
t.Fatalf("vision.block_count = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["vision.image_size"], uint32(256); got != want {
|
||||
t.Fatalf("vision.image_size = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["vision.image_token_id"], uint32(396); got != want {
|
||||
t.Fatalf("vision.image_token_id = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["vision.image_start_token_id"], uint32(2); got != want {
|
||||
t.Fatalf("vision.image_start_token_id = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["vision.do_image_splitting"], true; got != want {
|
||||
t.Fatalf("vision.do_image_splitting = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.min_tiles"], uint32(2); got != want {
|
||||
t.Fatalf("vision.min_tiles = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.max_tiles"], uint32(10); got != want {
|
||||
t.Fatalf("vision.max_tiles = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.tile_size"], uint32(512); got != want {
|
||||
t.Fatalf("vision.tile_size = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.use_thumbnail"], true; got != want {
|
||||
t.Fatalf("vision.use_thumbnail = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.use_image_special_tokens"], true; got != want {
|
||||
t.Fatalf("vision.use_image_special_tokens = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2VLTextModelTensorsIncludeVision(t *testing.T) {
|
||||
p := lfm2VLTextModel{}
|
||||
p.VisionConfig.PatchSize = 16
|
||||
p.VisionConfig.NumChannels = 3
|
||||
input := []Tensor{
|
||||
newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}),
|
||||
newLFM2StubTensor("model.layers.0.ffn_norm.weight", []uint64{2048}),
|
||||
newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}),
|
||||
newLFM2StubTensor("v.blk.0.attn_q.weight", []uint64{1152, 1152}),
|
||||
newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}),
|
||||
}
|
||||
|
||||
out := p.Tensors(input)
|
||||
if len(out) == 0 {
|
||||
t.Fatal("expected non-empty tensor list")
|
||||
}
|
||||
|
||||
foundPatch := false
|
||||
foundVision := false
|
||||
for _, tns := range out {
|
||||
if tns.Name == "v.patch_embd.weight" {
|
||||
foundPatch = true
|
||||
if !slices.Equal(tns.Shape, []uint64{1152, 3, 16, 16}) {
|
||||
t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", tns.Shape)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(tns.Name, "v.") || strings.HasPrefix(tns.Name, "mm.") {
|
||||
foundVision = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundPatch {
|
||||
t.Fatal("expected v.patch_embd.weight in output tensors")
|
||||
}
|
||||
if !foundVision {
|
||||
t.Fatal("expected at least one vision/projector tensor in output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2VLTextModelReplacements(t *testing.T) {
|
||||
p := lfm2VLTextModel{}
|
||||
r := strings.NewReplacer(p.Replacements()...)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "language_model_embed_tokens",
|
||||
in: "model.language_model.embed_tokens.weight",
|
||||
want: "token_embd.weight",
|
||||
},
|
||||
{
|
||||
name: "language_model_layers",
|
||||
in: "model.language_model.layers.2.self_attn.q_proj.weight",
|
||||
want: "blk.2.attn_q.weight",
|
||||
},
|
||||
{
|
||||
name: "nested_language_model_prefix",
|
||||
in: "model.language_model.model.embedding_norm.weight",
|
||||
want: "token_embd_norm.weight",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := r.Replace(tt.in); got != tt.want {
|
||||
t.Fatalf("replacement(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2VLProjectorKV(t *testing.T) {
|
||||
p := lfm2VLProjectorModel{
|
||||
DownsampleFactor: 2,
|
||||
ProjectorHiddenDim: 2048,
|
||||
}
|
||||
p.VisionModel.NumHiddenLayers = 27
|
||||
p.VisionModel.HiddenSize = 1152
|
||||
p.VisionModel.IntermediateSize = 4304
|
||||
p.VisionModel.NumAttentionHeads = 16
|
||||
p.VisionModel.PatchSize = 16
|
||||
p.VisionModel.LayerNormEpsilon = 1e-6
|
||||
p.Processor.ImageProcessor.TileSize = 512
|
||||
p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||
p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5}
|
||||
|
||||
kv := p.KV(nil)
|
||||
|
||||
if got, want := kv["general.architecture"], "clip"; got != want {
|
||||
t.Fatalf("general.architecture = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["clip.projector_type"], "lfm2"; got != want {
|
||||
t.Fatalf("clip.projector_type = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["clip.vision.image_size"], uint32(256); got != want {
|
||||
t.Fatalf("clip.vision.image_size = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2VLProjectorTensorsPatchReshape(t *testing.T) {
|
||||
p := lfm2VLProjectorModel{}
|
||||
p.VisionModel.NumChannels = 3
|
||||
p.VisionModel.PatchSize = 16
|
||||
|
||||
input := []Tensor{
|
||||
newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}),
|
||||
newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}),
|
||||
newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}),
|
||||
}
|
||||
|
||||
out := p.Tensors(input)
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("expected 2 tensors, got %d", len(out))
|
||||
}
|
||||
|
||||
var patchShape []uint64
|
||||
for _, tns := range out {
|
||||
if tns.Name == "v.patch_embd.weight" {
|
||||
patchShape = tns.Shape
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !slices.Equal(patchShape, []uint64{1152, 3, 16, 16}) {
|
||||
t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", patchShape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepackPatchEmbeddingWeight(t *testing.T) {
|
||||
data := []float32{
|
||||
0, 1, // y=0,x=0
|
||||
2, 3, // y=0,x=1
|
||||
4, 5, // y=1,x=0
|
||||
6, 7, // y=1,x=1
|
||||
}
|
||||
|
||||
got, err := repackPatchEmbeddingWeight(data, []uint64{1, 8}, 2, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
want := []float32{0, 2, 4, 6, 1, 3, 5, 7}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("repacked data = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
385
convert/convert_nemotron_h.go
Normal file
385
convert/convert_nemotron_h.go
Normal file
@@ -0,0 +1,385 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type hybridPattern string
|
||||
|
||||
func (p *hybridPattern) UnmarshalJSON(data []byte) error {
|
||||
if string(data) == "null" {
|
||||
*p = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
var single string
|
||||
if err := json.Unmarshal(data, &single); err == nil {
|
||||
*p = hybridPattern(strings.TrimSpace(single))
|
||||
return nil
|
||||
}
|
||||
|
||||
var parts []string
|
||||
if err := json.Unmarshal(data, &parts); err == nil {
|
||||
*p = hybridPattern(strings.Join(parts, ""))
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("hybrid_override_pattern must be a string or string array")
|
||||
}
|
||||
|
||||
type nemotronHModel struct {
|
||||
ModelParameters
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
NormEpsilon float32 `json:"norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
ConvKernel uint32 `json:"conv_kernel"`
|
||||
SSMStateSize uint32 `json:"ssm_state_size"`
|
||||
MambaNumHeads uint32 `json:"mamba_num_heads"`
|
||||
MambaHeadDim uint32 `json:"mamba_head_dim"`
|
||||
NGroups uint32 `json:"n_groups"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
HybridOverridePattern hybridPattern `json:"hybrid_override_pattern"`
|
||||
|
||||
// MoE
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumSharedExperts uint32 `json:"num_shared_experts"`
|
||||
NRoutedExperts uint32 `json:"n_routed_experts"`
|
||||
NSharedExperts uint32 `json:"n_shared_experts"`
|
||||
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
|
||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
MoESharedExpertIntermediate uint32 `json:"moe_shared_expert_intermediate_size"`
|
||||
NormTopKProb bool `json:"norm_topk_prob"`
|
||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||
ExpertGroupCount uint32 `json:"n_group"`
|
||||
ExpertGroupUsedCount uint32 `json:"topk_group"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*nemotronHModel)(nil)
|
||||
|
||||
func (n *nemotronHModel) parseMore(_ fs.FS) error {
|
||||
if n.NumHiddenLayers == 0 {
|
||||
return fmt.Errorf("nemotron_h: num_hidden_layers must be set")
|
||||
}
|
||||
if n.HiddenSize == 0 {
|
||||
return fmt.Errorf("nemotron_h: hidden_size must be set")
|
||||
}
|
||||
if n.NumAttentionHeads == 0 {
|
||||
return fmt.Errorf("nemotron_h: num_attention_heads must be set")
|
||||
}
|
||||
if n.HeadDim == 0 {
|
||||
if n.HiddenSize%n.NumAttentionHeads != 0 {
|
||||
return fmt.Errorf("nemotron_h: hidden_size (%d) must be divisible by num_attention_heads (%d)", n.HiddenSize, n.NumAttentionHeads)
|
||||
}
|
||||
n.HeadDim = n.HiddenSize / n.NumAttentionHeads
|
||||
}
|
||||
if n.NumKeyValueHeads == 0 {
|
||||
n.NumKeyValueHeads = n.NumAttentionHeads
|
||||
}
|
||||
if n.ConvKernel == 0 {
|
||||
return fmt.Errorf("nemotron_h: conv_kernel must be set")
|
||||
}
|
||||
if n.SSMStateSize == 0 {
|
||||
return fmt.Errorf("nemotron_h: ssm_state_size must be set")
|
||||
}
|
||||
if n.ssmHeadCount() == 0 {
|
||||
return fmt.Errorf("nemotron_h: mamba_num_heads must be set")
|
||||
}
|
||||
if n.MambaHeadDim == 0 {
|
||||
return fmt.Errorf("nemotron_h: mamba_head_dim must be set")
|
||||
}
|
||||
if n.NGroups == 0 {
|
||||
n.NGroups = 1
|
||||
}
|
||||
|
||||
if _, _, err := n.layerArrays(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if n.isMoE() {
|
||||
if n.routedExpertCount() == 0 {
|
||||
return fmt.Errorf("nemotron_h: routed expert count must be set for MoE models")
|
||||
}
|
||||
if n.NumExpertsPerTok == 0 {
|
||||
return fmt.Errorf("nemotron_h: num_experts_per_tok must be set for MoE models")
|
||||
}
|
||||
if n.NumExpertsPerTok > n.routedExpertCount() {
|
||||
return fmt.Errorf("nemotron_h: num_experts_per_tok (%d) cannot exceed expert_count (%d)", n.NumExpertsPerTok, n.routedExpertCount())
|
||||
}
|
||||
if n.moeIntermediateSize() == 0 {
|
||||
return fmt.Errorf("nemotron_h: moe_intermediate_size must be set for MoE models")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) isMoE() bool {
|
||||
return cmp.Or(n.routedExpertCount(), n.NumExpertsPerTok, n.MoEIntermediateSize) > 0
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) routedExpertCount() uint32 {
|
||||
return cmp.Or(n.NRoutedExperts, n.NumExperts)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) sharedExpertCount() uint32 {
|
||||
return cmp.Or(n.NSharedExperts, n.NumSharedExperts)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) ssmHeadCount() uint32 {
|
||||
return n.MambaNumHeads
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) ssmInnerSize() uint32 {
|
||||
return n.MambaHeadDim * n.ssmHeadCount()
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) epsilon() float32 {
|
||||
return cmp.Or(n.NormEpsilon, n.LayerNormEpsilon, float32(1e-5))
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) moeIntermediateSize() uint32 {
|
||||
return cmp.Or(n.MoEIntermediateSize, n.IntermediateSize)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) denseIntermediateSize() uint32 {
|
||||
return cmp.Or(n.IntermediateSize, n.MoEIntermediateSize)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) layerArrays() (headCountKV []uint32, ffnLengths []uint32, err error) {
|
||||
pattern := strings.TrimSpace(string(n.HybridOverridePattern))
|
||||
if pattern == "" {
|
||||
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern must be set")
|
||||
}
|
||||
|
||||
runes := []rune(pattern)
|
||||
if len(runes) != int(n.NumHiddenLayers) {
|
||||
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern length (%d) must match num_hidden_layers (%d)", len(runes), n.NumHiddenLayers)
|
||||
}
|
||||
|
||||
headCountKV = make([]uint32, n.NumHiddenLayers)
|
||||
ffnLengths = make([]uint32, n.NumHiddenLayers)
|
||||
|
||||
attnKVHeads := cmp.Or(n.NumKeyValueHeads, n.NumAttentionHeads)
|
||||
moeFFN := n.moeIntermediateSize()
|
||||
denseFFN := n.denseIntermediateSize()
|
||||
|
||||
for i, layerType := range runes {
|
||||
switch layerType {
|
||||
case 'M':
|
||||
// Recurrent layer: no KV heads and no FFN.
|
||||
case '*', 'A':
|
||||
// Attention-only layer.
|
||||
headCountKV[i] = attnKVHeads
|
||||
case 'E':
|
||||
// MoE layer.
|
||||
if moeFFN == 0 {
|
||||
return nil, nil, fmt.Errorf("nemotron_h: moe layer at index %d but moe_intermediate_size is zero", i)
|
||||
}
|
||||
ffnLengths[i] = moeFFN
|
||||
case '-':
|
||||
// Dense FFN layer.
|
||||
if denseFFN == 0 {
|
||||
return nil, nil, fmt.Errorf("nemotron_h: dense FFN layer at index %d but intermediate_size is zero", i)
|
||||
}
|
||||
ffnLengths[i] = denseFFN
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("nemotron_h: unsupported layer type %q in hybrid_override_pattern at index %d", layerType, i)
|
||||
}
|
||||
}
|
||||
|
||||
return headCountKV, ffnLengths, nil
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) KV(t *Tokenizer) KV {
|
||||
kv := n.ModelParameters.KV(t)
|
||||
|
||||
arch := "nemotron_h"
|
||||
if n.isMoE() {
|
||||
arch = "nemotron_h_moe"
|
||||
}
|
||||
kv["general.architecture"] = arch
|
||||
kv["block_count"] = n.NumHiddenLayers
|
||||
kv["context_length"] = n.MaxPositionEmbeddings
|
||||
kv["embedding_length"] = n.HiddenSize
|
||||
kv["attention.head_count"] = n.NumAttentionHeads
|
||||
kv["attention.key_length"] = n.HeadDim
|
||||
kv["attention.value_length"] = n.HeadDim
|
||||
kv["attention.layer_norm_epsilon"] = n.epsilon()
|
||||
kv["attention.layer_norm_rms_epsilon"] = n.epsilon()
|
||||
kv["rope.freq_base"] = cmp.Or(n.RopeTheta, float32(10000))
|
||||
if n.PartialRotaryFactor > 0 && n.PartialRotaryFactor <= 1 {
|
||||
kv["rope.dimension_count"] = uint32(float32(n.HeadDim) * n.PartialRotaryFactor)
|
||||
}
|
||||
|
||||
if headCountKV, ffnLengths, err := n.layerArrays(); err == nil {
|
||||
kv["attention.head_count_kv"] = headCountKV
|
||||
kv["feed_forward_length"] = ffnLengths
|
||||
}
|
||||
|
||||
kv["ssm.conv_kernel"] = n.ConvKernel
|
||||
kv["ssm.inner_size"] = n.ssmInnerSize()
|
||||
kv["ssm.state_size"] = n.SSMStateSize
|
||||
kv["ssm.group_count"] = n.NGroups
|
||||
kv["ssm.time_step_rank"] = n.ssmHeadCount()
|
||||
|
||||
if n.isMoE() {
|
||||
kv["expert_count"] = n.routedExpertCount()
|
||||
kv["expert_used_count"] = n.NumExpertsPerTok
|
||||
kv["expert_feed_forward_length"] = n.moeIntermediateSize()
|
||||
if n.sharedExpertCount() > 0 {
|
||||
kv["expert_shared_count"] = n.sharedExpertCount()
|
||||
}
|
||||
if n.MoESharedExpertIntermediate > 0 {
|
||||
kv["expert_shared_feed_forward_length"] = n.MoESharedExpertIntermediate
|
||||
}
|
||||
kv["expert_weights_norm"] = n.NormTopKProb
|
||||
kv["expert_weights_scale"] = n.RoutedScalingFactor
|
||||
if n.ExpertGroupCount > 0 {
|
||||
kv["expert_group_count"] = n.ExpertGroupCount
|
||||
}
|
||||
if n.ExpertGroupUsedCount > 0 {
|
||||
kv["expert_group_used_count"] = n.ExpertGroupUsedCount
|
||||
}
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func normalizeVectorShapeToColumn(shape []uint64) []uint64 {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return []uint64{shape[0], 1}
|
||||
case 2:
|
||||
if shape[0] == 1 && shape[1] > 1 {
|
||||
return []uint64{shape[1], 1}
|
||||
}
|
||||
if shape[1] == 1 && shape[0] > 1 {
|
||||
return []uint64{shape[0], 1}
|
||||
}
|
||||
}
|
||||
|
||||
return slices.Clone(shape)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
remaining := ts
|
||||
if n.isMoE() {
|
||||
merges := make([]merge, 0, n.NumHiddenLayers*2)
|
||||
for i := range n.NumHiddenLayers {
|
||||
merges = append(merges, merge{
|
||||
fmt.Sprintf("blk.%d.mixer.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.mixer.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
})
|
||||
}
|
||||
|
||||
merged, rest := mergeTensors(ts, merges...)
|
||||
out = append(out, merged...)
|
||||
remaining = rest
|
||||
}
|
||||
|
||||
nGroups := uint64(cmp.Or(n.NGroups, uint32(1)))
|
||||
for _, t := range remaining {
|
||||
name := t.Name()
|
||||
shape := slices.Clone(t.Shape())
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(name, ".ssm_a"):
|
||||
shape = normalizeVectorShapeToColumn(shape)
|
||||
t.SetRepacker(func(_ string, data []float32, _ []uint64) ([]float32, error) {
|
||||
out := make([]float32, len(data))
|
||||
for i, v := range data {
|
||||
out[i] = -float32(math.Exp(float64(v)))
|
||||
}
|
||||
return out, nil
|
||||
})
|
||||
case strings.HasSuffix(name, ".ssm_d"):
|
||||
shape = normalizeVectorShapeToColumn(shape)
|
||||
case strings.HasSuffix(name, ".ssm_norm.weight"):
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
if nGroups > 0 && shape[0]%nGroups == 0 {
|
||||
shape = []uint64{nGroups, shape[0] / nGroups}
|
||||
}
|
||||
case 2:
|
||||
if shape[0] == 1 && nGroups > 0 && shape[1]%nGroups == 0 {
|
||||
shape = []uint64{nGroups, shape[1] / nGroups}
|
||||
}
|
||||
}
|
||||
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
||||
if len(shape) == 3 {
|
||||
if shape[0] == 1 {
|
||||
shape = []uint64{shape[1], shape[2]}
|
||||
} else if shape[1] == 1 {
|
||||
shape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) Replacements() []string {
|
||||
return []string{
|
||||
// Embedding and output
|
||||
"lm_head", "output",
|
||||
"backbone.embeddings", "token_embd",
|
||||
"backbone.norm_f", "output_norm",
|
||||
"backbone.layers", "blk",
|
||||
|
||||
// Recurrent (Mamba2) tensors
|
||||
"mixer.in_proj", "ssm_in",
|
||||
"mixer.out_proj", "ssm_out",
|
||||
"mixer.dt_bias", "ssm_dt.bias",
|
||||
"mixer.A_log", "ssm_a",
|
||||
"mixer.D", "ssm_d",
|
||||
"mixer.conv1d", "ssm_conv1d",
|
||||
"mixer.norm.weight", "ssm_norm.weight",
|
||||
|
||||
// Attention tensors
|
||||
"mixer.q_proj", "attn_q",
|
||||
"mixer.k_proj", "attn_k",
|
||||
"mixer.v_proj", "attn_v",
|
||||
"mixer.o_proj", "attn_output",
|
||||
|
||||
// FFN / MoE tensors
|
||||
"mixer.gate.e_score_correction_bias", "exp_probs_b.bias",
|
||||
"mixer.gate", "ffn_gate_inp",
|
||||
"mixer.fc1_latent_proj", "ffn_latent_in",
|
||||
"mixer.fc2_latent_proj", "ffn_latent_out",
|
||||
"mixer.shared_experts.up_proj", "ffn_up_shexp",
|
||||
"mixer.shared_experts.down_proj", "ffn_down_shexp",
|
||||
"mixer.up_proj", "ffn_up",
|
||||
"mixer.down_proj", "ffn_down",
|
||||
|
||||
// Per-layer pre-norm
|
||||
".norm.weight", ".attn_norm.weight",
|
||||
}
|
||||
}
|
||||
230
convert/convert_nemotron_h_test.go
Normal file
230
convert/convert_nemotron_h_test.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHybridPatternUnmarshal(t *testing.T) {
|
||||
t.Run("string", func(t *testing.T) {
|
||||
var p hybridPattern
|
||||
if err := json.Unmarshal([]byte(`"MEM*"`), &p); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := string(p), "MEM*"; got != want {
|
||||
t.Fatalf("unexpected pattern: got %q want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("array", func(t *testing.T) {
|
||||
var p hybridPattern
|
||||
if err := json.Unmarshal([]byte(`["M","E","M","*"]`), &p); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := string(p), "MEM*"; got != want {
|
||||
t.Fatalf("unexpected pattern: got %q want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNemotronHLayerArrays(t *testing.T) {
|
||||
m := &nemotronHModel{
|
||||
NumHiddenLayers: 5,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 8,
|
||||
HybridOverridePattern: "MEM*E",
|
||||
NRoutedExperts: 128,
|
||||
NumExpertsPerTok: 6,
|
||||
MoEIntermediateSize: 1856,
|
||||
}
|
||||
|
||||
headsKV, ffn, err := m.layerArrays()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := headsKV, []uint32{0, 0, 0, 8, 0}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected head_count_kv: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := ffn, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHKV(t *testing.T) {
|
||||
m := &nemotronHModel{
|
||||
MaxPositionEmbeddings: 1048576,
|
||||
HiddenSize: 2688,
|
||||
NumHiddenLayers: 5,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 2,
|
||||
HeadDim: 128,
|
||||
LayerNormEpsilon: 1e-5,
|
||||
RopeTheta: 10000,
|
||||
PartialRotaryFactor: 0.5,
|
||||
ConvKernel: 4,
|
||||
SSMStateSize: 128,
|
||||
MambaNumHeads: 64,
|
||||
MambaHeadDim: 64,
|
||||
NGroups: 8,
|
||||
HybridOverridePattern: "MEM*E",
|
||||
NRoutedExperts: 128,
|
||||
NSharedExperts: 1,
|
||||
NumExpertsPerTok: 6,
|
||||
MoEIntermediateSize: 1856,
|
||||
MoESharedExpertIntermediate: 3712,
|
||||
NormTopKProb: true,
|
||||
RoutedScalingFactor: 2.5,
|
||||
}
|
||||
if err := m.parseMore(nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||
if got, want := kv["general.architecture"], "nemotron_h_moe"; got != want {
|
||||
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
|
||||
if !ok {
|
||||
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
|
||||
}
|
||||
if got, want := headCountKV, []uint32{0, 0, 0, 2, 0}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
ffnLength, ok := kv["feed_forward_length"].([]uint32)
|
||||
if !ok {
|
||||
t.Fatalf("feed_forward_length has unexpected type: %T", kv["feed_forward_length"])
|
||||
}
|
||||
if got, want := ffnLength, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHTensorsTransforms(t *testing.T) {
|
||||
m := &nemotronHModel{NGroups: 8}
|
||||
in := []Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_a",
|
||||
shape: []uint64{4},
|
||||
data: []float32{0, 1, 2, 3},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_d",
|
||||
shape: []uint64{4},
|
||||
data: []float32{0, 1, 2, 3},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_norm.weight",
|
||||
shape: []uint64{16},
|
||||
data: make([]float32, 16),
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_conv1d.weight",
|
||||
shape: []uint64{10, 1, 4},
|
||||
data: make([]float32, 40),
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors(in)
|
||||
if len(out) != len(in) {
|
||||
t.Fatalf("unexpected output tensor count: got %d want %d", len(out), len(in))
|
||||
}
|
||||
|
||||
got := map[string]struct {
|
||||
shape []uint64
|
||||
writer io.WriterTo
|
||||
}{}
|
||||
for _, t := range out {
|
||||
got[t.Name] = struct {
|
||||
shape []uint64
|
||||
writer io.WriterTo
|
||||
}{shape: t.Shape, writer: t.WriterTo}
|
||||
}
|
||||
|
||||
if shape := got["blk.0.ssm_a"].shape; !slices.Equal(shape, []uint64{4, 1}) {
|
||||
t.Fatalf("unexpected ssm_a shape: %v", shape)
|
||||
}
|
||||
if shape := got["blk.0.ssm_d"].shape; !slices.Equal(shape, []uint64{4, 1}) {
|
||||
t.Fatalf("unexpected ssm_d shape: %v", shape)
|
||||
}
|
||||
if shape := got["blk.0.ssm_norm.weight"].shape; !slices.Equal(shape, []uint64{8, 2}) {
|
||||
t.Fatalf("unexpected ssm_norm shape: %v", shape)
|
||||
}
|
||||
if shape := got["blk.0.ssm_conv1d.weight"].shape; !slices.Equal(shape, []uint64{10, 4}) {
|
||||
t.Fatalf("unexpected ssm_conv1d shape: %v", shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := got["blk.0.ssm_a"].writer.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
values := make([]float32, 4)
|
||||
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// 0 -> -exp(0) == -1
|
||||
if values[0] != -1 {
|
||||
t.Fatalf("unexpected transformed ssm_a[0]: got %v want -1", values[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHLoadModelMetadata(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
config := `{
|
||||
"architectures": ["NemotronHForCausalLM"],
|
||||
"model_type": "nemotron_h",
|
||||
"num_hidden_layers": 4,
|
||||
"hidden_size": 512,
|
||||
"max_position_embeddings": 32768,
|
||||
"num_attention_heads": 8,
|
||||
"num_key_value_heads": 2,
|
||||
"head_dim": 64,
|
||||
"layer_norm_epsilon": 1e-5,
|
||||
"conv_kernel": 4,
|
||||
"ssm_state_size": 128,
|
||||
"mamba_num_heads": 16,
|
||||
"mamba_head_dim": 32,
|
||||
"n_groups": 8,
|
||||
"hybrid_override_pattern": "ME*M",
|
||||
"n_routed_experts": 16,
|
||||
"num_experts_per_tok": 4,
|
||||
"moe_intermediate_size": 256
|
||||
}`
|
||||
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
kv, _, err := LoadModelMetadata(os.DirFS(tempDir))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := kv.(*nemotronHModel); !ok {
|
||||
t.Fatalf("unexpected converter type: %T", kv)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHReplacementsLatentProjections(t *testing.T) {
|
||||
m := &nemotronHModel{}
|
||||
r := strings.NewReplacer(m.Replacements()...)
|
||||
|
||||
if got, want := r.Replace("backbone.layers.1.mixer.fc1_latent_proj.weight"), "blk.1.ffn_latent_in.weight"; got != want {
|
||||
t.Fatalf("unexpected fc1 replacement: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := r.Replace("backbone.layers.1.mixer.fc2_latent_proj.weight"), "blk.1.ffn_latent_out.weight"; got != want {
|
||||
t.Fatalf("unexpected fc2 replacement: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
97
convert/json_compat.go
Normal file
97
convert/json_compat.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package convert
|
||||
|
||||
// sanitizeNonFiniteJSON rewrites non-standard JSON numeric tokens that some
|
||||
// HF configs emit (Infinity, -Infinity, NaN) into standard JSON numbers.
|
||||
//
|
||||
// This is intentionally conservative:
|
||||
// - only runs outside quoted strings
|
||||
// - only rewrites full tokens
|
||||
//
|
||||
// We map these values to 0 because encoding/json rejects non-finite values,
|
||||
// and these fields are typically model-side metadata not consumed by the
|
||||
// converter.
|
||||
func sanitizeNonFiniteJSON(in []byte) []byte {
|
||||
if len(in) == 0 {
|
||||
return in
|
||||
}
|
||||
|
||||
out := make([]byte, 0, len(in))
|
||||
inString := false
|
||||
escape := false
|
||||
|
||||
for i := 0; i < len(in); {
|
||||
c := in[i]
|
||||
|
||||
if inString {
|
||||
out = append(out, c)
|
||||
if escape {
|
||||
escape = false
|
||||
} else if c == '\\' {
|
||||
escape = true
|
||||
} else if c == '"' {
|
||||
inString = false
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '"' {
|
||||
inString = true
|
||||
out = append(out, c)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if hasToken(in, i, "-Infinity") {
|
||||
out = append(out, '0')
|
||||
i += len("-Infinity")
|
||||
continue
|
||||
}
|
||||
|
||||
if hasToken(in, i, "Infinity") {
|
||||
out = append(out, '0')
|
||||
i += len("Infinity")
|
||||
continue
|
||||
}
|
||||
|
||||
if hasToken(in, i, "NaN") {
|
||||
out = append(out, '0')
|
||||
i += len("NaN")
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, c)
|
||||
i++
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func hasToken(in []byte, at int, tok string) bool {
|
||||
end := at + len(tok)
|
||||
if at < 0 || end > len(in) {
|
||||
return false
|
||||
}
|
||||
if string(in[at:end]) != tok {
|
||||
return false
|
||||
}
|
||||
if at > 0 && !isJSONValuePrefixBoundary(in[at-1]) {
|
||||
return false
|
||||
}
|
||||
if end < len(in) && !isJSONValueSuffixBoundary(in[end]) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isJSONWhitespace(b byte) bool {
|
||||
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
|
||||
}
|
||||
|
||||
func isJSONValuePrefixBoundary(b byte) bool {
|
||||
return isJSONWhitespace(b) || b == ':' || b == ',' || b == '['
|
||||
}
|
||||
|
||||
func isJSONValueSuffixBoundary(b byte) bool {
|
||||
return isJSONWhitespace(b) || b == ',' || b == ']' || b == '}'
|
||||
}
|
||||
46
convert/json_compat_test.go
Normal file
46
convert/json_compat_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package convert
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeNonFiniteJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "infinity token",
|
||||
in: `{"a":[0,Infinity,1]}`,
|
||||
want: `{"a":[0,0,1]}`,
|
||||
},
|
||||
{
|
||||
name: "negative infinity token",
|
||||
in: `{"a":-Infinity}`,
|
||||
want: `{"a":0}`,
|
||||
},
|
||||
{
|
||||
name: "nan token",
|
||||
in: `{"a":NaN}`,
|
||||
want: `{"a":0}`,
|
||||
},
|
||||
{
|
||||
name: "tokens inside strings untouched",
|
||||
in: `{"a":"Infinity -Infinity NaN","b":Infinity}`,
|
||||
want: `{"a":"Infinity -Infinity NaN","b":0}`,
|
||||
},
|
||||
{
|
||||
name: "identifier-like token untouched",
|
||||
in: `{"a":InfinityValue}`,
|
||||
want: `{"a":InfinityValue}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := string(sanitizeNonFiniteJSON([]byte(tt.in)))
|
||||
if got != tt.want {
|
||||
t.Fatalf("sanitizeNonFiniteJSON() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -212,8 +212,13 @@ type tokenizer struct {
|
||||
|
||||
PreTokenizer struct {
|
||||
PreTokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Type string `json:"type"`
|
||||
Behavior string `json:"behavior"`
|
||||
Invert bool `json:"invert"`
|
||||
AddPrefixSpace bool `json:"add_prefix_space"`
|
||||
TrimOffsets bool `json:"trim_offsets"`
|
||||
UseRegex bool `json:"use_regex"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
} `json:"pretokenizers"`
|
||||
|
||||
@@ -191,6 +191,84 @@ func TestParseTokenizer(t *testing.T) {
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "llama-bpe pretokenizer and control tokens",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
"tokenizer.json": strings.NewReader(`{
|
||||
"added_tokens": [
|
||||
{"id": 1, "content": "<|startoftext|>", "special": true},
|
||||
{"id": 6, "content": "<|im_start|>", "special": true},
|
||||
{"id": 7, "content": "<|im_end|>", "special": true},
|
||||
{"id": 8, "content": "<|tool_list_start|>", "special": true},
|
||||
{"id": 9, "content": "<|tool_list_end|>", "special": true},
|
||||
{"id": 10, "content": "<|tool_call_start|>", "special": true},
|
||||
{"id": 11, "content": "<|tool_call_end|>", "special": true},
|
||||
{"id": 12, "content": "<|tool_response_start|>", "special": true},
|
||||
{"id": 13, "content": "<|tool_response_end|>", "special": true},
|
||||
{"id": 396, "content": "<image>", "special": true},
|
||||
{"id": 64400, "content": "<think>", "special": true},
|
||||
{"id": 64401, "content": "</think>", "special": true}
|
||||
],
|
||||
"model": {
|
||||
"vocab": {
|
||||
"<|startoftext|>": 1,
|
||||
"<|im_start|>": 6,
|
||||
"<|im_end|>": 7,
|
||||
"<|tool_list_start|>": 8,
|
||||
"<|tool_list_end|>": 9,
|
||||
"<|tool_call_start|>": 10,
|
||||
"<|tool_call_end|>": 11,
|
||||
"<|tool_response_start|>": 12,
|
||||
"<|tool_response_end|>": 13,
|
||||
"<image>": 396,
|
||||
"<think>": 64400,
|
||||
"</think>": 64401
|
||||
}
|
||||
},
|
||||
"pre_tokenizer": {
|
||||
"type": "Sequence",
|
||||
"pretokenizers": [
|
||||
{
|
||||
"type": "Split",
|
||||
"pattern": {
|
||||
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
},
|
||||
"behavior": "Isolated",
|
||||
"invert": false
|
||||
},
|
||||
{
|
||||
"type": "ByteLevel",
|
||||
"add_prefix_space": false,
|
||||
"trim_offsets": true,
|
||||
"use_regex": false
|
||||
}
|
||||
]
|
||||
}
|
||||
}`),
|
||||
}),
|
||||
want: &Tokenizer{
|
||||
Vocabulary: &Vocabulary{
|
||||
Model: "gpt2",
|
||||
Tokens: []string{
|
||||
"<|startoftext|>",
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|tool_list_start|>",
|
||||
"<|tool_list_end|>",
|
||||
"<|tool_call_start|>",
|
||||
"<|tool_call_end|>",
|
||||
"<|tool_response_start|>",
|
||||
"<|tool_response_end|>",
|
||||
"<image>",
|
||||
"<think>",
|
||||
"</think>",
|
||||
},
|
||||
Scores: []float32{1, 6, 7, 8, 9, 10, 11, 12, 13, 396, 64400, 64401},
|
||||
Types: []int32{3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3},
|
||||
},
|
||||
Pre: "llama-bpe",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list string merges",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
|
||||
@@ -160,6 +160,27 @@ func (kv KV) SSMGroupCount() uint64 {
|
||||
return uint64(kv.Uint("ssm.group_count"))
|
||||
}
|
||||
|
||||
func (kv KV) FFNLength() []uint64 {
|
||||
ffnLengthDefault := uint32(0)
|
||||
ffnLength := kv.UintOrArrayValueAsArray("feed_forward_length", ffnLengthDefault)
|
||||
if len(ffnLength) == 1 {
|
||||
ffnLengthDefault = ffnLength[0]
|
||||
}
|
||||
nLayers := int(kv.BlockCount())
|
||||
if len(ffnLength) > nLayers {
|
||||
slog.Warn("got more elements of feed_forward_length than layers", "len(ffnLength)", len(ffnLength), "layers", nLayers)
|
||||
}
|
||||
out := make([]uint64, nLayers)
|
||||
for i := range nLayers {
|
||||
if i >= len(ffnLength) {
|
||||
out[i] = uint64(ffnLengthDefault)
|
||||
} else {
|
||||
out[i] = uint64(ffnLength[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// general types
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
@@ -264,6 +285,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"llama4",
|
||||
"mistral3",
|
||||
"mllama",
|
||||
"nemotron_h", "nemotron_h_moe",
|
||||
"nomic-bert",
|
||||
"olmo3",
|
||||
"qwen25vl",
|
||||
@@ -273,6 +295,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"glm4moelite",
|
||||
"glmocr",
|
||||
"lfm2",
|
||||
"lfm2moe",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
@@ -864,7 +887,9 @@ func (f GGML) FlashAttention() bool {
|
||||
"glmocr",
|
||||
"gptoss", "gpt-oss",
|
||||
"lfm2",
|
||||
"lfm2moe",
|
||||
"mistral3",
|
||||
"nemotron_h", "nemotron_h_moe",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3next",
|
||||
|
||||
752
kvcache/recurrent.go
Normal file
752
kvcache/recurrent.go
Normal file
@@ -0,0 +1,752 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCheckpointCount = 32
|
||||
DefaultCheckpointMinPos = int32(16)
|
||||
DefaultCheckpointInterval = int32(1280)
|
||||
)
|
||||
|
||||
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
|
||||
|
||||
// Config configures a shared hybrid recurrent cache.
|
||||
type RecurrentConfig struct {
|
||||
Shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||
ConvDim int
|
||||
ConvChannels int
|
||||
RecurrentStateSize int
|
||||
CheckpointLogPrefix string
|
||||
}
|
||||
|
||||
var (
|
||||
_ Cache = (*Recurrent)(nil)
|
||||
_ CheckpointCache = (*Recurrent)(nil)
|
||||
)
|
||||
|
||||
// Cache stores:
|
||||
// - a standard causal KV cache
|
||||
// - per-sequence conv state for recurrent operators
|
||||
// - per-sequence recurrent state for recurrent operators
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [convDim, convChannels]
|
||||
// Recurrent state shape (per layer, per sequence): [recurrentStateSize]
|
||||
type Recurrent struct {
|
||||
kv *Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
// Conv state dimensions
|
||||
convDim int
|
||||
convChannels int
|
||||
|
||||
// Recurrent state dimensions
|
||||
recurrentStateSize int
|
||||
|
||||
logPrefix string
|
||||
|
||||
// slot mapping for recurrent state (copy-on-write)
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
seqCounts map[int]int
|
||||
slotScratch [1]int32
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
|
||||
|
||||
// per-layer recurrent state buffers (allocated lazily)
|
||||
recurrentCtxs map[int]ml.Context
|
||||
recurrentStates map[int]ml.Tensor // [recurrentStateSize, maxSlots]
|
||||
|
||||
// recurrent checkpoints (per slot)
|
||||
checkpointCount int
|
||||
checkpointMinPos int32
|
||||
checkpointInterval int32
|
||||
checkpointCtxSize int
|
||||
checkpoints map[int]*slotCheckpointStore
|
||||
pendingRestore map[int]checkpointRestore
|
||||
curCheckpointPos []int32
|
||||
curCheckpointSlots map[int]int
|
||||
reserveCheckpoints bool
|
||||
checkpointConvCtxs map[int]ml.Context
|
||||
checkpointRecurCtxs map[int]ml.Context
|
||||
checkpointReserved map[int]struct{}
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
writableError error
|
||||
}
|
||||
|
||||
func NewRecurrentCache(config RecurrentConfig) *Recurrent {
|
||||
return &Recurrent{
|
||||
kv: NewCausalCache(config.Shift),
|
||||
convDim: config.ConvDim,
|
||||
convChannels: config.ConvChannels,
|
||||
recurrentStateSize: config.RecurrentStateSize,
|
||||
logPrefix: config.CheckpointLogPrefix,
|
||||
slotForSeq: make(map[int]int),
|
||||
seqCounts: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
recurrentCtxs: make(map[int]ml.Context),
|
||||
recurrentStates: make(map[int]ml.Tensor),
|
||||
checkpointCount: DefaultCheckpointCount,
|
||||
checkpointMinPos: DefaultCheckpointMinPos,
|
||||
checkpointInterval: DefaultCheckpointInterval,
|
||||
checkpoints: make(map[int]*slotCheckpointStore),
|
||||
pendingRestore: make(map[int]checkpointRestore),
|
||||
curCheckpointSlots: make(map[int]int),
|
||||
checkpointConvCtxs: make(map[int]ml.Context),
|
||||
checkpointRecurCtxs: make(map[int]ml.Context),
|
||||
checkpointReserved: make(map[int]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
c.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
c.pendingRestore = make(map[int]checkpointRestore)
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
c.curCheckpointSlots = make(map[int]int)
|
||||
c.checkpointReserved = make(map[int]struct{})
|
||||
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
|
||||
if c.checkpointCtxSize < 8 {
|
||||
c.checkpointCtxSize = 8
|
||||
}
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.recurrentCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointConvCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointRecurCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *Recurrent) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *Recurrent) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *Recurrent) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
if nTokens == 0 {
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
c.curSlots = c.curSlots[:0]
|
||||
c.curSlotsInput = nil
|
||||
c.curSeqTokens = 0
|
||||
c.reserveCheckpoints = false
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path for single-sequence batches (common during decode and prefill).
|
||||
firstSeq := batch.Sequences[0]
|
||||
singleSeq := true
|
||||
for _, s := range batch.Sequences[1:] {
|
||||
if s != firstSeq {
|
||||
singleSeq = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if singleSeq {
|
||||
return c.startForwardSingleSeq(ctx, firstSeq, nTokens, batch, reserve)
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for recurrent layers.
|
||||
seqCounts := c.seqCounts
|
||||
for s := range seqCounts {
|
||||
delete(seqCounts, s)
|
||||
}
|
||||
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if seqCounts[s] == 0 {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
}
|
||||
c.finalizeStartForward(ctx, batch, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch.
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
c.finalizeStartForward(ctx, batch, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) startForwardSingleSeq(ctx ml.Context, seq, seqTokens int, batch input.Batch, reserve bool) error {
|
||||
c.curSeqs = append(c.curSeqs[:0], seq)
|
||||
c.curSeqTokens = seqTokens
|
||||
|
||||
if reserve {
|
||||
c.curSlots = append(c.curSlots[:0], 0)
|
||||
c.finalizeStartForward(ctx, batch, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.slotForSeq[seq] = slot
|
||||
c.refCount[slot] = 1
|
||||
slotList := [1]int{slot}
|
||||
c.zeroSlots(ctx, slotList[:])
|
||||
}
|
||||
|
||||
c.curSlots = append(c.curSlots[:0], slot)
|
||||
c.finalizeStartForward(ctx, batch, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) finalizeStartForward(ctx ml.Context, batch input.Batch, reserve bool) {
|
||||
c.setCurSlotsInput(ctx)
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
c.reserveCheckpoints = reserve
|
||||
c.planCheckpoints(batch)
|
||||
}
|
||||
|
||||
func (c *Recurrent) setCurSlotsInput(ctx ml.Context) {
|
||||
c.curSlotsInput = c.slotsInput(ctx, c.curSlots)
|
||||
}
|
||||
|
||||
func (c *Recurrent) slotsInput(ctx ml.Context, slots []int) ml.Tensor {
|
||||
switch len(slots) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
c.slotScratch[0] = int32(slots[0])
|
||||
return ctx.Input().FromInts(c.slotScratch[:], 1)
|
||||
default:
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, v := range slots {
|
||||
slotIndices[i] = int32(v)
|
||||
}
|
||||
return ctx.Input().FromInts(slotIndices, len(slotIndices))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) freeSlot(slot int) {
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroSlots zeros recurrent state for the given slots across all cached layers.
|
||||
func (c *Recurrent) zeroSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
inputCtx := ctx.Input()
|
||||
slotsTensor := c.slotsInput(ctx, slots)
|
||||
|
||||
if len(c.convStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.recurrentStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.recurrentStateSize, len(slots))
|
||||
for _, buf := range c.recurrentStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures sequences have private slots (copy-on-write).
|
||||
func (c *Recurrent) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
}
|
||||
|
||||
c.setCurSlotsInput(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
|
||||
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
|
||||
|
||||
for _, buf := range c.convStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
if rows.DType() != ml.DTypeF32 {
|
||||
rows = rows.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
ctx.Forward(buf.SetRows(ctx, rows, dst))
|
||||
}
|
||||
|
||||
for _, buf := range c.recurrentStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
if rows.DType() != ml.DTypeF32 {
|
||||
rows = rows.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
ctx.Forward(buf.SetRows(ctx, rows, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
if c.validSlot(dstSlot) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if c.validSlot(srcSlot) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) CanResume(seq int, pos int32) bool {
|
||||
if !c.kv.CanResume(seq, pos) {
|
||||
return false
|
||||
}
|
||||
if pos == 0 {
|
||||
return true
|
||||
}
|
||||
return c.hasCheckpoint(seq, pos)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(c.pendingRestore, seq)
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok || !c.validSlot(slot) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Detach shared recurrent state/checkpoints before mutating checkpoint positions.
|
||||
if c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
slot = newSlot
|
||||
}
|
||||
|
||||
c.shiftCheckpoints(slot, beginIndex, endIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore, ok := c.pendingRestore[seq]
|
||||
if !ok || restore.pos+1 != beginIndex {
|
||||
return ErrNotSupported
|
||||
}
|
||||
if !c.restoreComplete(restore) {
|
||||
return ErrNotSupported
|
||||
}
|
||||
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
|
||||
restore.slot = newSlot
|
||||
c.pendingRestore[seq] = restore
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore := c.pendingRestore[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
return c.applyCheckpointRestore(restore)
|
||||
}
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.validSlot(slot) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.clearCheckpoints(slot)
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) validSlot(slot int) bool {
|
||||
return slot >= 0 && slot < len(c.refCount)
|
||||
}
|
||||
|
||||
func (c *Recurrent) SlotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
}
|
||||
|
||||
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
|
||||
func (c *Recurrent) contiguousSlots() (int, bool) {
|
||||
if len(c.curSlots) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
start := c.curSlots[0]
|
||||
for i, s := range c.curSlots {
|
||||
if s != start+i {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
return start, true
|
||||
}
|
||||
|
||||
func (c *Recurrent) SeqTokens() int {
|
||||
return c.curSeqTokens
|
||||
}
|
||||
|
||||
func (c *Recurrent) NumSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *Recurrent) convBuffer(layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *Recurrent) recurrentBuffer(layer int) ml.Tensor {
|
||||
if buf, ok := c.recurrentStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.recurrentCtxs[layer]; !ok {
|
||||
c.recurrentCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.recurrentCtxs[layer].Zeros(ml.DTypeF32, c.recurrentStateSize, c.maxSequences)
|
||||
c.recurrentStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureWritable(ctx ml.Context) error {
|
||||
c.ensureWritableOnce(ctx)
|
||||
return c.writableError
|
||||
}
|
||||
|
||||
func (c *Recurrent) currentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int) ml.Tensor {
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
offset := start * buf.Stride(1)
|
||||
return buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
|
||||
}
|
||||
|
||||
return buf.Rows(ctx, c.SlotsTensor())
|
||||
}
|
||||
|
||||
func (c *Recurrent) writeCurrentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int, src ml.Tensor) {
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
|
||||
ctx.Forward(src.Copy(ctx, view))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Forward(buf.SetRows(ctx, src, c.SlotsTensor()))
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureWritableOnce(ctx ml.Context) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
}
|
||||
|
||||
// ConvState returns conv state for current batch sequences as [convDim, convChannels, nSeqs].
|
||||
func (c *Recurrent) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf := c.convBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.convDim*c.convChannels)
|
||||
return cur.Reshape(ctx, c.convDim, c.convChannels, c.NumSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes new conv state for current batch sequences.
|
||||
func (c *Recurrent) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(layer)
|
||||
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.NumSeqs())
|
||||
srcF32 := src
|
||||
if src.DType() != ml.DTypeF32 {
|
||||
srcF32 = src.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
c.writeCurrentSlotRows(ctx, buf, c.convDim*c.convChannels, srcF32)
|
||||
|
||||
c.captureConvCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// RecurrentState returns recurrent state for current batch sequences with shape [dims..., nSeqs].
|
||||
func (c *Recurrent) RecurrentState(ctx ml.Context, layer int, dims ...int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dims) == 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
|
||||
size := 1
|
||||
for _, d := range dims {
|
||||
if d <= 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
size *= d
|
||||
}
|
||||
if size != c.recurrentStateSize {
|
||||
return nil, fmt.Errorf("%w: got %v (size %d), want size %d", ErrInvalidRecurrentShape, dims, size, c.recurrentStateSize)
|
||||
}
|
||||
|
||||
buf := c.recurrentBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
|
||||
shape := make([]int, 0, len(dims)+1)
|
||||
shape = append(shape, dims...)
|
||||
shape = append(shape, c.NumSeqs())
|
||||
return cur.Reshape(ctx, shape...), nil
|
||||
}
|
||||
|
||||
// RecurrentState4D returns recurrent state as [dim0, dim1, dim2, nSeqs].
|
||||
func (c *Recurrent) RecurrentState4D(ctx ml.Context, layer int, dim0, dim1, dim2 int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dim0 <= 0 || dim1 <= 0 || dim2 <= 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
|
||||
size := dim0 * dim1 * dim2
|
||||
if size != c.recurrentStateSize {
|
||||
return nil, fmt.Errorf("%w: got [%d %d %d] (size %d), want size %d", ErrInvalidRecurrentShape, dim0, dim1, dim2, size, c.recurrentStateSize)
|
||||
}
|
||||
|
||||
buf := c.recurrentBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
|
||||
return cur.Reshape(ctx, dim0, dim1, dim2, c.NumSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateRecurrentState writes new recurrent state for current batch sequences.
|
||||
func (c *Recurrent) UpdateRecurrentState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.recurrentBuffer(layer)
|
||||
src := newState.Reshape(ctx, c.recurrentStateSize, c.NumSeqs())
|
||||
srcF32 := src
|
||||
if src.DType() != ml.DTypeF32 {
|
||||
srcF32 = src.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
c.writeCurrentSlotRows(ctx, buf, c.recurrentStateSize, srcF32)
|
||||
|
||||
c.captureRecurrentCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
||||
func (c *Recurrent) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *Recurrent) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
}
|
||||
561
kvcache/recurrent_checkpoints.go
Normal file
561
kvcache/recurrent_checkpoints.go
Normal file
@@ -0,0 +1,561 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
|
||||
// memory usage while preserving prefix reuse for recurrent state.
|
||||
|
||||
type checkpointEntry struct {
|
||||
pos int32
|
||||
conv map[int]ml.Tensor
|
||||
recurrent map[int]ml.Tensor
|
||||
}
|
||||
|
||||
type slotCheckpointStore struct {
|
||||
entries []checkpointEntry
|
||||
size int
|
||||
next int
|
||||
lastPos int32
|
||||
}
|
||||
|
||||
type checkpointRestore struct {
|
||||
slot int
|
||||
idx int
|
||||
pos int32
|
||||
}
|
||||
|
||||
func newSlotCheckpointStore(n int) *slotCheckpointStore {
|
||||
entries := make([]checkpointEntry, n)
|
||||
for i := range entries {
|
||||
entries[i].pos = -1
|
||||
}
|
||||
return &slotCheckpointStore{
|
||||
entries: entries,
|
||||
lastPos: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) reset() {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
for i := range s.entries {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) record(pos int32) int {
|
||||
if len(s.entries) == 0 {
|
||||
return -1
|
||||
}
|
||||
idx := s.next
|
||||
s.next = (s.next + 1) % len(s.entries)
|
||||
if s.size < len(s.entries) {
|
||||
s.size++
|
||||
}
|
||||
s.entries[idx].pos = pos
|
||||
s.lastPos = pos
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
|
||||
bestIdx := -1
|
||||
bestPos := int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 || pos >= targetPos {
|
||||
continue
|
||||
}
|
||||
if pos > bestPos {
|
||||
bestPos = pos
|
||||
bestIdx = i
|
||||
}
|
||||
}
|
||||
if bestIdx < 0 {
|
||||
return -1, -1, false
|
||||
}
|
||||
return bestIdx, bestPos, true
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) pruneAfter(pos int32) {
|
||||
if len(s.entries) == 0 {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
minPos := int32(math.MaxInt32)
|
||||
minIdx := 0
|
||||
for i := range s.entries {
|
||||
if s.entries[i].pos > pos {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
if s.entries[i].pos >= 0 {
|
||||
size++
|
||||
if s.entries[i].pos < minPos {
|
||||
minPos = s.entries[i].pos
|
||||
minIdx = i
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.size = size
|
||||
if size == 0 {
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
if next != -1 {
|
||||
s.next = next
|
||||
} else {
|
||||
// Full ring: overwrite the oldest checkpoint next.
|
||||
s.next = minIdx
|
||||
}
|
||||
s.lastPos = pos
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) shiftRange(beginIndex, endIndex int32) {
|
||||
if len(s.entries) == 0 {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
offset := beginIndex - endIndex
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
minPos := int32(math.MaxInt32)
|
||||
maxPos := int32(-1)
|
||||
minIdx := 0
|
||||
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos >= 0 {
|
||||
if pos >= beginIndex && pos < endIndex {
|
||||
s.entries[i].pos = -1
|
||||
} else if pos >= endIndex {
|
||||
s.entries[i].pos = pos + offset
|
||||
}
|
||||
}
|
||||
|
||||
pos = s.entries[i].pos
|
||||
if pos >= 0 {
|
||||
size++
|
||||
if pos < minPos {
|
||||
minPos = pos
|
||||
minIdx = i
|
||||
}
|
||||
if pos > maxPos {
|
||||
maxPos = pos
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.size = size
|
||||
if size == 0 {
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
if next != -1 {
|
||||
s.next = next
|
||||
} else {
|
||||
// Full ring: overwrite the oldest checkpoint next.
|
||||
s.next = minIdx
|
||||
}
|
||||
s.lastPos = maxPos
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
|
||||
minPos = int32(math.MaxInt32)
|
||||
maxPos = int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
size++
|
||||
if pos < minPos {
|
||||
minPos = pos
|
||||
}
|
||||
if pos > maxPos {
|
||||
maxPos = pos
|
||||
}
|
||||
}
|
||||
if size == 0 {
|
||||
minPos = -1
|
||||
maxPos = -1
|
||||
}
|
||||
return size, minPos, maxPos, s.lastPos
|
||||
}
|
||||
|
||||
func (c *Recurrent) checkpointTag() string {
|
||||
if c.logPrefix == "" {
|
||||
return "kvcache.recurrent"
|
||||
}
|
||||
return c.logPrefix
|
||||
}
|
||||
|
||||
func (c *Recurrent) planCheckpoints(batch input.Batch) {
|
||||
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cap(c.curCheckpointPos) < len(c.curSeqs) {
|
||||
c.curCheckpointPos = make([]int32, len(c.curSeqs))
|
||||
} else {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
|
||||
}
|
||||
for i := range c.curCheckpointPos {
|
||||
c.curCheckpointPos[i] = -1
|
||||
}
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
|
||||
posMax := make(map[int]int32, len(c.curSeqs))
|
||||
for i, seq := range batch.Sequences {
|
||||
pos := batch.Positions[i]
|
||||
if cur, ok := posMax[seq]; !ok || pos > cur {
|
||||
posMax[seq] = pos
|
||||
}
|
||||
}
|
||||
|
||||
for i, seq := range c.curSeqs {
|
||||
pos, ok := posMax[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if pos < c.checkpointMinPos {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
store := c.checkpointStore(slot)
|
||||
lastPos := store.lastPos
|
||||
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
|
||||
c.curCheckpointPos[i] = pos
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) checkpointStore(slot int) *slotCheckpointStore {
|
||||
store, ok := c.checkpoints[slot]
|
||||
if ok {
|
||||
return store
|
||||
}
|
||||
store = newSlotCheckpointStore(c.checkpointCount)
|
||||
c.checkpoints[slot] = store
|
||||
return store
|
||||
}
|
||||
|
||||
func (c *Recurrent) checkpointIndexForSlot(slot int, pos int32) int {
|
||||
if c.checkpointCount == 0 {
|
||||
return -1
|
||||
}
|
||||
if idx, ok := c.curCheckpointSlots[slot]; ok {
|
||||
return idx
|
||||
}
|
||||
store := c.checkpointStore(slot)
|
||||
idx := store.record(pos)
|
||||
if idx >= 0 {
|
||||
c.curCheckpointSlots[slot] = idx
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func (c *Recurrent) hasCheckpoint(seq int, pos int32) bool {
|
||||
if pos <= 0 {
|
||||
return false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, _, ok = store.bestIndex(pos)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *Recurrent) PrepareRestore(seq int, targetPos int32) (int32, bool) {
|
||||
if targetPos <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
|
||||
return 0, false
|
||||
}
|
||||
idx, pos, ok := store.bestIndex(targetPos)
|
||||
if !ok {
|
||||
size, minPos, maxPos, lastPos := store.window()
|
||||
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
|
||||
"min", minPos, "max", maxPos, "last", lastPos)
|
||||
return 0, false
|
||||
}
|
||||
c.pendingRestore[seq] = checkpointRestore{
|
||||
slot: slot,
|
||||
idx: idx,
|
||||
pos: pos,
|
||||
}
|
||||
return pos + 1, true
|
||||
}
|
||||
|
||||
func (c *Recurrent) applyCheckpointRestore(restore checkpointRestore) error {
|
||||
entry, ok := c.restoreEntry(restore)
|
||||
if !ok {
|
||||
return ErrNotSupported
|
||||
}
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
|
||||
for layer, src := range entry.conv {
|
||||
buf := c.convBuffer(layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
for layer, src := range entry.recurrent {
|
||||
buf := c.recurrentBuffer(layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
|
||||
if len(entry.conv) > 0 || len(entry.recurrent) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
store := c.checkpoints[restore.slot]
|
||||
store.pruneAfter(restore.pos)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) restoreComplete(restore checkpointRestore) bool {
|
||||
_, ok := c.restoreEntry(restore)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *Recurrent) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
|
||||
store, ok := c.checkpoints[restore.slot]
|
||||
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
|
||||
return nil, false
|
||||
}
|
||||
entry := &store.entries[restore.idx]
|
||||
if entry.pos < 0 {
|
||||
return nil, false
|
||||
}
|
||||
if !c.entryComplete(entry) {
|
||||
return nil, false
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (c *Recurrent) entryComplete(entry *checkpointEntry) bool {
|
||||
for layer := range c.convStates {
|
||||
if entry.conv == nil || entry.conv[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for layer := range c.recurrentStates {
|
||||
if entry.recurrent == nil || entry.recurrent[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Recurrent) clearCheckpoints(slot int) {
|
||||
if store, ok := c.checkpoints[slot]; ok {
|
||||
store.reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) shiftCheckpoints(slot int, beginIndex, endIndex int32) {
|
||||
if store, ok := c.checkpoints[slot]; ok {
|
||||
store.shiftRange(beginIndex, endIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
srcStore, ok := c.checkpoints[srcSlot]
|
||||
if !ok || srcStore.size == 0 {
|
||||
return
|
||||
}
|
||||
dstStore := c.checkpointStore(dstSlot)
|
||||
dstStore.size = srcStore.size
|
||||
dstStore.next = srcStore.next
|
||||
dstStore.lastPos = srcStore.lastPos
|
||||
|
||||
for i := range srcStore.entries {
|
||||
srcEntry := &srcStore.entries[i]
|
||||
dstEntry := &dstStore.entries[i]
|
||||
dstEntry.pos = srcEntry.pos
|
||||
if srcEntry.conv != nil {
|
||||
if dstEntry.conv == nil {
|
||||
dstEntry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.conv {
|
||||
dst := c.ensureCheckpointConv(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
if srcEntry.recurrent != nil {
|
||||
if dstEntry.recurrent == nil {
|
||||
dstEntry.recurrent = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.recurrent {
|
||||
dst := c.ensureCheckpointRecurrent(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointConv(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointConv(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) captureRecurrentCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointRecurrent(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointRecurrent(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.conv == nil {
|
||||
entry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.conv[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointConvCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointConvCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
|
||||
entry.conv[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureCheckpointRecurrent(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.recurrent == nil {
|
||||
entry.recurrent = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.recurrent[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointRecurCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointRecurCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.recurrentStateSize, 1)
|
||||
entry.recurrent[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *Recurrent) reserveCheckpointConv(layer int) {
|
||||
key := checkpointReserveKey(layer, 0)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointConv(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func (c *Recurrent) reserveCheckpointRecurrent(layer int) {
|
||||
key := checkpointReserveKey(layer, 1)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointRecurrent(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func checkpointReserveKey(layer int, kind int) int {
|
||||
return layer*2 + kind
|
||||
}
|
||||
288
kvcache/recurrent_checkpoints_test.go
Normal file
288
kvcache/recurrent_checkpoints_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func newTestCache() *Recurrent {
|
||||
return NewRecurrentCache(RecurrentConfig{ConvDim: 1, ConvChannels: 2, RecurrentStateSize: 2})
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
|
||||
store := newSlotCheckpointStore(2)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
|
||||
_, pos, ok := store.bestIndex(15)
|
||||
if !ok || pos != 10 {
|
||||
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
store.record(30) // overwrite oldest (10)
|
||||
|
||||
if _, _, ok := store.bestIndex(15); ok {
|
||||
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(40)
|
||||
if !ok || pos != 30 {
|
||||
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachePrepareRestore(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
store := cache.checkpointStore(0)
|
||||
store.record(5)
|
||||
store.record(9)
|
||||
store.record(15)
|
||||
|
||||
restorePos, ok := cache.PrepareRestore(1, 12)
|
||||
if !ok {
|
||||
t.Fatalf("expected restore ok")
|
||||
}
|
||||
if restorePos != 10 {
|
||||
t.Fatalf("expected restorePos 10, got %d", restorePos)
|
||||
}
|
||||
rest, ok := cache.pendingRestore[1]
|
||||
if !ok {
|
||||
t.Fatalf("expected pending restore entry")
|
||||
}
|
||||
if rest.pos != 9 {
|
||||
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.pruneAfter(20)
|
||||
|
||||
if store.lastPos != 20 {
|
||||
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, pos, ok := store.bestIndex(25)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(35)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Simulate layer 0 requires both conv and recurrent checkpoints.
|
||||
cache.convStates[0] = nil
|
||||
cache.recurrentStates[0] = nil
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
entry := &store.entries[idx]
|
||||
entry.conv = map[int]ml.Tensor{0: nil}
|
||||
// entry.recurrent intentionally missing
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
err := cache.Remove(1, 10, math.MaxInt32)
|
||||
if !errors.Is(err, ErrNotSupported) {
|
||||
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
restore := cache.pendingRestore[1]
|
||||
if !cache.restoreComplete(restore) {
|
||||
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRecurrentStateShapeValidation(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
_, err := cache.RecurrentState(nil, 0, 3)
|
||||
if !errors.Is(err, ErrInvalidRecurrentShape) {
|
||||
t.Fatalf("expected ErrInvalidRecurrentShape, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreShiftRange(t *testing.T) {
|
||||
store := newSlotCheckpointStore(5)
|
||||
store.record(1)
|
||||
store.record(4)
|
||||
store.record(7)
|
||||
store.record(10)
|
||||
|
||||
store.shiftRange(2, 6)
|
||||
|
||||
var positions []int32
|
||||
for i := range store.entries {
|
||||
if store.entries[i].pos >= 0 {
|
||||
positions = append(positions, store.entries[i].pos)
|
||||
}
|
||||
}
|
||||
slices.Sort(positions)
|
||||
|
||||
want := []int32{1, 3, 6}
|
||||
if !slices.Equal(positions, want) {
|
||||
t.Fatalf("unexpected shifted positions: got=%v want=%v", positions, want)
|
||||
}
|
||||
if store.lastPos != 6 {
|
||||
t.Fatalf("expected lastPos 6, got %d", store.lastPos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRemoveMiddleShiftsCheckpoints(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: 0, pos: 1}
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
store.record(1)
|
||||
store.record(4)
|
||||
store.record(7)
|
||||
store.record(10)
|
||||
|
||||
if err := cache.Remove(1, 2, 6); err != nil {
|
||||
t.Fatalf("expected middle remove to succeed, got %v", err)
|
||||
}
|
||||
|
||||
if _, ok := cache.pendingRestore[1]; ok {
|
||||
t.Fatalf("expected pending restore to be cleared after middle remove")
|
||||
}
|
||||
|
||||
var positions []int32
|
||||
for i := range store.entries {
|
||||
if store.entries[i].pos >= 0 {
|
||||
positions = append(positions, store.entries[i].pos)
|
||||
}
|
||||
}
|
||||
slices.Sort(positions)
|
||||
|
||||
want := []int32{1, 3, 6}
|
||||
if !slices.Equal(positions, want) {
|
||||
t.Fatalf("unexpected checkpoint positions after remove: got=%v want=%v", positions, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.entries[0].conv = make(map[int]ml.Tensor)
|
||||
store.entries[0].conv[0] = nil
|
||||
store.entries[0].recurrent = make(map[int]ml.Tensor)
|
||||
store.entries[0].recurrent[0] = nil
|
||||
|
||||
store.record(40)
|
||||
|
||||
if store.entries[0].conv == nil {
|
||||
t.Fatalf("expected conv map to be preserved on reuse")
|
||||
}
|
||||
if store.entries[0].recurrent == nil {
|
||||
t.Fatalf("expected recurrent map to be preserved on reuse")
|
||||
}
|
||||
if store.entries[0].pos != 40 {
|
||||
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
|
||||
store := newSlotCheckpointStore(2)
|
||||
|
||||
idx1 := store.record(10)
|
||||
idx2 := store.record(20)
|
||||
|
||||
if idx1 != 0 || idx2 != 1 {
|
||||
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
|
||||
}
|
||||
if store.size != 2 {
|
||||
t.Fatalf("expected size 2, got %d", store.size)
|
||||
}
|
||||
|
||||
_, pos1, ok1 := store.bestIndex(15)
|
||||
_, pos2, ok2 := store.bestIndex(25)
|
||||
|
||||
if !ok1 || pos1 != 10 {
|
||||
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
|
||||
}
|
||||
if !ok2 || pos2 != 20 {
|
||||
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
|
||||
store := newSlotCheckpointStore(0)
|
||||
|
||||
idx := store.record(10)
|
||||
if idx != -1 {
|
||||
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(15)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint for empty buffer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.pruneAfter(5)
|
||||
|
||||
if store.size != 0 {
|
||||
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
|
||||
}
|
||||
if store.lastPos != -1 {
|
||||
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(100)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint after pruning all")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: jmorganca <jmorganca@gmail.com>
|
||||
Date: Sun, 22 Feb 2026 14:12:30 -0800
|
||||
Subject: [PATCH] ggml-metal: guard mul_mat_id map0 and add ne20=22
|
||||
specialization
|
||||
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 3 ++-
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 1 +
|
||||
2 files changed, 3 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index 4ac135603..ac5ad53db 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
// ne21 = n_rows (batch size)
|
||||
const int ne21_mm_id_min = 32;
|
||||
|
||||
- if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
|
||||
+ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
|
||||
+ (ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
//switch (op->src[0]->type) {
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index c37447a10..4f338aa13 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
+template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
||||
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm_id(
|
||||
@@ -163,6 +163,7 @@ type Tensor interface {
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
|
||||
SSMConv(ctx Context, kernel Tensor) Tensor
|
||||
SSMScan(ctx Context, x, dt, A, B, C, ids Tensor) Tensor
|
||||
|
||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
|
||||
@@ -1662,6 +1662,13 @@ func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SSMScan(ctx ml.Context, x, dt, A, B, C, ids ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_ssm_scan(ctx.(*Context).ctx, t.t, x.(*Tensor).t, dt.(*Tensor).t, A.(*Tensor).t, B.(*Tensor).t, C.(*Tensor).t, ids.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
||||
@@ -12249,6 +12249,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
||||
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm_id(
|
||||
|
||||
@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
// ne21 = n_rows (batch size)
|
||||
const int ne21_mm_id_min = 32;
|
||||
|
||||
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
|
||||
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
|
||||
(ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
//switch (op->src[0]->type) {
|
||||
|
||||
@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
||||
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm_id(
|
||||
|
||||
@@ -67,6 +67,7 @@ func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor
|
||||
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SSMScan(ctx ml.Context, _, _, _, _, _, _ ml.Tensor) ml.Tensor { return f }
|
||||
|
||||
func (m *fakeBackend) Get(name string) ml.Tensor {
|
||||
if slices.Contains(m.names, name) {
|
||||
|
||||
@@ -1,410 +1,44 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var _ kvcache.Cache = (*HybridCache)(nil)
|
||||
var (
|
||||
_ kvcache.Cache = (*HybridCache)(nil)
|
||||
_ kvcache.CheckpointCache = (*HybridCache)(nil)
|
||||
)
|
||||
|
||||
// HybridCache stores:
|
||||
// - a standard causal KV cache for attention layers
|
||||
// - a per-sequence recurrent conv state for shortconv layers
|
||||
// HybridCache adapts the shared recurrent cache for LFM2:
|
||||
// - KV attention cache is handled by the embedded causal cache
|
||||
// - shortconv recurrent state uses conv slots [dConv, hiddenSize]
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
|
||||
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
|
||||
// This reuses shared checkpoint/restore logic for prefix mismatch recovery.
|
||||
type HybridCache struct {
|
||||
kv *kvcache.Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
hiddenSize int
|
||||
dConv int
|
||||
|
||||
// slot mapping for recurrent state
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
// track any error from EnsureWritable to propagate later
|
||||
writableError error
|
||||
*kvcache.Recurrent
|
||||
}
|
||||
|
||||
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
|
||||
return &HybridCache{
|
||||
kv: kvcache.NewCausalCache(shift),
|
||||
hiddenSize: hiddenSize,
|
||||
dConv: dConv,
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
|
||||
Shift: shift,
|
||||
ConvDim: dConv,
|
||||
ConvChannels: hiddenSize,
|
||||
RecurrentStateSize: 1, // LFM2 uses only conv state; keep a minimal recurrent buffer size.
|
||||
CheckpointLogPrefix: "lfm2",
|
||||
})
|
||||
|
||||
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for shortconv.
|
||||
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
|
||||
seqCounts := make(map[int]int)
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if _, ok := seqCounts[s]; !ok {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
if len(c.curSeqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
// When reserving memory for estimation, use fake slot assignments
|
||||
// without modifying permanent state (slotForSeq, refCount)
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
slots := make([]int32, nSeqs)
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
slots[i] = int32(i)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int // track newly allocated slots that need zeroing
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
// Zero conv state for newly allocated slots to clear stale data from previous sequences
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroConvSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
// Create a tensor for the current slots
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
// Reset writable state for new forward pass
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, kvcache.ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) freeSlot(slot int) {
|
||||
// Bounds check before freeing
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroConvSlots zeros the conv state for the given slots across all layers.
|
||||
// This must be called when recycling slots to prevent stale state from affecting new sequences.
|
||||
func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 || len(c.convStates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Use input context for creating tensors
|
||||
inputCtx := ctx.Input()
|
||||
|
||||
// Create slot indices tensor
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, s := range slots {
|
||||
slotIndices[i] = int32(s)
|
||||
}
|
||||
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
||||
|
||||
// Create zero tensor for the slots (SetRows requires F32 source)
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots))
|
||||
|
||||
// Zero each layer's conv state for these slots
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
|
||||
// Returns an error if slot allocation fails.
|
||||
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Bounds check
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
// Copy existing conv state for all initialized layers
|
||||
for _, buf := range c.convStates {
|
||||
// buf: [dConv*hiddenSize, maxSlots]
|
||||
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
|
||||
// SetRows requires F32 source
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild current slots tensor
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
|
||||
// On the first write to dst, EnsureWritable will create a private slot.
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
// Bounds check before decrementing
|
||||
if dstSlot >= 0 && dstSlot < len(c.refCount) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
// src may not have a slot yet; dst will allocate on demand
|
||||
return
|
||||
}
|
||||
|
||||
// Bounds check before incrementing
|
||||
if srcSlot >= 0 && srcSlot < len(c.refCount) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
||||
return c.kv.CanResume(seq, pos)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For recurrent state, any removal invalidates the state because
|
||||
// the state at position N depends on all previous positions.
|
||||
// Drop the slot mapping so it resets on next use.
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bounds check
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
return &HybridCache{Recurrent: base}
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
return c.SlotsTensor()
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.curSeqTokens
|
||||
return c.SeqTokens()
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
|
||||
// Returns an error if copy-on-write allocation fails.
|
||||
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes a new conv state for current batch sequences.
|
||||
// newState must have shape [dConv, hiddenSize, nSeqs].
|
||||
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
|
||||
// SetRows requires F32 source
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
|
||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *HybridCache) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
return c.NumSeqs()
|
||||
}
|
||||
|
||||
@@ -4,441 +4,39 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// TestHybridCache tests verify the slot management logic of HybridCache.
|
||||
// These tests focus on the recurrent state slot allocation, reference counting,
|
||||
// and copy-on-write semantics without requiring a full ML backend.
|
||||
func TestHybridCache_New(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 2)
|
||||
if cache == nil {
|
||||
t.Fatal("expected cache to be created")
|
||||
}
|
||||
|
||||
// createSlotOnlyCache creates a HybridCache with only the slot management
|
||||
// fields initialized. Used to test slot logic in isolation.
|
||||
func createSlotOnlyCache(maxSequences int) *HybridCache {
|
||||
return &HybridCache{
|
||||
hiddenSize: 256,
|
||||
dConv: 3,
|
||||
maxSequences: maxSequences,
|
||||
refCount: make([]int, maxSequences),
|
||||
freeSlots: initFreeSlots(maxSequences),
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
if cache.Recurrent == nil {
|
||||
t.Fatal("expected embedded recurrent cache to be created")
|
||||
}
|
||||
}
|
||||
|
||||
func initFreeSlots(n int) []int {
|
||||
slots := make([]int, 0, n)
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
slots = append(slots, i)
|
||||
}
|
||||
return slots
|
||||
}
|
||||
func TestHybridCache_ImplementsCheckpointCache(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 2)
|
||||
|
||||
func TestHybridCache_SlotAllocation(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Verify initial state
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Allocate all slots
|
||||
for range 4 {
|
||||
slot, err := cache.allocSlot()
|
||||
if err != nil {
|
||||
t.Fatalf("allocSlot failed: %v", err)
|
||||
}
|
||||
cache.refCount[slot] = 1
|
||||
}
|
||||
|
||||
// Should be full now
|
||||
if len(cache.freeSlots) != 0 {
|
||||
t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Trying to allocate another should fail
|
||||
_, err := cache.allocSlot()
|
||||
if err != kvcache.ErrKvCacheFull {
|
||||
t.Errorf("expected ErrKvCacheFull, got %v", err)
|
||||
if _, ok := any(cache).(kvcache.CheckpointCache); !ok {
|
||||
t.Fatal("expected HybridCache to implement CheckpointCache")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotReuse(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
func TestHybridCache_DefaultBatchState(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 2)
|
||||
|
||||
// Allocate a slot
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free it
|
||||
cache.refCount[slot1] = 0
|
||||
cache.freeSlot(slot1)
|
||||
|
||||
// Allocate again - should get the same slot back (LIFO)
|
||||
slot2, _ := cache.allocSlot()
|
||||
if slot2 != slot1 {
|
||||
t.Errorf("expected slot %d to be reused, got %d", slot1, slot2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Simulate sharing slot with seq 2 (copy-on-write style)
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Should share the same slot
|
||||
if cache.slotForSeq[2] != slot1 {
|
||||
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
||||
if got := cache.numSeqs(); got != 0 {
|
||||
t.Fatalf("expected 0 sequences before StartForward, got %d", got)
|
||||
}
|
||||
|
||||
// Ref count should be 2
|
||||
if cache.refCount[slot1] != 2 {
|
||||
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Share with seq 2
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Unshare seq 2
|
||||
cache.refCount[slot1]--
|
||||
delete(cache.slotForSeq, 2)
|
||||
|
||||
// Ref count should be back to 1
|
||||
if cache.refCount[slot1] != 1 {
|
||||
t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1])
|
||||
if got := cache.seqTokens(); got != 0 {
|
||||
t.Fatalf("expected 0 sequence tokens before StartForward, got %d", got)
|
||||
}
|
||||
|
||||
// Seq 2 should no longer have a slot
|
||||
if _, ok := cache.slotForSeq[2]; ok {
|
||||
t.Error("seq 2 should not have a slot after unshare")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotFreeWhenUnused(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
initialFreeSlots := len(cache.freeSlots)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free the slot when refCount drops to 0
|
||||
cache.refCount[slot1]--
|
||||
if cache.refCount[slot1] <= 0 {
|
||||
cache.refCount[slot1] = 0
|
||||
cache.freeSlot(slot1)
|
||||
}
|
||||
delete(cache.slotForSeq, 1)
|
||||
|
||||
// Slot should be freed
|
||||
if len(cache.freeSlots) != initialFreeSlots {
|
||||
t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Ref count should be 0
|
||||
if cache.refCount[slot1] != 0 {
|
||||
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotOverwrite(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slots for seq 1 and seq 2
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
slot2, _ := cache.allocSlot()
|
||||
cache.slotForSeq[2] = slot2
|
||||
cache.refCount[slot2] = 1
|
||||
|
||||
initialFreeSlots := len(cache.freeSlots)
|
||||
|
||||
// Simulate overwriting seq 2's slot with slot1 (sharing)
|
||||
// First free the old slot
|
||||
cache.refCount[slot2]--
|
||||
if cache.refCount[slot2] <= 0 {
|
||||
cache.refCount[slot2] = 0
|
||||
cache.freeSlot(slot2)
|
||||
}
|
||||
// Then share slot1
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Seq 2 should now share slot1
|
||||
if cache.slotForSeq[2] != slot1 {
|
||||
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
||||
}
|
||||
|
||||
// Old slot2 should be freed
|
||||
if len(cache.freeSlots) != initialFreeSlots+1 {
|
||||
t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_BoundsChecking(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Test freeing invalid slot (should not panic)
|
||||
cache.freeSlot(-1)
|
||||
cache.freeSlot(100) // out of bounds
|
||||
|
||||
// freeSlot does bounds checking, so invalid slots should be ignored
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) {
|
||||
cache := createSlotOnlyCache(8)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Fork to seq 2, 3, 4 (all share slot1)
|
||||
for _, seq := range []int{2, 3, 4} {
|
||||
cache.slotForSeq[seq] = slot1
|
||||
cache.refCount[slot1]++
|
||||
}
|
||||
|
||||
// Ref count should be 4
|
||||
if cache.refCount[slot1] != 4 {
|
||||
t.Errorf("expected refCount 4, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Remove seq 2, 3
|
||||
for _, seq := range []int{2, 3} {
|
||||
delete(cache.slotForSeq, seq)
|
||||
cache.refCount[slot1]--
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 2 {
|
||||
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Slot should still be allocated (not in free list)
|
||||
found := false
|
||||
for _, s := range cache.freeSlots {
|
||||
if s == slot1 {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
t.Error("slot1 should not be in free list yet")
|
||||
}
|
||||
|
||||
// Remove remaining sequences
|
||||
for _, seq := range []int{1, 4} {
|
||||
delete(cache.slotForSeq, seq)
|
||||
cache.refCount[slot1]--
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 0 {
|
||||
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_ChainedSharing(t *testing.T) {
|
||||
cache := createSlotOnlyCache(8)
|
||||
|
||||
// Create seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Share 1 -> 2
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Share 2 -> 3 (should still share slot1)
|
||||
cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// All should share slot1
|
||||
if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 {
|
||||
t.Error("all sequences should share slot1")
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 3 {
|
||||
t.Errorf("expected refCount 3, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_CacheParameters(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5
|
||||
|
||||
if cache.hiddenSize != 512 {
|
||||
t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize)
|
||||
}
|
||||
if cache.dConv != 5 {
|
||||
t.Errorf("expected dConv 5, got %d", cache.dConv)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_NumSeqs(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially no sequences
|
||||
if cache.numSeqs() != 0 {
|
||||
t.Errorf("expected 0 seqs, got %d", cache.numSeqs())
|
||||
}
|
||||
|
||||
// Manually set up current batch state
|
||||
cache.curSeqs = []int{1, 2, 3}
|
||||
|
||||
if cache.numSeqs() != 3 {
|
||||
t.Errorf("expected 3 seqs, got %d", cache.numSeqs())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SeqTokens(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially 0
|
||||
if cache.seqTokens() != 0 {
|
||||
t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens())
|
||||
}
|
||||
|
||||
// Manually set up current batch state
|
||||
cache.curSeqTokens = 16
|
||||
|
||||
if cache.seqTokens() != 16 {
|
||||
t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens())
|
||||
}
|
||||
}
|
||||
|
||||
// Test that Seqs returns a clone of curSeqs
|
||||
func TestHybridCache_Seqs_ReturnsClone(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
cache.curSeqs = []int{1, 2, 3}
|
||||
|
||||
seqs := cache.Seqs()
|
||||
|
||||
// Modify returned slice
|
||||
seqs[0] = 999
|
||||
|
||||
// Original should be unchanged
|
||||
if cache.curSeqs[0] != 1 {
|
||||
t.Error("Seqs should return a clone, not the original slice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_IsSupportedForBatch(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially not supported (no batch set up)
|
||||
if cache.IsSupportedForBatch() {
|
||||
t.Error("expected IsSupportedForBatch to be false initially")
|
||||
}
|
||||
|
||||
// Set up a valid batch
|
||||
cache.curSeqTokens = 1
|
||||
cache.curSeqs = []int{1}
|
||||
|
||||
if !cache.IsSupportedForBatch() {
|
||||
t.Error("expected IsSupportedForBatch to be true with valid batch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// zeroConvSlots should handle empty slots without panicking
|
||||
cache.zeroConvSlots(nil, nil)
|
||||
cache.zeroConvSlots(nil, []int{})
|
||||
|
||||
// zeroConvSlots should handle empty convStates without panicking
|
||||
cache.zeroConvSlots(nil, []int{0, 1, 2})
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free the slot (simulating sequence removal)
|
||||
cache.refCount[slot1]--
|
||||
cache.freeSlot(slot1)
|
||||
delete(cache.slotForSeq, 1)
|
||||
|
||||
// Verify slot is in free list
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Allocate for new seq 2 - should get recycled slot
|
||||
slot2, _ := cache.allocSlot()
|
||||
if slot2 != slot1 {
|
||||
t.Errorf("expected recycled slot %d, got %d", slot1, slot2)
|
||||
}
|
||||
|
||||
// This recycled slot would need zeroing in the real implementation
|
||||
// The actual zeroing is tested via integration tests since it requires ML context
|
||||
}
|
||||
|
||||
func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Simulate the slot allocation flow from StartForward
|
||||
// When a sequence doesn't have a slot, it gets allocated and tracked as "new"
|
||||
|
||||
newSlots := []int{}
|
||||
|
||||
// Seq 1 doesn't have a slot - allocate and track
|
||||
seq := 1
|
||||
if _, ok := cache.slotForSeq[seq]; !ok {
|
||||
slot, err := cache.allocSlot()
|
||||
if err != nil {
|
||||
t.Fatalf("allocSlot failed: %v", err)
|
||||
}
|
||||
cache.slotForSeq[seq] = slot
|
||||
cache.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
|
||||
// Verify newSlots contains the allocated slot
|
||||
if len(newSlots) != 1 {
|
||||
t.Errorf("expected 1 new slot, got %d", len(newSlots))
|
||||
}
|
||||
|
||||
// Seq 1 already has a slot - should NOT be tracked as new
|
||||
newSlots2 := []int{}
|
||||
if _, ok := cache.slotForSeq[seq]; !ok {
|
||||
slot, _ := cache.allocSlot()
|
||||
cache.slotForSeq[seq] = slot
|
||||
cache.refCount[slot] = 1
|
||||
newSlots2 = append(newSlots2, slot)
|
||||
}
|
||||
|
||||
// Verify no new slots for existing sequence
|
||||
if len(newSlots2) != 0 {
|
||||
t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2))
|
||||
t.Fatal("expected unsupported batch layout before StartForward")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -25,8 +29,20 @@ type Options struct {
|
||||
// per-layer head counts (LFM2 alternates attention and recurrent layers)
|
||||
numHeadsByLayer []int
|
||||
numKVHeadsByLayer []int
|
||||
|
||||
// MoE config
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
normTopKProb bool
|
||||
expertWeightsScale float32
|
||||
expertGatingFunc uint32
|
||||
}
|
||||
|
||||
const (
|
||||
expertGatingFuncSoftmax = uint32(0)
|
||||
expertGatingFuncSigmoid = uint32(2)
|
||||
)
|
||||
|
||||
func (o Options) headDimValue() int {
|
||||
// Head dim is shared across layers; fall back to first attention layer head count.
|
||||
for _, h := range o.numHeadsByLayer {
|
||||
@@ -67,18 +83,138 @@ type Model struct {
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
VisionModel *VisionModel `gguf:"v"`
|
||||
VisionProjector *VisionProjector `gguf:"mm"`
|
||||
ImageProcessor ImageProcessor
|
||||
imageTokenID int32
|
||||
imageStartToken int32
|
||||
imageEndToken int32
|
||||
imageThumbnailID int32
|
||||
imageRowColIDs map[imageGridPos]int32
|
||||
useSpecialTokens bool
|
||||
projectorOptions VisionProjectorOptions
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
if c.Uint("expert_count") > 0 {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
type imageGridPos struct {
|
||||
row int
|
||||
col int
|
||||
}
|
||||
|
||||
type visionEmbeddingLayout struct {
|
||||
rows int
|
||||
cols int
|
||||
hasThumbnail bool
|
||||
}
|
||||
|
||||
type visionChunkData struct {
|
||||
tokens int
|
||||
row int
|
||||
col int
|
||||
thumbnail bool
|
||||
layout *visionEmbeddingLayout
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
if m.TokenEmbedding == nil {
|
||||
return errors.New("lfm2: missing token_embd tensor")
|
||||
}
|
||||
if m.OutputNorm == nil {
|
||||
return errors.New("lfm2: missing output_norm tensor")
|
||||
}
|
||||
if m.Output == nil {
|
||||
return errors.New("lfm2: missing output tensor")
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
if layer.AttentionNorm == nil {
|
||||
return fmt.Errorf("lfm2: missing blk.%d.attn_norm tensor", i)
|
||||
}
|
||||
if layer.MLPNorm == nil {
|
||||
return fmt.Errorf("lfm2: missing blk.%d.ffn_norm tensor", i)
|
||||
}
|
||||
switch ff := layer.MLP.(type) {
|
||||
case nil:
|
||||
return fmt.Errorf("lfm2: missing blk.%d feed-forward tensors", i)
|
||||
case *denseMLP:
|
||||
if ff.Up == nil || ff.Down == nil || ff.Gate == nil {
|
||||
return fmt.Errorf("lfm2: missing blk.%d dense feed-forward tensors", i)
|
||||
}
|
||||
case *sparseMLP:
|
||||
if ff.Router == nil || ff.Gate == nil || ff.Up == nil || ff.Down == nil {
|
||||
return fmt.Errorf("lfm2: missing blk.%d sparse feed-forward tensors", i)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("lfm2: unsupported feed-forward type at blk.%d", i)
|
||||
}
|
||||
|
||||
switch op := layer.Operator.(type) {
|
||||
case *Attention:
|
||||
if op == nil || op.Query == nil || op.Key == nil || op.Value == nil || op.Output == nil || op.QueryNorm == nil || op.KeyNorm == nil {
|
||||
return fmt.Errorf("lfm2: missing blk.%d attention tensors", i)
|
||||
}
|
||||
case *ShortConv:
|
||||
if op == nil || op.Conv == nil || op.Conv.Weight == nil || op.InProj == nil || op.OutProj == nil {
|
||||
return fmt.Errorf("lfm2: missing blk.%d shortconv tensors", i)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("lfm2: unsupported operator at blk.%d", i)
|
||||
}
|
||||
}
|
||||
|
||||
if m.VisionModel != nil {
|
||||
if m.VisionModel.PatchEmbedding == nil {
|
||||
return errors.New("lfm2: missing vision patch embedding tensors")
|
||||
}
|
||||
if m.VisionModel.PositionEmbedding == nil {
|
||||
return errors.New("lfm2: missing vision position embedding tensors")
|
||||
}
|
||||
if m.VisionModel.PostLayerNorm == nil {
|
||||
return errors.New("lfm2: missing vision post layer norm tensors")
|
||||
}
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return errors.New("lfm2: missing vision encoder layers")
|
||||
}
|
||||
for i, layer := range m.VisionModel.Layers {
|
||||
if layer.LayerNorm1 == nil || layer.LayerNorm2 == nil || layer.SelfAttention == nil || layer.MLP == nil {
|
||||
return fmt.Errorf("lfm2: missing vision layer tensors at v.blk.%d", i)
|
||||
}
|
||||
if layer.SelfAttention.Query == nil || layer.SelfAttention.Key == nil || layer.SelfAttention.Value == nil || layer.SelfAttention.Output == nil {
|
||||
return fmt.Errorf("lfm2: missing vision attention tensors at v.blk.%d", i)
|
||||
}
|
||||
if layer.MLP.Up == nil || layer.MLP.Down == nil {
|
||||
return fmt.Errorf("lfm2: missing vision feed-forward tensors at v.blk.%d", i)
|
||||
}
|
||||
}
|
||||
|
||||
if m.VisionProjector == nil || m.VisionProjector.Linear1 == nil || m.VisionProjector.Linear2 == nil {
|
||||
return errors.New("lfm2: missing multimodal projector tensors")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
if c.String("tokenizer.ggml.model") != "gpt2" {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
numExperts := int(c.Uint("expert_count"))
|
||||
isMoE := numExperts > 0
|
||||
numExpertsUsed := int(c.Uint("expert_used_count"))
|
||||
if isMoE {
|
||||
if numExperts <= 0 {
|
||||
return nil, fmt.Errorf("lfm2: invalid expert_count=%d", numExperts)
|
||||
}
|
||||
if numExpertsUsed <= 0 || numExpertsUsed > numExperts {
|
||||
return nil, fmt.Errorf("lfm2: invalid expert_used_count=%d for expert_count=%d", numExpertsUsed, numExperts)
|
||||
}
|
||||
}
|
||||
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
@@ -105,8 +241,16 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
VisionProjector: &VisionProjector{},
|
||||
imageRowColIDs: make(map[imageGridPos]int32),
|
||||
projectorOptions: VisionProjectorOptions{
|
||||
scaleFactor: int(c.Uint("vision.projector.scale_factor", 2)),
|
||||
useLayerNorm: c.Bool("vision.projector.use_layernorm", false),
|
||||
},
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
@@ -116,9 +260,66 @@ func New(c fs.Config) (model.Model, error) {
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
numExperts: numExperts,
|
||||
numExpertsUsed: numExpertsUsed,
|
||||
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||
expertWeightsScale: c.Float("expert_weights_scale", 1.0),
|
||||
expertGatingFunc: c.Uint("expert_gating_func", expertGatingFuncSoftmax),
|
||||
},
|
||||
}
|
||||
|
||||
lookupTokenID := func(token string) int32 {
|
||||
for i, t := range vocabulary.Values {
|
||||
if t == token {
|
||||
return int32(i)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
resolveTokenID := func(explicitKey, token string, fallback uint32) int32 {
|
||||
if explicitKey != "" {
|
||||
if id := c.Uint(explicitKey); id != 0 {
|
||||
return int32(id)
|
||||
}
|
||||
}
|
||||
if tokenID := lookupTokenID(token); tokenID != 0 {
|
||||
return tokenID
|
||||
}
|
||||
return int32(fallback)
|
||||
}
|
||||
|
||||
m.imageTokenID = resolveTokenID("vision.image_token_id", "<image>", 396)
|
||||
m.imageStartToken = resolveTokenID("vision.image_start_token_id", "<|image_start|>", 0)
|
||||
m.imageEndToken = resolveTokenID("vision.image_end_token_id", "<|image_end|>", 0)
|
||||
m.imageThumbnailID = resolveTokenID("vision.image_thumbnail_token_id", "<|img_thumbnail|>", 0)
|
||||
m.useSpecialTokens = c.Bool("vision.use_image_special_tokens", true)
|
||||
|
||||
maxGridTokens := int(c.Uint("vision.max_tiles", 10))
|
||||
if maxGridTokens <= 0 {
|
||||
maxGridTokens = 10
|
||||
}
|
||||
for row := 1; row <= maxGridTokens; row++ {
|
||||
for col := 1; col <= maxGridTokens; col++ {
|
||||
token := fmt.Sprintf("<|img_row_%d_col_%d|>", row, col)
|
||||
if tokenID := lookupTokenID(token); tokenID > 0 {
|
||||
m.imageRowColIDs[imageGridPos{row: row, col: col}] = tokenID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !m.useSpecialTokens {
|
||||
m.imageStartToken = 0
|
||||
m.imageEndToken = 0
|
||||
m.imageThumbnailID = 0
|
||||
m.imageRowColIDs = map[imageGridPos]int32{}
|
||||
}
|
||||
|
||||
if c.Uint("vision.block_count") == 0 {
|
||||
m.VisionModel = nil
|
||||
m.VisionProjector = nil
|
||||
}
|
||||
|
||||
type headCounts interface {
|
||||
HeadCount() []uint64
|
||||
HeadCountKV() []uint64
|
||||
@@ -133,6 +334,14 @@ func New(c fs.Config) (model.Model, error) {
|
||||
|
||||
m.numHeadsByLayer = make([]int, len(m.Layers))
|
||||
m.numKVHeadsByLayer = make([]int, len(m.Layers))
|
||||
leadingDenseBlockCount := int(c.Uint("leading_dense_block_count"))
|
||||
if leadingDenseBlockCount < 0 {
|
||||
leadingDenseBlockCount = 0
|
||||
}
|
||||
if leadingDenseBlockCount > len(m.Layers) {
|
||||
leadingDenseBlockCount = len(m.Layers)
|
||||
}
|
||||
|
||||
for i := range m.Layers {
|
||||
m.numHeadsByLayer[i] = int(headCount[i])
|
||||
m.numKVHeadsByLayer[i] = int(headCountKV[i])
|
||||
@@ -142,6 +351,12 @@ func New(c fs.Config) (model.Model, error) {
|
||||
} else {
|
||||
m.Layers[i].Operator = &Attention{}
|
||||
}
|
||||
|
||||
if isMoE && i >= leadingDenseBlockCount {
|
||||
m.Layers[i].MLP = &sparseMLP{}
|
||||
} else {
|
||||
m.Layers[i].MLP = &denseMLP{}
|
||||
}
|
||||
}
|
||||
|
||||
lCache := int(c.Uint("shortconv.l_cache"))
|
||||
@@ -188,22 +403,77 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
type FeedForward interface {
|
||||
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
|
||||
}
|
||||
|
||||
type denseMLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
func (mlp *denseMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type sparseMLP struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
Bias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
|
||||
}
|
||||
|
||||
func (mlp *sparseMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
// hiddenState: [hidden, tokens]
|
||||
routerLogits := mlp.Router.Forward(ctx, hiddenState)
|
||||
|
||||
probs := routerLogits.Softmax(ctx)
|
||||
if opts.expertGatingFunc == expertGatingFuncSigmoid {
|
||||
probs = routerLogits.Sigmoid(ctx)
|
||||
}
|
||||
|
||||
selectionProbs := probs
|
||||
if mlp.Bias != nil {
|
||||
selectionProbs = selectionProbs.Add(ctx, mlp.Bias)
|
||||
}
|
||||
|
||||
selectedExperts := selectionProbs.TopK(ctx, opts.numExpertsUsed)
|
||||
routingWeights := probs.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
|
||||
if opts.normTopKProb {
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
weightsSum := routingWeights.SumRows(ctx)
|
||||
weightsSum = weightsSum.Clamp(ctx, 1e-6, float32(math.Inf(1)))
|
||||
routingWeights = routingWeights.Div(ctx, weightsSum)
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
}
|
||||
if opts.expertWeightsScale != 1 {
|
||||
routingWeights = routingWeights.Scale(ctx, float64(opts.expertWeightsScale))
|
||||
}
|
||||
|
||||
// Build routing-weights branch early to enable topk-MoE fusion.
|
||||
ctx.Forward(routingWeights)
|
||||
|
||||
hiddenState3D := hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
|
||||
experts := mlp.Gate.Forward(ctx, hiddenState3D, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenState3D, selectedExperts))
|
||||
experts = mlp.Down.Forward(ctx, experts, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
nextState := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextState = nextState.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
return nextState
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Operator Operator
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *MLP
|
||||
MLP FeedForward
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor {
|
||||
@@ -229,10 +499,233 @@ func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tenso
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func multimodalTokenCount(mm input.Multimodal) int {
|
||||
if mm.Tensor != nil {
|
||||
return mm.Tensor.Dim(1)
|
||||
}
|
||||
|
||||
switch data := mm.Data.(type) {
|
||||
case int:
|
||||
return data
|
||||
case int32:
|
||||
return int(data)
|
||||
case visionChunkData:
|
||||
return data.tokens
|
||||
case *visionChunkData:
|
||||
if data != nil {
|
||||
return data.tokens
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func multimodalChunkInfo(mm input.Multimodal) visionChunkData {
|
||||
switch data := mm.Data.(type) {
|
||||
case visionChunkData:
|
||||
return data
|
||||
case *visionChunkData:
|
||||
if data != nil {
|
||||
return *data
|
||||
}
|
||||
}
|
||||
|
||||
return visionChunkData{
|
||||
tokens: multimodalTokenCount(mm),
|
||||
}
|
||||
}
|
||||
|
||||
func multimodalLayout(mm []input.Multimodal) visionEmbeddingLayout {
|
||||
layout := visionEmbeddingLayout{rows: 1, cols: 1}
|
||||
if len(mm) == 0 {
|
||||
return layout
|
||||
}
|
||||
|
||||
first := multimodalChunkInfo(mm[0])
|
||||
if first.layout != nil {
|
||||
return *first.layout
|
||||
}
|
||||
|
||||
return layout
|
||||
}
|
||||
|
||||
func (m *Model) imageRowColToken(row, col int) int32 {
|
||||
if row <= 0 || col <= 0 {
|
||||
return 0
|
||||
}
|
||||
return m.imageRowColIDs[imageGridPos{row: row, col: col}]
|
||||
}
|
||||
|
||||
func (m *Model) appendImageChunk(result []*input.Input, chunk input.Multimodal, imageToken int32, hash uint64) ([]*input.Input, error) {
|
||||
tokenCount := multimodalTokenCount(chunk)
|
||||
if tokenCount <= 0 {
|
||||
return nil, errors.New("lfm2: multimodal input has no tokens")
|
||||
}
|
||||
|
||||
result = append(result, &input.Input{
|
||||
Token: imageToken,
|
||||
Multimodal: []input.Multimodal{chunk},
|
||||
MultimodalHash: hash,
|
||||
SameBatch: tokenCount - 1,
|
||||
})
|
||||
|
||||
for range tokenCount - 1 {
|
||||
result = append(result, &input.Input{Token: imageToken})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if m.VisionModel == nil || m.VisionProjector == nil || len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
processedImages, layout, err := m.ImageProcessor.ProcessImage(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if m.ImageProcessor.patchSize <= 0 {
|
||||
return nil, errors.New("lfm2: invalid vision patch size")
|
||||
}
|
||||
|
||||
layoutInfo := &visionEmbeddingLayout{
|
||||
rows: layout.rows,
|
||||
cols: layout.cols,
|
||||
hasThumbnail: layout.hasThumbnail,
|
||||
}
|
||||
|
||||
mm := make([]input.Multimodal, 0, len(processedImages))
|
||||
for i, processed := range processedImages {
|
||||
patches := visionPatchGrid{
|
||||
Width: processed.size.X / m.ImageProcessor.patchSize,
|
||||
Height: processed.size.Y / m.ImageProcessor.patchSize,
|
||||
}
|
||||
if patches.Width == 0 || patches.Height == 0 {
|
||||
return nil, errors.New("lfm2: invalid resized image dimensions")
|
||||
}
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(processed.data, processed.size.X, processed.size.Y, m.ImageProcessor.numChannels)
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, patches)
|
||||
projected := m.VisionProjector.Forward(ctx, visionOutputs, patches, m.projectorOptions)
|
||||
|
||||
chunk := visionChunkData{
|
||||
tokens: projected.Dim(1),
|
||||
row: processed.row,
|
||||
col: processed.col,
|
||||
thumbnail: processed.thumbnail,
|
||||
}
|
||||
if i == 0 {
|
||||
chunk.layout = layoutInfo
|
||||
}
|
||||
|
||||
mm = append(mm, input.Multimodal{
|
||||
Tensor: projected,
|
||||
Data: chunk,
|
||||
})
|
||||
}
|
||||
|
||||
return mm, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
imageToken := m.imageTokenID
|
||||
if imageToken == 0 {
|
||||
imageToken = 396
|
||||
}
|
||||
useSpecialTokens := m.useSpecialTokens || m.imageStartToken > 0 || m.imageEndToken > 0 || m.imageThumbnailID > 0 || len(m.imageRowColIDs) > 0
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
continue
|
||||
}
|
||||
|
||||
layout := multimodalLayout(inp.Multimodal)
|
||||
if layout.rows <= 0 {
|
||||
layout.rows = 1
|
||||
}
|
||||
if layout.cols <= 0 {
|
||||
layout.cols = 1
|
||||
}
|
||||
tiles := layout.rows * layout.cols
|
||||
multitile := tiles > 1
|
||||
|
||||
if useSpecialTokens && m.imageStartToken > 0 {
|
||||
result = append(result, &input.Input{Token: m.imageStartToken})
|
||||
}
|
||||
|
||||
for i, mm := range inp.Multimodal {
|
||||
chunk := multimodalChunkInfo(mm)
|
||||
if chunk.tokens <= 0 {
|
||||
chunk.tokens = multimodalTokenCount(mm)
|
||||
}
|
||||
|
||||
if multitile && !chunk.thumbnail && chunk.row == 0 && chunk.col == 0 && i < tiles {
|
||||
chunk.row = i/layout.cols + 1
|
||||
chunk.col = i%layout.cols + 1
|
||||
}
|
||||
if multitile && layout.hasThumbnail && i == tiles {
|
||||
chunk.thumbnail = true
|
||||
}
|
||||
|
||||
if useSpecialTokens && multitile {
|
||||
if chunk.thumbnail {
|
||||
if m.imageThumbnailID > 0 {
|
||||
result = append(result, &input.Input{Token: m.imageThumbnailID})
|
||||
}
|
||||
} else if marker := m.imageRowColToken(chunk.row, chunk.col); marker > 0 {
|
||||
result = append(result, &input.Input{Token: marker})
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
result, err = m.appendImageChunk(result, input.Multimodal{
|
||||
Tensor: mm.Tensor,
|
||||
Data: chunk,
|
||||
}, imageToken, inp.MultimodalHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if useSpecialTokens && m.imageEndToken > 0 {
|
||||
result = append(result, &input.Input{Token: m.imageEndToken})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
if len(batch.Multimodal) > 0 {
|
||||
// We splice vision embeddings into token embeddings in-place; duplicate to
|
||||
// avoid aliasing the raw embedding output graph.
|
||||
hiddenState = hiddenState.Duplicate(ctx)
|
||||
}
|
||||
for _, mm := range batch.Multimodal {
|
||||
offset := mm.Index
|
||||
for _, multimodal := range mm.Multimodal {
|
||||
if multimodal.Tensor == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
visionOutputs := multimodal.Tensor
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, offset*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
offset += visionOutputs.Dim(1)
|
||||
}
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
@@ -251,4 +744,5 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
func init() {
|
||||
model.Register("lfm2", New)
|
||||
model.Register("lfm2moe", New)
|
||||
}
|
||||
|
||||
160
model/models/lfm2/model_multimodal_test.go
Normal file
160
model/models/lfm2/model_multimodal_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
func TestPostTokenizeWithSpecialImageTokens(t *testing.T) {
|
||||
m := &Model{
|
||||
imageTokenID: 396,
|
||||
imageStartToken: 2,
|
||||
imageEndToken: 3,
|
||||
useSpecialTokens: true,
|
||||
}
|
||||
|
||||
in := []*input.Input{
|
||||
{Token: 11},
|
||||
{Multimodal: []input.Multimodal{{Data: 64}}, MultimodalHash: 123},
|
||||
{Token: 12},
|
||||
}
|
||||
|
||||
out, err := m.PostTokenize(in)
|
||||
if err != nil {
|
||||
t.Fatalf("PostTokenize returned error: %v", err)
|
||||
}
|
||||
|
||||
if len(out) != 68 {
|
||||
t.Fatalf("expected 68 tokens, got %d", len(out))
|
||||
}
|
||||
|
||||
if out[0].Token != 11 {
|
||||
t.Fatalf("out[0].Token = %d, want 11", out[0].Token)
|
||||
}
|
||||
if out[1].Token != 2 {
|
||||
t.Fatalf("out[1].Token = %d, want 2", out[1].Token)
|
||||
}
|
||||
|
||||
firstImage := out[2]
|
||||
if firstImage.Token != 396 {
|
||||
t.Fatalf("out[2].Token = %d, want 396", firstImage.Token)
|
||||
}
|
||||
if len(firstImage.Multimodal) != 1 {
|
||||
t.Fatalf("expected multimodal payload on first image token")
|
||||
}
|
||||
if firstImage.MultimodalHash != 123 {
|
||||
t.Fatalf("out[2].MultimodalHash = %d, want 123", firstImage.MultimodalHash)
|
||||
}
|
||||
if firstImage.SameBatch != 63 {
|
||||
t.Fatalf("out[2].SameBatch = %d, want 63", firstImage.SameBatch)
|
||||
}
|
||||
|
||||
for i := 3; i < 66; i++ {
|
||||
if out[i].Token != 396 {
|
||||
t.Fatalf("out[%d].Token = %d, want 396", i, out[i].Token)
|
||||
}
|
||||
if len(out[i].Multimodal) != 0 {
|
||||
t.Fatalf("out[%d] should not carry multimodal payload", i)
|
||||
}
|
||||
}
|
||||
|
||||
if out[66].Token != 3 {
|
||||
t.Fatalf("out[66].Token = %d, want 3", out[66].Token)
|
||||
}
|
||||
if out[67].Token != 12 {
|
||||
t.Fatalf("out[67].Token = %d, want 12", out[67].Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostTokenizeWithoutSpecialImageTokens(t *testing.T) {
|
||||
m := &Model{
|
||||
imageTokenID: 777,
|
||||
useSpecialTokens: false,
|
||||
}
|
||||
|
||||
in := []*input.Input{
|
||||
{Multimodal: []input.Multimodal{{Data: 5}}, MultimodalHash: 9},
|
||||
}
|
||||
|
||||
out, err := m.PostTokenize(in)
|
||||
if err != nil {
|
||||
t.Fatalf("PostTokenize returned error: %v", err)
|
||||
}
|
||||
|
||||
if len(out) != 5 {
|
||||
t.Fatalf("expected 5 tokens, got %d", len(out))
|
||||
}
|
||||
if out[0].Token != 777 || out[0].SameBatch != 4 || len(out[0].Multimodal) != 1 {
|
||||
t.Fatalf("unexpected first token: %+v", *out[0])
|
||||
}
|
||||
for i := 1; i < 5; i++ {
|
||||
if out[i].Token != 777 {
|
||||
t.Fatalf("out[%d].Token = %d, want 777", i, out[i].Token)
|
||||
}
|
||||
if len(out[i].Multimodal) != 0 {
|
||||
t.Fatalf("out[%d] should not carry multimodal payload", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostTokenizeMultiTileLayoutTokens(t *testing.T) {
|
||||
m := &Model{
|
||||
imageTokenID: 396,
|
||||
imageStartToken: 498,
|
||||
imageEndToken: 499,
|
||||
imageThumbnailID: 497,
|
||||
imageRowColIDs: map[imageGridPos]int32{
|
||||
{row: 1, col: 1}: 397,
|
||||
{row: 1, col: 2}: 398,
|
||||
},
|
||||
useSpecialTokens: true,
|
||||
}
|
||||
|
||||
layout := &visionEmbeddingLayout{rows: 1, cols: 2, hasThumbnail: true}
|
||||
in := []*input.Input{{
|
||||
Multimodal: []input.Multimodal{
|
||||
{Data: visionChunkData{tokens: 3, row: 1, col: 1, layout: layout}},
|
||||
{Data: visionChunkData{tokens: 3, row: 1, col: 2}},
|
||||
{Data: visionChunkData{tokens: 2, thumbnail: true}},
|
||||
},
|
||||
MultimodalHash: 1,
|
||||
}}
|
||||
|
||||
out, err := m.PostTokenize(in)
|
||||
if err != nil {
|
||||
t.Fatalf("PostTokenize returned error: %v", err)
|
||||
}
|
||||
|
||||
got := make([]int32, len(out))
|
||||
for i := range out {
|
||||
got[i] = out[i].Token
|
||||
}
|
||||
|
||||
want := []int32{
|
||||
498, // <|image_start|>
|
||||
397, // <|img_row_1_col_1|>
|
||||
396, 396, 396,
|
||||
398, // <|img_row_1_col_2|>
|
||||
396, 396, 396,
|
||||
497, // <|img_thumbnail|>
|
||||
396, 396,
|
||||
499, // <|image_end|>
|
||||
}
|
||||
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("len(out) = %d, want %d", len(got), len(want))
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("out[%d].Token = %d, want %d", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(out[2].Multimodal) != 1 || len(out[6].Multimodal) != 1 || len(out[10].Multimodal) != 1 {
|
||||
t.Fatalf("expected multimodal payload on first token of each chunk")
|
||||
}
|
||||
if out[2].SameBatch != 2 || out[6].SameBatch != 2 || out[10].SameBatch != 1 {
|
||||
t.Fatalf("unexpected SameBatch values: [%d %d %d]", out[2].SameBatch, out[6].SameBatch, out[10].SameBatch)
|
||||
}
|
||||
}
|
||||
184
model/models/lfm2/model_vision.go
Normal file
184
model/models/lfm2/model_vision.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
const lfm2VisionBatchSize = 1
|
||||
|
||||
type visionPatchGrid struct {
|
||||
Width int
|
||||
Height int
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output,alt:attn_out"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), lfm2VisionBatchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), lfm2VisionBatchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), lfm2VisionBatchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), lfm2VisionBatchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
return mlp.Down.Forward(ctx, mlp.Up.Forward(ctx, hiddenState).GELU(ctx))
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
LayerNorm1 *nn.LayerNorm `gguf:"ln1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
|
||||
LayerNorm2 *nn.LayerNorm `gguf:"ln2"`
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (l *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
|
||||
residual = hiddenState
|
||||
hiddenState = l.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize, numHeads int
|
||||
imageSize, patchSize int
|
||||
eps float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
|
||||
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
|
||||
PostLayerNorm *nn.LayerNorm `gguf:"post_ln"`
|
||||
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, patches visionPatchGrid) ml.Tensor {
|
||||
numPatches := patches.Width * patches.Height
|
||||
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
if m.PositionEmbedding != nil {
|
||||
posTokens := m.PositionEmbedding.Weight.Dim(1)
|
||||
source := int(math.Sqrt(float64(posTokens)))
|
||||
|
||||
var positionEmbeddings ml.Tensor
|
||||
if source > 0 && source*source == posTokens && (source != patches.Width || source != patches.Height) {
|
||||
// SigLIP2 NAFlex-style position interpolation for variable image sizes.
|
||||
positionIDs := ctx.Arange(0, float32(posTokens), 1, ml.DTypeI32)
|
||||
positionEmbeddings = m.PositionEmbedding.Forward(ctx, positionIDs)
|
||||
positionEmbeddings = positionEmbeddings.Reshape(ctx, -1, source, source)
|
||||
positionEmbeddings = positionEmbeddings.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
positionEmbeddings = positionEmbeddings.Interpolate(ctx, [4]int{
|
||||
patches.Width,
|
||||
patches.Height,
|
||||
hiddenState.Dim(0),
|
||||
1,
|
||||
}, ml.SamplingModeBilinear)
|
||||
positionEmbeddings = positionEmbeddings.Permute(ctx, 1, 2, 0, 3)
|
||||
positionEmbeddings = positionEmbeddings.Contiguous(ctx, -1, patches.Width*patches.Height)
|
||||
} else {
|
||||
positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeI32)
|
||||
positionEmbeddings = m.PositionEmbedding.Forward(ctx, positionIDs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, positionEmbeddings)
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
return m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length", 1152)),
|
||||
numHeads: int(c.Uint("vision.attention.head_count", 16)),
|
||||
imageSize: int(c.Uint("vision.image_size", 256)),
|
||||
patchSize: int(c.Uint("vision.patch_size", 16)),
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type VisionProjector struct {
|
||||
LayerNorm *nn.LayerNorm `gguf:"layer_norm"`
|
||||
Linear1 *nn.Linear `gguf:"1"`
|
||||
Linear2 *nn.Linear `gguf:"2"`
|
||||
}
|
||||
|
||||
type VisionProjectorOptions struct {
|
||||
scaleFactor int
|
||||
useLayerNorm bool
|
||||
}
|
||||
|
||||
func (p *VisionProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, patches visionPatchGrid, opts VisionProjectorOptions) ml.Tensor {
|
||||
hiddenSize := visionOutputs.Dim(0)
|
||||
featureMap := visionOutputs
|
||||
|
||||
merge := max(opts.scaleFactor, 1)
|
||||
if merge > 1 {
|
||||
width := patches.Width
|
||||
height := patches.Height
|
||||
|
||||
featureMap = featureMap.Reshape(ctx, hiddenSize, width, height)
|
||||
|
||||
// Match llama.cpp patch merger: pad spatial dims to merge factor.
|
||||
padWidth := (merge - width%merge) % merge
|
||||
padHeight := (merge - height%merge) % merge
|
||||
if padWidth != 0 || padHeight != 0 {
|
||||
featureMap = featureMap.Pad(ctx, 0, padWidth, padHeight, 0)
|
||||
width += padWidth
|
||||
height += padHeight
|
||||
}
|
||||
|
||||
featureMap = featureMap.Reshape(ctx, hiddenSize*merge, width/merge, height)
|
||||
featureMap = featureMap.Permute(ctx, 0, 2, 1).Contiguous(ctx, hiddenSize*merge*merge, height/merge, width/merge)
|
||||
featureMap = featureMap.Permute(ctx, 0, 2, 1).Contiguous(ctx)
|
||||
featureMap = featureMap.Contiguous(ctx, featureMap.Dim(0), featureMap.Dim(1)*featureMap.Dim(2))
|
||||
}
|
||||
|
||||
if opts.useLayerNorm && p.LayerNorm != nil {
|
||||
featureMap = p.LayerNorm.Forward(ctx, featureMap, 1e-5)
|
||||
}
|
||||
|
||||
featureMap = p.Linear1.Forward(ctx, featureMap).GELU(ctx)
|
||||
return p.Linear2.Forward(ctx, featureMap)
|
||||
}
|
||||
260
model/models/lfm2/process_image.go
Normal file
260
model/models/lfm2/process_image.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"image"
|
||||
stdimage "image/draw"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize, patchSize, numChannels int
|
||||
downsampleFactor int
|
||||
imageMean, imageStd [3]float32
|
||||
|
||||
doImageSplitting bool
|
||||
minTiles int
|
||||
maxTiles int
|
||||
useThumbnail bool
|
||||
tileSize int
|
||||
|
||||
minImageTokens int
|
||||
maxImageTokens int
|
||||
maxPixelsTolerance float64
|
||||
}
|
||||
|
||||
type processedVisionImage struct {
|
||||
data []float32
|
||||
size image.Point
|
||||
row int
|
||||
col int
|
||||
thumbnail bool
|
||||
}
|
||||
|
||||
type processedVisionLayout struct {
|
||||
rows int
|
||||
cols int
|
||||
hasThumbnail bool
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
mean := c.Floats("vision.image_mean")
|
||||
std := c.Floats("vision.image_std")
|
||||
|
||||
processor := ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size", 256)),
|
||||
patchSize: int(c.Uint("vision.patch_size", 16)),
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||
downsampleFactor: int(c.Uint("vision.projector.scale_factor", 2)),
|
||||
imageMean: [3]float32{0.5, 0.5, 0.5},
|
||||
imageStd: [3]float32{0.5, 0.5, 0.5},
|
||||
doImageSplitting: c.Bool("vision.do_image_splitting", true),
|
||||
minTiles: int(c.Uint("vision.min_tiles", 2)),
|
||||
maxTiles: int(c.Uint("vision.max_tiles", 10)),
|
||||
useThumbnail: c.Bool("vision.use_thumbnail", true),
|
||||
tileSize: int(c.Uint("vision.tile_size", 512)),
|
||||
minImageTokens: int(c.Uint("vision.min_image_tokens", 64)),
|
||||
maxImageTokens: int(c.Uint("vision.max_image_tokens", 256)),
|
||||
maxPixelsTolerance: float64(c.Float("vision.max_pixels_tolerance", 2.0)),
|
||||
}
|
||||
|
||||
if len(mean) >= 3 {
|
||||
processor.imageMean = [3]float32{mean[0], mean[1], mean[2]}
|
||||
}
|
||||
if len(std) >= 3 {
|
||||
processor.imageStd = [3]float32{std[0], std[1], std[2]}
|
||||
}
|
||||
|
||||
// Keep defaults aligned with HF unless explicitly configured.
|
||||
if processor.downsampleFactor <= 0 {
|
||||
processor.downsampleFactor = 2
|
||||
}
|
||||
if processor.patchSize <= 0 {
|
||||
processor.patchSize = 16
|
||||
}
|
||||
if processor.tileSize <= 0 {
|
||||
processor.tileSize = 512
|
||||
}
|
||||
if processor.minTiles <= 0 {
|
||||
processor.minTiles = 2
|
||||
}
|
||||
if processor.maxTiles < processor.minTiles {
|
||||
processor.maxTiles = processor.minTiles
|
||||
}
|
||||
if processor.minImageTokens <= 0 {
|
||||
processor.minImageTokens = 64
|
||||
}
|
||||
if processor.maxImageTokens < processor.minImageTokens {
|
||||
processor.maxImageTokens = processor.minImageTokens
|
||||
}
|
||||
if processor.maxPixelsTolerance <= 0 {
|
||||
processor.maxPixelsTolerance = 2.0
|
||||
}
|
||||
|
||||
return processor
|
||||
}
|
||||
|
||||
func (p ImageProcessor) ProcessImage(img image.Image) ([]processedVisionImage, processedVisionLayout, error) {
|
||||
img = imageproc.Composite(img)
|
||||
|
||||
orig := img.Bounds().Size()
|
||||
resizedWidth, resizedHeight := p.smartResize(orig.Y, orig.X)
|
||||
|
||||
layout := processedVisionLayout{rows: 1, cols: 1}
|
||||
if p.shouldSplit(orig.Y, orig.X) {
|
||||
gridWidth, gridHeight, targetWidth, targetHeight := p.gridLayout(orig.Y, orig.X)
|
||||
layout.rows = gridHeight
|
||||
layout.cols = gridWidth
|
||||
layout.hasThumbnail = p.useThumbnail && gridWidth*gridHeight != 1
|
||||
|
||||
resized := imageproc.Resize(img, image.Point{X: targetWidth, Y: targetHeight}, imageproc.ResizeBilinear)
|
||||
images := make([]processedVisionImage, 0, gridWidth*gridHeight+1)
|
||||
for row := range gridHeight {
|
||||
for col := range gridWidth {
|
||||
rect := image.Rect(
|
||||
col*p.tileSize,
|
||||
row*p.tileSize,
|
||||
(col+1)*p.tileSize,
|
||||
(row+1)*p.tileSize,
|
||||
)
|
||||
tile := cropImage(resized, rect)
|
||||
images = append(images, processedVisionImage{
|
||||
data: imageproc.Normalize(tile, p.imageMean, p.imageStd, true, true),
|
||||
size: tile.Bounds().Size(),
|
||||
row: row + 1,
|
||||
col: col + 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if layout.hasThumbnail {
|
||||
thumbnail := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
|
||||
images = append(images, processedVisionImage{
|
||||
data: imageproc.Normalize(thumbnail, p.imageMean, p.imageStd, true, true),
|
||||
size: thumbnail.Bounds().Size(),
|
||||
thumbnail: true,
|
||||
})
|
||||
}
|
||||
|
||||
return images, layout, nil
|
||||
}
|
||||
|
||||
single := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
|
||||
return []processedVisionImage{{
|
||||
data: imageproc.Normalize(single, p.imageMean, p.imageStd, true, true),
|
||||
size: single.Bounds().Size(),
|
||||
}}, layout, nil
|
||||
}
|
||||
|
||||
func (p ImageProcessor) shouldSplit(height, width int) bool {
|
||||
if !p.doImageSplitting || p.minTiles == 1 && p.maxTiles == 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
totalFactor := p.patchSize * p.downsampleFactor
|
||||
hBar := max(p.patchSize, roundByFactor(height, totalFactor))
|
||||
wBar := max(p.patchSize, roundByFactor(width, totalFactor))
|
||||
|
||||
limit := float64(p.maxImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor)
|
||||
limit *= p.maxPixelsTolerance
|
||||
|
||||
return float64(hBar*wBar) > limit
|
||||
}
|
||||
|
||||
func (p ImageProcessor) smartResize(height, width int) (int, int) {
|
||||
totalFactor := p.patchSize * p.downsampleFactor
|
||||
minPixels := p.minImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor
|
||||
maxPixels := p.maxImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor
|
||||
|
||||
hBar := max(totalFactor, roundByFactor(height, totalFactor))
|
||||
wBar := max(totalFactor, roundByFactor(width, totalFactor))
|
||||
|
||||
if hBar*wBar > maxPixels {
|
||||
beta := math.Sqrt(float64(height*width) / float64(maxPixels))
|
||||
hBar = max(totalFactor, int(math.Floor(float64(height)/beta/float64(totalFactor)))*totalFactor)
|
||||
wBar = max(totalFactor, int(math.Floor(float64(width)/beta/float64(totalFactor)))*totalFactor)
|
||||
} else if hBar*wBar < minPixels {
|
||||
beta := math.Sqrt(float64(minPixels) / float64(height*width))
|
||||
hBar = int(math.Ceil(float64(height)*beta/float64(totalFactor))) * totalFactor
|
||||
wBar = int(math.Ceil(float64(width)*beta/float64(totalFactor))) * totalFactor
|
||||
}
|
||||
|
||||
return wBar, hBar
|
||||
}
|
||||
|
||||
func (p ImageProcessor) gridLayout(height, width int) (gridWidth, gridHeight, targetWidth, targetHeight int) {
|
||||
aspectRatio := float64(width) / float64(height)
|
||||
targetRatios := p.targetRatios()
|
||||
bestRatio := clipImageSize{width: 1, height: 1}
|
||||
bestRatioDiff := math.MaxFloat64
|
||||
area := float64(width * height)
|
||||
|
||||
for _, ratio := range targetRatios {
|
||||
targetAspect := float64(ratio.width) / float64(ratio.height)
|
||||
ratioDiff := math.Abs(aspectRatio - targetAspect)
|
||||
|
||||
if ratioDiff < bestRatioDiff {
|
||||
bestRatioDiff = ratioDiff
|
||||
bestRatio = ratio
|
||||
continue
|
||||
}
|
||||
|
||||
if ratioDiff == bestRatioDiff {
|
||||
targetArea := float64(p.tileSize * p.tileSize * ratio.width * ratio.height)
|
||||
if area > 0.5*targetArea {
|
||||
bestRatio = ratio
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bestRatio.width, bestRatio.height, p.tileSize * bestRatio.width, p.tileSize * bestRatio.height
|
||||
}
|
||||
|
||||
type clipImageSize struct {
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
func (p ImageProcessor) targetRatios() []clipImageSize {
|
||||
targetRatios := make([]clipImageSize, 0, p.maxTiles*p.maxTiles)
|
||||
for n := p.minTiles; n <= p.maxTiles; n++ {
|
||||
for w := 1; w <= n; w++ {
|
||||
for h := 1; h <= n; h++ {
|
||||
if w*h < p.minTiles || w*h > p.maxTiles {
|
||||
continue
|
||||
}
|
||||
targetRatios = append(targetRatios, clipImageSize{width: w, height: h})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unique := targetRatios[:0]
|
||||
for _, ratio := range targetRatios {
|
||||
if slices.Contains(unique, ratio) {
|
||||
continue
|
||||
}
|
||||
unique = append(unique, ratio)
|
||||
}
|
||||
|
||||
slices.SortFunc(unique, func(a, b clipImageSize) int {
|
||||
return a.width*a.height - b.width*b.height
|
||||
})
|
||||
|
||||
return unique
|
||||
}
|
||||
|
||||
func roundByFactor(number, factor int) int {
|
||||
if factor <= 0 {
|
||||
return number
|
||||
}
|
||||
return int(math.RoundToEven(float64(number)/float64(factor))) * factor
|
||||
}
|
||||
|
||||
func cropImage(img image.Image, rect image.Rectangle) image.Image {
|
||||
dst := image.NewRGBA(image.Rect(0, 0, rect.Dx(), rect.Dy()))
|
||||
stdimage.Draw(dst, dst.Bounds(), img, rect.Min, stdimage.Src)
|
||||
return dst
|
||||
}
|
||||
105
model/models/lfm2/process_image_test.go
Normal file
105
model/models/lfm2/process_image_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"image"
|
||||
"image/color"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProcessImageSingleTile(t *testing.T) {
|
||||
p := ImageProcessor{
|
||||
patchSize: 16,
|
||||
downsampleFactor: 2,
|
||||
numChannels: 3,
|
||||
imageMean: [3]float32{0.5, 0.5, 0.5},
|
||||
imageStd: [3]float32{0.5, 0.5, 0.5},
|
||||
doImageSplitting: true,
|
||||
minTiles: 2,
|
||||
maxTiles: 10,
|
||||
useThumbnail: true,
|
||||
tileSize: 512,
|
||||
minImageTokens: 64,
|
||||
maxImageTokens: 256,
|
||||
maxPixelsTolerance: 2.0,
|
||||
}
|
||||
|
||||
img := image.NewRGBA(image.Rect(0, 0, 320, 320))
|
||||
out, layout, err := p.ProcessImage(img)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessImage returned error: %v", err)
|
||||
}
|
||||
|
||||
if layout.rows != 1 || layout.cols != 1 || layout.hasThumbnail {
|
||||
t.Fatalf("layout = %+v, want rows=1 cols=1 hasThumbnail=false", layout)
|
||||
}
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("len(out) = %d, want 1", len(out))
|
||||
}
|
||||
if out[0].size != (image.Point{X: 320, Y: 320}) {
|
||||
t.Fatalf("single image size = %+v, want 320x320", out[0].size)
|
||||
}
|
||||
if out[0].thumbnail {
|
||||
t.Fatalf("single image should not be marked as thumbnail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessImageDynamicTiling(t *testing.T) {
|
||||
p := ImageProcessor{
|
||||
patchSize: 16,
|
||||
downsampleFactor: 2,
|
||||
numChannels: 3,
|
||||
imageMean: [3]float32{0.5, 0.5, 0.5},
|
||||
imageStd: [3]float32{0.5, 0.5, 0.5},
|
||||
doImageSplitting: true,
|
||||
minTiles: 2,
|
||||
maxTiles: 10,
|
||||
useThumbnail: true,
|
||||
tileSize: 512,
|
||||
minImageTokens: 64,
|
||||
maxImageTokens: 256,
|
||||
maxPixelsTolerance: 2.0,
|
||||
}
|
||||
|
||||
// Wide image that should trigger multi-tile splitting.
|
||||
img := image.NewRGBA(image.Rect(0, 0, 3000, 1000))
|
||||
fill := color.RGBA{R: 120, G: 90, B: 60, A: 255}
|
||||
for y := range 1000 {
|
||||
for x := range 3000 {
|
||||
img.Set(x, y, fill)
|
||||
}
|
||||
}
|
||||
|
||||
out, layout, err := p.ProcessImage(img)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessImage returned error: %v", err)
|
||||
}
|
||||
|
||||
if layout.rows*layout.cols <= 1 {
|
||||
t.Fatalf("expected multi-tile layout, got %+v", layout)
|
||||
}
|
||||
if !layout.hasThumbnail {
|
||||
t.Fatalf("expected thumbnail for multi-tile layout")
|
||||
}
|
||||
|
||||
wantLen := layout.rows*layout.cols + 1
|
||||
if len(out) != wantLen {
|
||||
t.Fatalf("len(out) = %d, want %d", len(out), wantLen)
|
||||
}
|
||||
|
||||
for i := range layout.rows * layout.cols {
|
||||
if out[i].size != (image.Point{X: 512, Y: 512}) {
|
||||
t.Fatalf("tile[%d] size = %+v, want 512x512", i, out[i].size)
|
||||
}
|
||||
if out[i].thumbnail {
|
||||
t.Fatalf("tile[%d] should not be marked as thumbnail", i)
|
||||
}
|
||||
}
|
||||
|
||||
thumb := out[len(out)-1]
|
||||
if !thumb.thumbnail {
|
||||
t.Fatalf("last chunk should be thumbnail")
|
||||
}
|
||||
if thumb.size.X%32 != 0 || thumb.size.Y%32 != 0 {
|
||||
t.Fatalf("thumbnail size = %+v, want dimensions aligned to 32", thumb.size)
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
_ "github.com/ollama/ollama/model/models/mllama"
|
||||
_ "github.com/ollama/ollama/model/models/nemotronh"
|
||||
_ "github.com/ollama/ollama/model/models/nomicbert"
|
||||
_ "github.com/ollama/ollama/model/models/olmo3"
|
||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||
|
||||
88
model/models/nemotronh/attention.go
Normal file
88
model/models/nemotronh/attention.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
// Attention implements simple attention without RoPE for Nemotron-H.
|
||||
// Unlike Qwen3Next, Nemotron-H attention has:
|
||||
// - No RoPE (position info comes from Mamba2 layers)
|
||||
// - Standard scaled dot-product attention
|
||||
type Attention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
hiddenDim := hiddenStates.Dim(0)
|
||||
nSeqTokens := hiddenStates.Dim(1)
|
||||
switch hiddenStates.Dim(2) {
|
||||
case 0:
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
|
||||
case 1:
|
||||
default:
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
|
||||
// Nemotron-H is currently clamped to num_parallel=1.
|
||||
if cache != nil && cache.IsSupportedForBatch() {
|
||||
if cache.numSeqs() != 1 {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
}
|
||||
batchSize := nSeqTokens
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize)
|
||||
|
||||
headDim := opts.getHeadDim()
|
||||
if headDim <= 0 {
|
||||
return nil, fmt.Errorf("nemotronh: invalid attention head dimension %d", headDim)
|
||||
}
|
||||
|
||||
// Q projection
|
||||
query := a.Query.Forward(ctx, hiddenStates)
|
||||
if query.Dim(0)%headDim != 0 {
|
||||
return nil, fmt.Errorf("nemotronh: query dim %d not divisible by head dim %d", query.Dim(0), headDim)
|
||||
}
|
||||
numHeads := query.Dim(0) / headDim
|
||||
query = query.Reshape(ctx, headDim, numHeads, batchSize)
|
||||
|
||||
// K projection
|
||||
key := a.Key.Forward(ctx, hiddenStates)
|
||||
if key.Dim(0)%headDim != 0 {
|
||||
return nil, fmt.Errorf("nemotronh: key dim %d not divisible by head dim %d", key.Dim(0), headDim)
|
||||
}
|
||||
numKVHeads := key.Dim(0) / headDim
|
||||
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
|
||||
// V projection
|
||||
value := a.Value.Forward(ctx, hiddenStates)
|
||||
if value.Dim(0)%headDim != 0 {
|
||||
return nil, fmt.Errorf("nemotronh: value dim %d not divisible by head dim %d", value.Dim(0), headDim)
|
||||
}
|
||||
if value.Dim(0)/headDim != numKVHeads {
|
||||
return nil, fmt.Errorf("nemotronh: key heads %d and value heads %d do not match", numKVHeads, value.Dim(0)/headDim)
|
||||
}
|
||||
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
|
||||
// Standard attention computation (no RoPE)
|
||||
scale := opts.attentionScale
|
||||
if scale == 0 {
|
||||
scale = 1.0 / math.Sqrt(float64(headDim))
|
||||
}
|
||||
attention := nn.Attention(ctx, query, key, value, scale, cache)
|
||||
|
||||
// Flatten heads
|
||||
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
|
||||
|
||||
// Output projection
|
||||
return a.Output.Forward(ctx, attention), nil
|
||||
}
|
||||
55
model/models/nemotronh/cache.go
Normal file
55
model/models/nemotronh/cache.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
|
||||
// with the layer requirements.
|
||||
var ErrUnsupportedBatchLayout = errors.New("nemotronh: unsupported batch layout")
|
||||
|
||||
var (
|
||||
_ kvcache.Cache = (*HybridCache)(nil)
|
||||
_ kvcache.CheckpointCache = (*HybridCache)(nil)
|
||||
)
|
||||
|
||||
// HybridCache adapts the shared recurrent cache base for Nemotron-H naming.
|
||||
type HybridCache struct {
|
||||
*kvcache.Recurrent
|
||||
}
|
||||
|
||||
func NewHybridCache(convDim, convChannels, ssmStateSize int) *HybridCache {
|
||||
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
|
||||
Shift: Shift,
|
||||
ConvDim: convDim,
|
||||
ConvChannels: convChannels,
|
||||
RecurrentStateSize: ssmStateSize,
|
||||
CheckpointLogPrefix: "nemotronh",
|
||||
})
|
||||
return &HybridCache{Recurrent: base}
|
||||
}
|
||||
|
||||
// SSMState returns the SSM state for current batch sequences.
|
||||
func (c *HybridCache) SSMState(ctx ml.Context, layer int, dState, headDim, nHead int) (ml.Tensor, error) {
|
||||
return c.RecurrentState4D(ctx, layer, dState, headDim, nHead)
|
||||
}
|
||||
|
||||
// UpdateSSMState writes a new SSM state for current batch sequences.
|
||||
func (c *HybridCache) UpdateSSMState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
c.UpdateRecurrentState(ctx, layer, newState)
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.SlotsTensor()
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.SeqTokens()
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return c.NumSeqs()
|
||||
}
|
||||
197
model/models/nemotronh/mamba2.go
Normal file
197
model/models/nemotronh/mamba2.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
// convKernel wraps the 1D convolution kernel tensor
|
||||
type convKernel struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
// Mamba2 implements the Mamba2 SSM layer for Nemotron-H.
|
||||
// The forward pass follows llama.cpp's build_mamba2_layer:
|
||||
// 1. Input projection: zxBCdt = SSMIn @ hidden
|
||||
// 2. Split: z, xBC, dt
|
||||
// 3. Concat with conv state, apply SSMConv, save new conv state
|
||||
// 4. Apply SiLU to convolved xBC
|
||||
// 5. Split: x, B, C
|
||||
// 6. Add dt bias
|
||||
// 7. SSMScan: y = SSMScan(state, x, dt, A, B, C, ids)
|
||||
// 8. D skip: y = y + x * D
|
||||
// 9. Swiglu with z: y = z * silu(y)
|
||||
// 10. Group RMSNorm
|
||||
// 11. Output projection
|
||||
type Mamba2 struct {
|
||||
SSMIn *nn.Linear `gguf:"ssm_in"` // n_embd → d_in_proj (2*d_inner + 2*n_group*d_state + n_head)
|
||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"` // conv kernel
|
||||
SSMConv1DB ml.Tensor `gguf:"ssm_conv1d.bias"`
|
||||
SSMDtB ml.Tensor `gguf:"ssm_dt.bias"` // dt bias [n_head]
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // A parameter [1, n_head]
|
||||
SSMD ml.Tensor `gguf:"ssm_d"` // D skip connection [1, n_head]
|
||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` // group norm
|
||||
SSMOut *nn.Linear `gguf:"ssm_out"` // output projection
|
||||
Layer int
|
||||
}
|
||||
|
||||
func (m *Mamba2) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
layer := m.Layer
|
||||
hiddenDim := hiddenStates.Dim(0)
|
||||
nSeqTokens := hiddenStates.Dim(1)
|
||||
switch hiddenStates.Dim(2) {
|
||||
case 0:
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
|
||||
case 1:
|
||||
default:
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
|
||||
// Nemotron-H is currently clamped to num_parallel=1.
|
||||
if cache != nil && cache.IsSupportedForBatch() {
|
||||
if cache.numSeqs() != 1 {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
}
|
||||
nSeqs := 1
|
||||
|
||||
dConv := opts.ssmDConv
|
||||
dInner := opts.ssmDInner
|
||||
dState := opts.ssmDState
|
||||
nHead := opts.ssmNHead
|
||||
headDim := dInner / nHead
|
||||
nGroup := opts.ssmNGroup
|
||||
|
||||
// {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
|
||||
// d_in_proj = 2*d_inner + 2*n_group*d_state + n_head
|
||||
zxBCdt := m.SSMIn.Forward(ctx, hiddenStates)
|
||||
|
||||
// Split into z, xBC, dt
|
||||
// z: [head_dim, n_head, n_seq_tokens, n_seqs]
|
||||
z := zxBCdt.Slice(ctx, 0, 0, dInner, 1)
|
||||
z = z.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
|
||||
|
||||
// xBC: [d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs]
|
||||
xBCSize := dInner + 2*nGroup*dState
|
||||
xBC := zxBCdt.Slice(ctx, 0, dInner, dInner+xBCSize, 1)
|
||||
if nSeqTokens == 1 {
|
||||
xBC = xBC.Reshape(ctx, xBCSize, 1, nSeqs)
|
||||
}
|
||||
|
||||
// dt: [n_head, n_seq_tokens, n_seqs]
|
||||
dt := zxBCdt.Slice(ctx, 0, 2*dInner+2*nGroup*dState, 2*dInner+2*nGroup*dState+nHead, 1)
|
||||
if nSeqTokens == 1 {
|
||||
dt = dt.Reshape(ctx, nHead, 1, nSeqs)
|
||||
} else {
|
||||
dt = dt.Contiguous(ctx, nHead, nSeqTokens, nSeqs)
|
||||
}
|
||||
|
||||
// Get conv state from cache
|
||||
convStates, err := cache.ConvState(ctx, layer)
|
||||
if err != nil {
|
||||
slog.Warn("nemotronh: failed to get conv state, using zeros", "layer", layer, "error", err)
|
||||
convStates = ctx.Input().Zeros(ml.DTypeF32, dConv-1, xBCSize, nSeqs)
|
||||
}
|
||||
|
||||
// Reshape conv states: [d_conv-1, xBCSize, n_seqs]
|
||||
convStates = convStates.Reshape(ctx, dConv-1, xBCSize, nSeqs)
|
||||
|
||||
// For decode (n_seq_tokens == 1), reshape avoids a transpose/contiguous pair.
|
||||
var xBCT ml.Tensor
|
||||
if nSeqTokens == 1 {
|
||||
xBCT = xBC.Reshape(ctx, 1, xBCSize, nSeqs)
|
||||
} else {
|
||||
// Prefill path: [xBCSize, n_seq_tokens, n_seqs] -> [n_seq_tokens, xBCSize, n_seqs]
|
||||
xBCT = xBC.Permute(ctx, 1, 0, 2, 3)
|
||||
}
|
||||
|
||||
// Concatenate with conv state: [d_conv-1 + n_seq_tokens, xBCSize, n_seqs]
|
||||
convInput := convStates.Concat(ctx, xBCT, 0)
|
||||
|
||||
// Save new conv state (last d_conv-1 columns)
|
||||
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+dConv-1, 1)
|
||||
cache.UpdateConvState(ctx, layer, lastConvStates)
|
||||
|
||||
// Apply SSM convolution
|
||||
xBC = convInput.SSMConv(ctx, m.SSMConv1D.Weight)
|
||||
|
||||
// Add conv bias
|
||||
if m.SSMConv1DB != nil {
|
||||
xBC = xBC.Add(ctx, m.SSMConv1DB)
|
||||
}
|
||||
|
||||
// Apply SiLU
|
||||
xBC = xBC.SILU(ctx)
|
||||
|
||||
// Split xBC into x, B, C
|
||||
// x: [head_dim, n_head, n_seq_tokens, n_seqs]
|
||||
x := xBC.Slice(ctx, 0, 0, dInner, 1)
|
||||
x = x.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
|
||||
|
||||
// B: [d_state, n_group, n_seq_tokens, n_seqs]
|
||||
B := xBC.Slice(ctx, 0, dInner, dInner+nGroup*dState, 1)
|
||||
B = B.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
|
||||
|
||||
// C: [d_state, n_group, n_seq_tokens, n_seqs]
|
||||
C := xBC.Slice(ctx, 0, dInner+nGroup*dState, dInner+2*nGroup*dState, 1)
|
||||
C = C.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
|
||||
|
||||
// Add dt bias
|
||||
dt = dt.Add(ctx, m.SSMDtB)
|
||||
|
||||
// Get SSM state from cache
|
||||
state, err := cache.SSMState(ctx, layer, dState, headDim, nHead)
|
||||
if err != nil {
|
||||
slog.Warn("nemotronh: failed to get SSM state, using zeros", "layer", layer, "error", err)
|
||||
state = ctx.Input().Zeros(ml.DTypeF32, dState, headDim, nHead, nSeqs)
|
||||
}
|
||||
|
||||
// SSMScan
|
||||
// state: [d_state, head_dim, n_head, n_seqs]
|
||||
// returns: [head_dim, n_head, n_seq_tokens, n_seqs] concatenated with new state
|
||||
ySsm := state.SSMScan(ctx, x, dt, m.SSMA, B, C, cache.slotsTensor())
|
||||
|
||||
// ySsm is a packed 1D buffer: [y (nSeqTokens*headDim*nHead*nSeqs), newState]
|
||||
yElems := headDim * nHead * nSeqTokens * nSeqs
|
||||
y := ySsm.View(ctx, 0, yElems).Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
|
||||
|
||||
stateOffsetBytes := yElems * x.Stride(0)
|
||||
stateElems := dState * headDim * nHead * nSeqs
|
||||
newState := ySsm.View(ctx, stateOffsetBytes, stateElems)
|
||||
newState = newState.Reshape(ctx, dState, headDim, nHead, nSeqs)
|
||||
|
||||
// Update SSM state in cache
|
||||
cache.UpdateSSMState(ctx, layer, newState)
|
||||
|
||||
// D skip connection: y = y + x * D
|
||||
if m.SSMD != nil {
|
||||
// SSMD shape: [1, n_head] -> broadcast to [head_dim, n_head, n_seq_tokens, n_seqs]
|
||||
xD := x.Mul(ctx, m.SSMD)
|
||||
y = y.Add(ctx, xD)
|
||||
}
|
||||
|
||||
// Swiglu with z: y = z * silu(y)
|
||||
y = z.SILU(ctx, y)
|
||||
|
||||
// Group RMSNorm
|
||||
if m.SSMNorm != nil {
|
||||
// Reshape for group norm: [d_inner/n_group, n_group, n_seq_tokens, n_seqs]
|
||||
innerPerGroup := dInner / nGroup
|
||||
y = y.Reshape(ctx, innerPerGroup, nGroup, nSeqTokens, nSeqs)
|
||||
y = m.SSMNorm.Forward(ctx, y, opts.eps)
|
||||
}
|
||||
|
||||
// Reshape back to [d_inner, n_seq_tokens, n_seqs]
|
||||
y = y.Reshape(ctx, dInner, nSeqTokens, nSeqs)
|
||||
|
||||
// Output projection
|
||||
out := m.SSMOut.Forward(ctx, y)
|
||||
|
||||
// Reshape to 2D for consistency with attention output
|
||||
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
|
||||
}
|
||||
417
model/models/nemotronh/model.go
Normal file
417
model/models/nemotronh/model.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// Options contains model configuration
|
||||
type Options struct {
|
||||
hiddenSize int
|
||||
numHeads int // attention heads
|
||||
numKVHeads int // KV heads for attention layers
|
||||
headDim int
|
||||
eps float32
|
||||
|
||||
// Mamba2 SSM config
|
||||
ssmDConv int // conv kernel size
|
||||
ssmDInner int // inner dimension (d_inner)
|
||||
ssmDState int // state dimension
|
||||
ssmNHead int // number of SSM heads (dt_rank)
|
||||
ssmNGroup int // number of groups for B, C
|
||||
|
||||
// Per-layer configuration
|
||||
isRecurrent []bool // true = Mamba2, false = attention or FFN
|
||||
nFF []int // n_ff per layer (0 = attention-only)
|
||||
|
||||
// Attention scale
|
||||
attentionScale float64
|
||||
|
||||
// MoE config
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
expertWeightsNorm bool
|
||||
expertWeightsScale float32
|
||||
expertWeightsNormClip float32
|
||||
}
|
||||
|
||||
func (o Options) getHeadDim() int {
|
||||
if o.headDim > 0 {
|
||||
return o.headDim
|
||||
}
|
||||
if o.numHeads <= 0 {
|
||||
return 0
|
||||
}
|
||||
return o.hiddenSize / o.numHeads
|
||||
}
|
||||
|
||||
// Operator is the interface for layer operators (Mamba2 or Attention)
|
||||
type Operator interface {
|
||||
Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
|
||||
}
|
||||
|
||||
// MLP is the interface for feedforward networks
|
||||
type MLP interface {
|
||||
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
|
||||
}
|
||||
|
||||
// Layer represents a single transformer layer
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Operator Operator // Mamba2, Attention, or nil (for FFN-only layers)
|
||||
MLP MLP // Dense or MoE FFN, or nil
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
residual := hiddenStates
|
||||
|
||||
// Pre-layer norm
|
||||
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// Layer operator (Mamba2, Attention, or FFN)
|
||||
if l.Operator != nil {
|
||||
var err error
|
||||
hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, cache, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if l.MLP != nil {
|
||||
// FFN-only layer
|
||||
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
|
||||
}
|
||||
|
||||
// Output projection for last layer
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
return hiddenStates.Add(ctx, residual), nil
|
||||
}
|
||||
|
||||
// Model is the main Nemotron-H model
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Layers []Layer `gguf:"blk"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
// Shift is used for KV cache position shifting.
|
||||
// Nemotron-H attention does not apply RoPE, so keys do not need to be transformed.
|
||||
func Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
cache := m.Cache.(*HybridCache)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
var err error
|
||||
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, outputs, cache, m.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
layers := make([]Layer, numLayers)
|
||||
|
||||
// Get per-layer configuration from GGUF metadata
|
||||
// Use the same interface pattern as qwen3next
|
||||
type perLayerConfig interface {
|
||||
HeadCount() []uint64
|
||||
HeadCountKV() []uint64
|
||||
FFNLength() []uint64
|
||||
}
|
||||
|
||||
var headCount []uint64
|
||||
var headCountKV []uint64
|
||||
var ffnLength []uint64
|
||||
|
||||
if plc, ok := c.(perLayerConfig); ok {
|
||||
headCount = plc.HeadCount()
|
||||
headCountKV = plc.HeadCountKV()
|
||||
ffnLength = plc.FFNLength()
|
||||
}
|
||||
|
||||
// Build per-layer arrays with defaults
|
||||
isRecurrent := make([]bool, numLayers)
|
||||
nFF := make([]int, numLayers)
|
||||
|
||||
for i := range numLayers {
|
||||
// Get per-layer values
|
||||
kvHeads := uint64(1) // Default non-zero
|
||||
if i < len(headCountKV) {
|
||||
kvHeads = headCountKV[i]
|
||||
}
|
||||
ff := uint64(0)
|
||||
if i < len(ffnLength) {
|
||||
ff = ffnLength[i]
|
||||
}
|
||||
nFF[i] = int(ff)
|
||||
|
||||
// A layer is recurrent IFF n_head_kv == 0 AND n_ff == 0
|
||||
// This matches llama.cpp behavior for Nemotron-H
|
||||
isRecurrent[i] = kvHeads == 0 && ff == 0
|
||||
}
|
||||
|
||||
// Determine if MoE
|
||||
isMoE := c.Uint("expert_count") > 0
|
||||
|
||||
for i := range layers {
|
||||
if isRecurrent[i] {
|
||||
// Mamba2 layer
|
||||
layers[i].Operator = &Mamba2{Layer: i}
|
||||
} else if nFF[i] == 0 {
|
||||
// Attention-only layer (n_head_kv > 0, n_ff == 0)
|
||||
layers[i].Operator = &Attention{}
|
||||
} else {
|
||||
// FFN layer (n_ff > 0)
|
||||
if isMoE {
|
||||
layers[i].MLP = &MoESparse{}
|
||||
} else {
|
||||
layers[i].MLP = &Dense{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get attention head configuration
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
if numHeads == 0 {
|
||||
for i := range numLayers {
|
||||
if i < len(headCount) && i < len(headCountKV) && headCount[i] > 0 && headCountKV[i] > 0 {
|
||||
numHeads = int(headCount[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
numKVHeads := int(c.Uint("attention.head_count_kv"))
|
||||
if numKVHeads == 0 {
|
||||
for i := range numLayers {
|
||||
if i < len(headCountKV) && i < len(ffnLength) && headCountKV[i] > 0 && ffnLength[i] == 0 {
|
||||
numKVHeads = int(headCountKV[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
if numKVHeads == 0 {
|
||||
numKVHeads = numHeads
|
||||
}
|
||||
}
|
||||
|
||||
headDim := int(c.Uint("attention.head_dim"))
|
||||
if headDim == 0 {
|
||||
if keyLength := int(c.Uint("attention.key_length")); keyLength > 0 {
|
||||
headDim = keyLength
|
||||
} else if numHeads > 0 {
|
||||
headDim = int(c.Uint("embedding_length")) / numHeads
|
||||
}
|
||||
}
|
||||
if headDim <= 0 {
|
||||
return nil, fmt.Errorf("nemotronh: invalid attention head dimension")
|
||||
}
|
||||
if numHeads <= 0 {
|
||||
// Attention layers derive per-layer head counts from projection weights.
|
||||
// Keep a non-zero default to avoid invalid option math.
|
||||
numHeads = 1
|
||||
}
|
||||
|
||||
numExperts := int(c.Uint("expert_count"))
|
||||
numExpertsUsed := int(c.Uint("expert_used_count"))
|
||||
if numExperts > 0 {
|
||||
if numExpertsUsed <= 0 || numExpertsUsed > numExperts {
|
||||
return nil, fmt.Errorf("nemotronh: invalid expert_used_count=%d for expert_count=%d", numExpertsUsed, numExperts)
|
||||
}
|
||||
}
|
||||
|
||||
opts := &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: numHeads,
|
||||
numKVHeads: numKVHeads,
|
||||
headDim: headDim,
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ssmDConv: int(c.Uint("ssm.conv_kernel")),
|
||||
ssmDInner: int(c.Uint("ssm.inner_size")),
|
||||
ssmDState: int(c.Uint("ssm.state_size")),
|
||||
ssmNHead: int(c.Uint("ssm.time_step_rank")),
|
||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||
isRecurrent: isRecurrent,
|
||||
nFF: nFF,
|
||||
attentionScale: float64(c.Float("attention.scale")),
|
||||
numExperts: numExperts,
|
||||
numExpertsUsed: numExpertsUsed,
|
||||
expertWeightsNorm: c.Bool("expert_weights_norm", false),
|
||||
expertWeightsScale: c.Float("expert_weights_scale", 1.0),
|
||||
expertWeightsNormClip: c.Float("expert_weights_norm_clip", 0),
|
||||
}
|
||||
|
||||
// Calculate cache dimensions
|
||||
convDim := max(0, opts.ssmDConv-1)
|
||||
convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
|
||||
ssmHeadDim := 0
|
||||
if opts.ssmNHead > 0 {
|
||||
ssmHeadDim = opts.ssmDInner / opts.ssmNHead
|
||||
}
|
||||
ssmStateSize := opts.ssmDState * ssmHeadDim * opts.ssmNHead
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: opts,
|
||||
}
|
||||
|
||||
m.Cache = NewHybridCache(convDim, convChannels, ssmStateSize)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("nemotron_h", New)
|
||||
model.Register("nemotron_h_moe", New)
|
||||
}
|
||||
|
||||
// Ensure Model implements model.Model
|
||||
var _ model.Model = (*Model)(nil)
|
||||
|
||||
// Dense implements standard feedforward with ReLU-squared activation
|
||||
type Dense struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (d *Dense) Forward(ctx ml.Context, x ml.Tensor, opts *Options) ml.Tensor {
|
||||
// up -> ReLU-squared -> down
|
||||
up := d.Up.Forward(ctx, x)
|
||||
up = up.RELU(ctx)
|
||||
up = up.Mul(ctx, up) // Square
|
||||
return d.Down.Forward(ctx, up)
|
||||
}
|
||||
|
||||
// MoESparse implements MoE with shared experts for Nemotron-H-MoE
|
||||
type MoESparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
Bias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
|
||||
|
||||
LatentIn *nn.Linear `gguf:"ffn_latent_in"`
|
||||
LatentOut *nn.Linear `gguf:"ffn_latent_out"`
|
||||
|
||||
// Shared experts
|
||||
SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
|
||||
SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
|
||||
}
|
||||
|
||||
func (m *MoESparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenDim := hiddenStates.Dim(0)
|
||||
seqLen := hiddenStates.Dim(1)
|
||||
batchSize := hiddenStates.Dim(2)
|
||||
if batchSize == 0 {
|
||||
batchSize = 1
|
||||
}
|
||||
hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, seqLen*batchSize)
|
||||
|
||||
// Router logits with sigmoid gating
|
||||
routerLogits := m.Router.Forward(ctx, hiddenStates2D)
|
||||
|
||||
// Weights come from unbiased sigmoid probabilities.
|
||||
probs := routerLogits.Sigmoid(ctx)
|
||||
|
||||
// Selection uses optional bias.
|
||||
selectionProbs := probs
|
||||
if m.Bias != nil {
|
||||
selectionProbs = selectionProbs.Add(ctx, m.Bias)
|
||||
}
|
||||
|
||||
// Select top-k experts
|
||||
selectedExperts := selectionProbs.TopK(ctx, opts.numExpertsUsed)
|
||||
routingWeights := probs.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
|
||||
|
||||
// Normalize routing weights
|
||||
if opts.expertWeightsNorm {
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
|
||||
weightsSum := routingWeights.SumRows(ctx)
|
||||
weightsSum = weightsSum.Clamp(ctx, 6.103515625e-5, float32(math.MaxFloat32))
|
||||
routingWeights = routingWeights.Div(ctx, weightsSum)
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
|
||||
}
|
||||
|
||||
// Scale routing weights
|
||||
if opts.expertWeightsScale != 1.0 {
|
||||
routingWeights = routingWeights.Scale(ctx, float64(opts.expertWeightsScale))
|
||||
}
|
||||
|
||||
routedInput := hiddenStates2D
|
||||
if m.LatentIn != nil {
|
||||
routedInput = m.LatentIn.Forward(ctx, routedInput)
|
||||
}
|
||||
hiddenStates3D := routedInput.Reshape(ctx, routedInput.Dim(0), 1, routedInput.Dim(1))
|
||||
|
||||
// Expert computation with ReLU-squared activation
|
||||
upOut := m.Up.Forward(ctx, hiddenStates3D, selectedExperts)
|
||||
upOut = upOut.RELU(ctx)
|
||||
upOut = upOut.Mul(ctx, upOut) // Square
|
||||
experts := m.Down.Forward(ctx, upOut, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
// Sum over experts
|
||||
moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
if m.LatentOut != nil {
|
||||
moeOut = m.LatentOut.Forward(ctx, moeOut)
|
||||
}
|
||||
|
||||
// Add shared experts if present
|
||||
if m.SharedUp != nil {
|
||||
sharedUp := m.SharedUp.Forward(ctx, hiddenStates2D)
|
||||
sharedUp = sharedUp.RELU(ctx)
|
||||
sharedUp = sharedUp.Mul(ctx, sharedUp) // Square
|
||||
sharedOut := m.SharedDown.Forward(ctx, sharedUp)
|
||||
moeOut = moeOut.Add(ctx, sharedOut)
|
||||
}
|
||||
|
||||
return moeOut
|
||||
}
|
||||
@@ -32,6 +32,8 @@ type LFM2Parser struct {
|
||||
hasThinkingSupport bool
|
||||
needsThinkingLeadingTrim bool // trim leading whitespace after <think> tag
|
||||
needsContentLeadingTrim bool // trim leading whitespace after </think> tag
|
||||
toolNames map[string]struct{}
|
||||
hasTools bool
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) HasToolSupport() bool {
|
||||
@@ -63,6 +65,13 @@ func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.T
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.toolNames = make(map[string]struct{}, len(tools))
|
||||
p.hasTools = len(tools) > 0
|
||||
for _, tool := range tools {
|
||||
if tool.Function.Name != "" {
|
||||
p.toolNames[tool.Function.Name] = struct{}{}
|
||||
}
|
||||
}
|
||||
p.setInitialState(lastMessage, thinkValue)
|
||||
return tools
|
||||
}
|
||||
@@ -105,9 +114,33 @@ func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string,
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback for models that emit bare tool calls without <|tool_call_*|> wrappers.
|
||||
if done && len(toolCalls) == 0 && p.hasTools {
|
||||
candidate := strings.TrimSpace(contentSb.String())
|
||||
if fallbackCalls, parseErr := p.parseToolCallsContent(candidate); parseErr == nil && p.toolCallsAllowed(fallbackCalls) {
|
||||
contentSb.Reset()
|
||||
toolCalls = append(toolCalls, fallbackCalls...)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) toolCallsAllowed(calls []api.ToolCall) bool {
|
||||
if len(calls) == 0 {
|
||||
return false
|
||||
}
|
||||
if len(p.toolNames) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, call := range calls {
|
||||
if _, ok := p.toolNames[call.Function.Name]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) parseEvents() []lfm2Event {
|
||||
var all []lfm2Event
|
||||
|
||||
@@ -269,36 +302,16 @@ func (p *LFM2Parser) eat() ([]lfm2Event, bool) {
|
||||
return events, false
|
||||
}
|
||||
|
||||
// parseToolCallsContent parses one or more tool calls from content
|
||||
// Supports JSON format and Python-style format including multiple calls: [func1(...),func2(...)]
|
||||
// parseToolCallsContent parses one or more Python-style tool calls.
|
||||
// Example: [func1(arg='v'), func2(x=1)]
|
||||
func (p *LFM2Parser) parseToolCallsContent(content string) ([]api.ToolCall, error) {
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Try JSON format first: {"name": "func", "arguments": {...}}
|
||||
var parsed struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
}
|
||||
// Be tolerant of malformed outputs that include wrapper tags without proper pairing.
|
||||
content = strings.TrimSpace(strings.TrimPrefix(content, lfm2ToolCallStartTag))
|
||||
content = strings.TrimSpace(strings.TrimSuffix(content, lfm2ToolCallEndTag))
|
||||
|
||||
if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" {
|
||||
var args api.ToolCallFunctionArguments
|
||||
if len(parsed.Arguments) > 0 {
|
||||
if err := json.Unmarshal(parsed.Arguments, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
args = api.NewToolCallFunctionArguments()
|
||||
}
|
||||
|
||||
return []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: parsed.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Try Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
|
||||
// Parse Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
|
||||
return p.parsePythonStyleToolCalls(content)
|
||||
}
|
||||
|
||||
@@ -417,21 +430,16 @@ func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error)
|
||||
|
||||
// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2"
|
||||
func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error {
|
||||
// Simple state machine to parse key='value' pairs
|
||||
// Handles: command='ls', flag="-la", count=42, enabled=true
|
||||
var key string
|
||||
i := 0
|
||||
|
||||
for i < len(argsStr) {
|
||||
// Skip whitespace
|
||||
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
|
||||
// Skip separators and whitespace.
|
||||
for i < len(argsStr) && (argsStr[i] == ',' || unicode.IsSpace(rune(argsStr[i]))) {
|
||||
i++
|
||||
}
|
||||
if i >= len(argsStr) {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse key
|
||||
keyStart := i
|
||||
for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' {
|
||||
i++
|
||||
@@ -439,60 +447,238 @@ func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error
|
||||
if i >= len(argsStr) || argsStr[i] != '=' {
|
||||
return errors.New("invalid argument: expected '='")
|
||||
}
|
||||
key = strings.TrimSpace(argsStr[keyStart:i])
|
||||
|
||||
key := strings.TrimSpace(argsStr[keyStart:i])
|
||||
if key == "" {
|
||||
return errors.New("invalid argument: empty key")
|
||||
}
|
||||
i++ // skip '='
|
||||
|
||||
// Skip whitespace after =
|
||||
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') {
|
||||
for i < len(argsStr) && unicode.IsSpace(rune(argsStr[i])) {
|
||||
i++
|
||||
}
|
||||
|
||||
// Parse value
|
||||
var value string
|
||||
if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') {
|
||||
// Quoted string
|
||||
quote := argsStr[i]
|
||||
i++
|
||||
valueStart := i
|
||||
for i < len(argsStr) && argsStr[i] != quote {
|
||||
if argsStr[i] == '\\' && i+1 < len(argsStr) {
|
||||
i += 2 // skip escaped char
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
value = argsStr[valueStart:i]
|
||||
if i < len(argsStr) {
|
||||
i++ // skip closing quote
|
||||
}
|
||||
args.Set(key, value)
|
||||
} else {
|
||||
// Unquoted value (number, bool, etc)
|
||||
valueStart := i
|
||||
for i < len(argsStr) && argsStr[i] != ',' {
|
||||
i++
|
||||
}
|
||||
value = strings.TrimSpace(argsStr[valueStart:i])
|
||||
|
||||
// Try to parse as number or bool
|
||||
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||
args.Set(key, v)
|
||||
} else if v, err := strconv.ParseFloat(value, 64); err == nil {
|
||||
args.Set(key, v)
|
||||
} else if value == "true" {
|
||||
args.Set(key, true)
|
||||
} else if value == "false" {
|
||||
args.Set(key, false)
|
||||
} else {
|
||||
args.Set(key, value)
|
||||
}
|
||||
if i >= len(argsStr) {
|
||||
return errors.New("invalid argument: missing value")
|
||||
}
|
||||
|
||||
// Skip comma and whitespace
|
||||
for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
|
||||
value, next, err := parsePythonArgValue(argsStr, i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
args.Set(key, value)
|
||||
i = next
|
||||
|
||||
// Optional trailing comma before next key/value.
|
||||
if i < len(argsStr) && argsStr[i] == ',' {
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePythonArgValue(s string, i int) (any, int, error) {
|
||||
if i >= len(s) {
|
||||
return nil, i, errors.New("invalid argument: missing value")
|
||||
}
|
||||
|
||||
// Quoted string literal.
|
||||
if s[i] == '\'' || s[i] == '"' {
|
||||
quote := s[i]
|
||||
i++
|
||||
start := i
|
||||
for i < len(s) {
|
||||
if s[i] == '\\' && i+1 < len(s) {
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if s[i] == quote {
|
||||
value := s[start:i]
|
||||
i++
|
||||
return value, i, nil
|
||||
}
|
||||
i++
|
||||
}
|
||||
return nil, i, errors.New("invalid argument: unterminated string")
|
||||
}
|
||||
|
||||
// Unquoted literal. Consume until top-level comma.
|
||||
start := i
|
||||
depthParen, depthSquare, depthCurly := 0, 0, 0
|
||||
inString := false
|
||||
var quote byte
|
||||
escaped := false
|
||||
|
||||
for i < len(s) {
|
||||
ch := s[i]
|
||||
if inString {
|
||||
if escaped {
|
||||
escaped = false
|
||||
} else if ch == '\\' {
|
||||
escaped = true
|
||||
} else if ch == quote {
|
||||
inString = false
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case '\'', '"':
|
||||
inString = true
|
||||
quote = ch
|
||||
case '(':
|
||||
depthParen++
|
||||
case ')':
|
||||
if depthParen > 0 {
|
||||
depthParen--
|
||||
}
|
||||
case '[':
|
||||
depthSquare++
|
||||
case ']':
|
||||
if depthSquare > 0 {
|
||||
depthSquare--
|
||||
}
|
||||
case '{':
|
||||
depthCurly++
|
||||
case '}':
|
||||
if depthCurly > 0 {
|
||||
depthCurly--
|
||||
}
|
||||
case ',':
|
||||
if depthParen == 0 && depthSquare == 0 && depthCurly == 0 {
|
||||
token := strings.TrimSpace(s[start:i])
|
||||
value, err := parsePythonLiteral(token)
|
||||
return value, i, err
|
||||
}
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(s[start:i])
|
||||
value, err := parsePythonLiteral(token)
|
||||
return value, i, err
|
||||
}
|
||||
|
||||
func parsePythonLiteral(token string) (any, error) {
|
||||
switch token {
|
||||
case "":
|
||||
return "", nil
|
||||
case "true", "True":
|
||||
return true, nil
|
||||
case "false", "False":
|
||||
return false, nil
|
||||
case "null", "None":
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if v, err := strconv.ParseInt(token, 10, 64); err == nil {
|
||||
return v, nil
|
||||
}
|
||||
if v, err := strconv.ParseFloat(token, 64); err == nil {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(token, "[") || strings.HasPrefix(token, "{") {
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(token), &parsed); err == nil {
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
if converted, err := pythonLiteralToJSON(token); err == nil {
|
||||
if err := json.Unmarshal([]byte(converted), &parsed); err == nil {
|
||||
return parsed, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func pythonLiteralToJSON(s string) (string, error) {
|
||||
var out strings.Builder
|
||||
out.Grow(len(s) + len(s)/8)
|
||||
|
||||
inString := false
|
||||
var quote byte
|
||||
escaped := false
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
ch := s[i]
|
||||
|
||||
if inString {
|
||||
if escaped {
|
||||
out.WriteByte(ch)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '\\' {
|
||||
out.WriteByte(ch)
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == quote {
|
||||
out.WriteByte('"')
|
||||
inString = false
|
||||
continue
|
||||
}
|
||||
|
||||
if quote == '\'' && ch == '"' {
|
||||
out.WriteString(`\"`)
|
||||
continue
|
||||
}
|
||||
|
||||
out.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '\'' || ch == '"' {
|
||||
inString = true
|
||||
quote = ch
|
||||
escaped = false
|
||||
out.WriteByte('"')
|
||||
continue
|
||||
}
|
||||
|
||||
// Replace Python identifiers with JSON equivalents when outside strings.
|
||||
if isIdentStart(ch) {
|
||||
j := i + 1
|
||||
for j < len(s) && isIdentPart(s[j]) {
|
||||
j++
|
||||
}
|
||||
|
||||
ident := s[i:j]
|
||||
switch ident {
|
||||
case "True":
|
||||
out.WriteString("true")
|
||||
case "False":
|
||||
out.WriteString("false")
|
||||
case "None":
|
||||
out.WriteString("null")
|
||||
default:
|
||||
out.WriteString(ident)
|
||||
}
|
||||
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
|
||||
out.WriteByte(ch)
|
||||
}
|
||||
|
||||
if inString {
|
||||
return "", errors.New("unterminated string")
|
||||
}
|
||||
|
||||
return out.String(), nil
|
||||
}
|
||||
|
||||
func isIdentStart(b byte) bool {
|
||||
return (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || b == '_'
|
||||
}
|
||||
|
||||
func isIdentPart(b byte) bool {
|
||||
return isIdentStart(b) || (b >= '0' && b <= '9')
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func TestLFM2Parser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "tool_call_simple",
|
||||
input: "I'll check the weather.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
|
||||
input: "I'll check the weather.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>",
|
||||
expectedContent: "I'll check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
@@ -55,7 +55,7 @@ func TestLFM2Parser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "multiple_tool_calls",
|
||||
input: "Getting weather for both cities.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|><|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"London\"}}<|tool_call_end|>",
|
||||
input: "Getting weather for both cities.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|><|tool_call_start|>[get_weather(location=\"London\")]<|tool_call_end|>",
|
||||
expectedContent: "Getting weather for both cities.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
@@ -79,7 +79,7 @@ func TestLFM2Parser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "complex_tool_arguments",
|
||||
input: "Processing data.<|tool_call_start|>{\"name\":\"process_data\",\"arguments\":{\"items\":[\"item1\",\"item2\"],\"config\":{\"enabled\":true,\"threshold\":0.95}}}<|tool_call_end|>",
|
||||
input: "Processing data.<|tool_call_start|>[process_data(items=['item1','item2'], config={'enabled': True, 'threshold': 0.95})]<|tool_call_end|>",
|
||||
expectedContent: "Processing data.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
@@ -96,7 +96,7 @@ func TestLFM2Parser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "thinking_with_tool_call",
|
||||
input: "Let me check the weather...</think>I'll get that for you.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
|
||||
input: "Let me check the weather...</think>I'll get that for you.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>",
|
||||
expectedThinking: "Let me check the weather...",
|
||||
expectedContent: "I'll get that for you.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
@@ -144,16 +144,16 @@ func TestLFM2Parser(t *testing.T) {
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_unicode_args",
|
||||
input: "Searching for information.<|tool_call_start|>{\"name\":\"search\",\"arguments\":{\"query\":\"北京天气\",\"language\":\"中文\"}}<|tool_call_end|>",
|
||||
name: "tool_call_with_text_args",
|
||||
input: "Searching for information.<|tool_call_start|>[search(query='beijing weather', language='zh')]<|tool_call_end|>",
|
||||
expectedContent: "Searching for information.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"query": "北京天气",
|
||||
"language": "中文",
|
||||
"query": "beijing weather",
|
||||
"language": "zh",
|
||||
}),
|
||||
},
|
||||
},
|
||||
@@ -169,7 +169,7 @@ func TestLFM2Parser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "empty_tool_call_args",
|
||||
input: "Pinging server.<|tool_call_start|>{\"name\":\"ping\",\"arguments\":{}}<|tool_call_end|>",
|
||||
input: "Pinging server.<|tool_call_start|>[ping()]<|tool_call_end|>",
|
||||
expectedContent: "Pinging server.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
@@ -353,7 +353,7 @@ func TestLFM2Parser_Streaming(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "streaming_tool_call",
|
||||
chunks: []string{"I'll check weather.", "<|tool_call_start|>", "{\"name\":\"get_weather\",", "\"arguments\":{\"location\":\"Paris\"}}", "<|tool_call_end|>"},
|
||||
chunks: []string{"I'll check weather.", "<|tool_call_start|>", "[get_weather(", "location=\"Paris\")]", "<|tool_call_end|>"},
|
||||
expectedContent: "I'll check weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
@@ -381,16 +381,16 @@ func TestLFM2Parser_Streaming(t *testing.T) {
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "streaming_tool_call_with_split_json",
|
||||
chunks: []string{"Processing.", "<|tool_call_start|>{\"name\":\"calc\",\"arguments\":{\"x\":", "42,\"y\":", "24}}<|tool_call_end|>"},
|
||||
name: "streaming_tool_call_with_split_python",
|
||||
chunks: []string{"Processing.", "<|tool_call_start|>", "[calc(", "x=42, ", "y=24)]", "<|tool_call_end|>"},
|
||||
expectedContent: "Processing.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calc",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"x": float64(42),
|
||||
"y": float64(24),
|
||||
"x": int64(42),
|
||||
"y": int64(24),
|
||||
}),
|
||||
},
|
||||
},
|
||||
@@ -516,8 +516,8 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) {
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid_tool_call",
|
||||
content: `{"name":"get_weather","arguments":{"location":"Paris"}}`,
|
||||
name: "python_style_single_call",
|
||||
content: `get_weather(location="Paris")`,
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
@@ -528,21 +528,33 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex_arguments",
|
||||
content: `{"name":"process_data","arguments":{"items":["a","b"],"config":{"enabled":true}}}`,
|
||||
name: "python_style_with_brackets",
|
||||
content: `[get_weather(location="Paris")]`,
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"items": []interface{}{"a", "b"},
|
||||
"config": map[string]interface{}{"enabled": true},
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty_arguments",
|
||||
content: `{"name":"ping","arguments":{}}`,
|
||||
name: "python_style_complex_arguments",
|
||||
content: `process(items=['a', 'b'], config={'enabled': True})`,
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"items": []any{"a", "b"},
|
||||
"config": map[string]any{"enabled": true},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "python_style_empty_arguments",
|
||||
content: `ping()`,
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "ping",
|
||||
@@ -551,44 +563,13 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unicode_in_tool_name",
|
||||
content: `{"name":"获取天气","arguments":{"城市":"北京"}}`,
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"城市": "北京",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "numeric_arguments",
|
||||
content: `{"name":"calculate","arguments":{"x":3.14,"y":42,"enabled":true}}`,
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"x": 3.14,
|
||||
"y": float64(42),
|
||||
"enabled": true,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid_json",
|
||||
content: `{invalid json}`,
|
||||
name: "missing_parenthesis",
|
||||
content: `get_weather location="Paris")`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing_name",
|
||||
content: `{"arguments":{"arg":"value"}}`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty_name",
|
||||
content: `{"name":"","arguments":{"arg":"value"}}`,
|
||||
name: "invalid_argument_format",
|
||||
content: `bash(command)`,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
@@ -645,6 +626,24 @@ func TestLFM2Parser_parseToolCallsContent(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "python_style_complex_literals",
|
||||
content: `[AskUserQuestion(question="What's up?", headers=['Hello!', 'How can I help you?'], options=['Debugging help', 'Code writing assistance'], multiSelect=False, metadata={'priority': 1, 'active': True})]`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "AskUserQuestion",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"question": "What's up?",
|
||||
"headers": []any{"Hello!", "How can I help you?"},
|
||||
"options": []any{"Debugging help", "Code writing assistance"},
|
||||
"multiSelect": false,
|
||||
"metadata": map[string]any{"priority": float64(1), "active": true},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single_python_style_call",
|
||||
content: `bash(command='ls -la')`,
|
||||
@@ -673,6 +672,34 @@ func TestLFM2Parser_parseToolCallsContent(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single_call_with_orphan_end_tag",
|
||||
content: `[bash(command='ls')]<|tool_call_end|>`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bash",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "ls",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single_call_with_wrapper_tags",
|
||||
content: `<|tool_call_start|>[bash(command='pwd')]<|tool_call_end|>`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bash",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "pwd",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple_different_functions",
|
||||
content: `[get_weather(location='Paris'),search(query='news')]`,
|
||||
@@ -1086,3 +1113,106 @@ func TestLFM2Parser_EdgeCases(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2Parser_BareToolCallFallback(t *testing.T) {
|
||||
parser := &LFM2Parser{}
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
},
|
||||
},
|
||||
}
|
||||
parser.Init(tools, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
content, thinking, calls, err := parser.Add(`[get_weather(location="Paris")]`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name get_weather, got %q", calls[0].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2Parser_BareUnknownToolCallDoesNotParse(t *testing.T) {
|
||||
parser := &LFM2Parser{}
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
},
|
||||
},
|
||||
}
|
||||
parser.Init(tools, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
input := `[unknown_tool(location="Paris")]`
|
||||
content, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
|
||||
if content != input {
|
||||
t.Fatalf("expected content to be preserved, got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2Parser_ImagePlaceholdersPreserved(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{
|
||||
name: "indexed_img_placeholder",
|
||||
input: "[img-0]describe this image",
|
||||
},
|
||||
{
|
||||
name: "template_image_placeholder",
|
||||
input: "<image>describe this image",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &LFM2Parser{}
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "bash",
|
||||
},
|
||||
},
|
||||
}
|
||||
parser.Init(tools, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
content, thinking, calls, err := parser.Add(tt.input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
|
||||
if content != tt.input {
|
||||
t.Fatalf("expected content %q, got %q", tt.input, content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,6 +57,8 @@ func TestBuiltInParsersStillWork(t *testing.T) {
|
||||
{"qwen3"},
|
||||
{"qwen3-thinking"},
|
||||
{"qwen3-coder"},
|
||||
{"lfm2"},
|
||||
{"lfm2-thinking"},
|
||||
{"harmony"},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -9,18 +11,218 @@ import (
|
||||
|
||||
type LFM2Renderer struct {
|
||||
IsThinking bool
|
||||
useImgTags bool
|
||||
}
|
||||
|
||||
const lfm2BOSToken = "<|startoftext|>"
|
||||
|
||||
const (
|
||||
lfm2ToolListStartTag = "<|tool_list_start|>"
|
||||
lfm2ToolListEndTag = "<|tool_list_end|>"
|
||||
lfm2ToolCallStartTag = "<|tool_call_start|>"
|
||||
lfm2ToolCallEndTag = "<|tool_call_end|>"
|
||||
lfm2ToolResponseStartTag = "<|tool_response_start|>"
|
||||
lfm2ToolResponseEndTag = "<|tool_response_end|>"
|
||||
)
|
||||
|
||||
func lfm2RenderSystemContent(content any) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []any:
|
||||
var sb strings.Builder
|
||||
for _, item := range v {
|
||||
obj, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if itemType, _ := obj["type"].(string); itemType == "text" {
|
||||
if text, ok := obj["text"].(string); ok {
|
||||
sb.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func lfm2JSON(v any) string {
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.SetEscapeHTML(false)
|
||||
if err := enc.Encode(v); err != nil {
|
||||
fallback, _ := json.Marshal(v)
|
||||
return string(fallback)
|
||||
}
|
||||
|
||||
encoded := bytes.TrimSuffix(buf.Bytes(), []byte{'\n'})
|
||||
|
||||
// HF `tojson` defaults to `json.dumps(..., separators=None)`, which inserts
|
||||
// a space after commas and colons.
|
||||
var out strings.Builder
|
||||
out.Grow(len(encoded) + len(encoded)/8)
|
||||
|
||||
inString := false
|
||||
escaped := false
|
||||
for i, b := range encoded {
|
||||
out.WriteByte(b)
|
||||
|
||||
if inString {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if b == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if b == '"' {
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if b == '"' {
|
||||
inString = true
|
||||
continue
|
||||
}
|
||||
|
||||
if (b == ':' || b == ',') && i+1 < len(encoded) {
|
||||
next := encoded[i+1]
|
||||
if next != ' ' && next != '\n' && next != '\r' && next != '\t' {
|
||||
out.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func lfm2ImagePlaceholder(useImgTags bool) string {
|
||||
if useImgTags {
|
||||
return "[img]"
|
||||
}
|
||||
|
||||
return "<image>"
|
||||
}
|
||||
|
||||
func lfm2RenderContent(content any, useImgTags bool) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []any:
|
||||
var sb strings.Builder
|
||||
for _, item := range v {
|
||||
obj, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
sb.WriteString(lfm2JSON(item))
|
||||
continue
|
||||
}
|
||||
|
||||
itemType, _ := obj["type"].(string)
|
||||
switch itemType {
|
||||
case "image":
|
||||
sb.WriteString(lfm2ImagePlaceholder(useImgTags))
|
||||
case "text":
|
||||
if text, ok := obj["text"].(string); ok {
|
||||
sb.WriteString(text)
|
||||
} else {
|
||||
sb.WriteString(lfm2JSON(item))
|
||||
}
|
||||
default:
|
||||
sb.WriteString(lfm2JSON(item))
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
default:
|
||||
return lfm2JSON(content)
|
||||
}
|
||||
}
|
||||
|
||||
func lfm2ToolSchema(tool api.Tool) any {
|
||||
if tool.Function.Name == "" {
|
||||
return tool
|
||||
}
|
||||
|
||||
// LFM2 templates are typically fed function-schema objects (name/description/parameters).
|
||||
return tool.Function
|
||||
}
|
||||
|
||||
func lfm2ToolCallArgument(v any) string {
|
||||
return lfm2JSON(v)
|
||||
}
|
||||
|
||||
func lfm2RenderToolCalls(calls []api.ToolCall) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(lfm2ToolCallStartTag)
|
||||
sb.WriteString("[")
|
||||
for i, tc := range calls {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
|
||||
sb.WriteString(tc.Function.Name)
|
||||
sb.WriteString("(")
|
||||
|
||||
keys := make([]string, 0, tc.Function.Arguments.Len())
|
||||
for key := range tc.Function.Arguments.All() {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for j, key := range keys {
|
||||
if j > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
value, _ := tc.Function.Arguments.Get(key)
|
||||
sb.WriteString(key)
|
||||
sb.WriteString("=")
|
||||
sb.WriteString(lfm2ToolCallArgument(value))
|
||||
}
|
||||
|
||||
sb.WriteString(")")
|
||||
}
|
||||
sb.WriteString("]")
|
||||
sb.WriteString(lfm2ToolCallEndTag)
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *LFM2Renderer) renderMessageContent(message api.Message) string {
|
||||
content := lfm2RenderContent(message.Content, r.useImgTags)
|
||||
if len(message.Images) == 0 {
|
||||
return content
|
||||
}
|
||||
|
||||
// chatPrompt may already have inserted [img] / [img-n] placeholders.
|
||||
if strings.Contains(content, "[img]") || strings.Contains(content, "[img-") || strings.Contains(content, "<image>") {
|
||||
return content
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
placeholder := lfm2ImagePlaceholder(r.useImgTags)
|
||||
for range message.Images {
|
||||
sb.WriteString(placeholder)
|
||||
}
|
||||
sb.WriteString(content)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
// Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer
|
||||
// Follow Liquid tool-use formatting for LFM2 tool wrappers.
|
||||
sb.WriteString(lfm2BOSToken)
|
||||
|
||||
// Extract first system message if present (to combine with tools)
|
||||
var firstSystemContent string
|
||||
startIdx := 0
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
firstSystemContent = messages[0].Content
|
||||
firstSystemContent = lfm2RenderSystemContent(messages[0].Content)
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
@@ -29,18 +231,17 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
|
||||
if firstSystemContent != "" {
|
||||
firstSystemContent += "\n"
|
||||
}
|
||||
firstSystemContent += "List of tools: ["
|
||||
firstSystemContent += "List of tools: "
|
||||
firstSystemContent += lfm2ToolListStartTag
|
||||
firstSystemContent += "["
|
||||
for i, tool := range tools {
|
||||
toolJSON, err := json.Marshal(tool)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
firstSystemContent += string(toolJSON)
|
||||
firstSystemContent += lfm2JSON(lfm2ToolSchema(tool))
|
||||
if i < len(tools)-1 {
|
||||
firstSystemContent += ", "
|
||||
}
|
||||
}
|
||||
firstSystemContent += "]"
|
||||
firstSystemContent += lfm2ToolListEndTag
|
||||
}
|
||||
|
||||
// Output first system block if it has content
|
||||
@@ -50,6 +251,8 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
|
||||
|
||||
// Find the index of the last assistant message for thinking stripping
|
||||
lastAssistantIndex := -1
|
||||
for i := len(messages) - 1; i >= startIdx; i-- {
|
||||
@@ -59,85 +262,47 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
|
||||
}
|
||||
}
|
||||
|
||||
// Track whether we need to add generation prompt
|
||||
needsGenerationPrompt := len(messages) > 0
|
||||
|
||||
for i := startIdx; i < len(messages); i++ {
|
||||
message := messages[i]
|
||||
switch message.Role {
|
||||
case "system":
|
||||
// Additional system messages (after the first) are rendered normally
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
lastMessage := i == len(messages)-1
|
||||
prefill := lastMessage && message.Role == "assistant"
|
||||
|
||||
case "user":
|
||||
sb.WriteString("<|im_start|>user\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true
|
||||
sb.WriteString("<|im_start|>")
|
||||
sb.WriteString(message.Role)
|
||||
sb.WriteString("\n")
|
||||
|
||||
case "assistant":
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
|
||||
// Check if this is the last assistant message
|
||||
isLastAssistant := i == lastAssistantIndex
|
||||
|
||||
// Process content (may need thinking stripped)
|
||||
content := message.Content
|
||||
|
||||
// Handle thinking tags in assistant content
|
||||
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
|
||||
if strings.Contains(content, "</think>") {
|
||||
parts := strings.SplitN(content, "</think>", 2)
|
||||
if len(parts) > 1 {
|
||||
if !isLastAssistant && !keepPastThinking {
|
||||
// Strip thinking entirely for past assistant messages
|
||||
content = strings.TrimSpace(parts[1])
|
||||
} else {
|
||||
// Preserve thinking but trim whitespace after </think>
|
||||
content = parts[0] + "</think>" + strings.TrimLeft(parts[1], " \t\n\r")
|
||||
}
|
||||
}
|
||||
content := r.renderMessageContent(message)
|
||||
if message.Role == "assistant" && !keepPastThinking && i != lastAssistantIndex {
|
||||
if idx := strings.LastIndex(content, "</think>"); idx >= 0 {
|
||||
content = strings.TrimSpace(content[idx+len("</think>"):])
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
// Assistant with tool calls - write content first (if any after stripping)
|
||||
if content != "" {
|
||||
sb.WriteString(content)
|
||||
}
|
||||
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<|tool_call_start|>")
|
||||
toolCallJSON := map[string]any{
|
||||
"name": toolCall.Function.Name,
|
||||
"arguments": toolCall.Function.Arguments,
|
||||
}
|
||||
callJSON, _ := json.Marshal(toolCallJSON)
|
||||
sb.WriteString(string(callJSON))
|
||||
sb.WriteString("<|tool_call_end|>")
|
||||
}
|
||||
}
|
||||
if message.Role == "assistant" && len(message.ToolCalls) > 0 && !strings.Contains(content, lfm2ToolCallStartTag) {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = lfm2RenderToolCalls(message.ToolCalls) + content
|
||||
} else {
|
||||
sb.WriteString(content)
|
||||
content = lfm2RenderToolCalls(message.ToolCalls) + "\n" + content
|
||||
}
|
||||
}
|
||||
if message.Role == "tool" && !strings.Contains(content, lfm2ToolResponseStartTag) {
|
||||
content = lfm2ToolResponseStartTag + content + lfm2ToolResponseEndTag
|
||||
}
|
||||
|
||||
sb.WriteString(content)
|
||||
if !prefill {
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true
|
||||
|
||||
case "tool":
|
||||
// Tool responses are rendered as plain messages per the chat template
|
||||
sb.WriteString("<|im_start|>tool\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt
|
||||
needsGenerationPrompt := true
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "assistant" {
|
||||
needsGenerationPrompt = false
|
||||
}
|
||||
|
||||
if needsGenerationPrompt {
|
||||
// RenderWithRenderer uses add_generation_prompt=true for chat rendering,
|
||||
// unless we're prefilling a trailing assistant message.
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
// Note: Model is a "thinking-only" model - it will output <think> itself
|
||||
// We don't add <think> tag to the prompt
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
|
||||
@@ -8,73 +8,136 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestLFM2Renderer(t *testing.T) {
|
||||
func TestLFM2Renderer_ChatTemplateParity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
renderer *LFM2Renderer
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic user message",
|
||||
name: "user_only",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
expected: "<|startoftext|><|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "basic with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multiple system messages rendered separately",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "First instruction."},
|
||||
{Role: "system", Content: "Second instruction."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
{Role: "assistant", Content: "The answer is 4."},
|
||||
{Role: "user", Content: "Thanks!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "only system message",
|
||||
name: "system_and_user",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n",
|
||||
expected: "<|startoftext|><|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false)
|
||||
name: "user-assistant-user: last assistant preserves thinking",
|
||||
name: "tools_without_system",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reasoning</think>A1"},
|
||||
{Role: "user", Content: "Q2"},
|
||||
{Role: "user", Content: "Use tools"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reasoning</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n",
|
||||
expected: "<|startoftext|><|im_start|>system\nList of tools: <|tool_list_start|>[{\"name\": \"get_weather\", \"parameters\": {\"type\": \"object\", \"properties\": null}}]<|tool_list_end|><|im_end|>\n" +
|
||||
"<|im_start|>user\nUse tools<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// With two assistants, first is stripped (not last), second preserved (is last)
|
||||
name: "multi-turn thinking: first stripped, second preserved",
|
||||
name: "first_system_combined_with_tools",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "Follow instructions."},
|
||||
{Role: "user", Content: "Do work"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "tool_a",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "tool_b",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|startoftext|><|im_start|>system\nFollow instructions.\nList of tools: <|tool_list_start|>[{\"name\": \"tool_a\", \"parameters\": {\"type\": \"object\", \"properties\": null}}, {\"name\": \"tool_b\", \"parameters\": {\"type\": \"object\", \"properties\": null}}]<|tool_list_end|><|im_end|>\n" +
|
||||
"<|im_start|>user\nDo work<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "assistant_tool_calls_and_tool_responses_are_rendered",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Call a tool"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|startoftext|><|im_start|>user\nCall a tool<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|><|im_end|>\n<|im_start|>tool\n<|tool_response_start|>22C<|tool_response_end|><|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "assistant_tool_calls_with_content_preserves_both",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Call a tool"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Checking now.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|startoftext|><|im_start|>user\nCall a tool<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>\nChecking now.",
|
||||
},
|
||||
{
|
||||
name: "thinking_strips_non_last_assistant_when_disabled",
|
||||
renderer: &LFM2Renderer{IsThinking: true},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reason1</think>A1"},
|
||||
@@ -82,11 +145,11 @@ func TestLFM2Renderer(t *testing.T) {
|
||||
{Role: "assistant", Content: "<think>reason2</think>A2"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
|
||||
expected: "<|startoftext|><|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2",
|
||||
},
|
||||
{
|
||||
// With thinking enabled (keep_past_thinking=true), both preserved
|
||||
name: "multi-turn thinking: both preserved when thinking enabled",
|
||||
name: "thinking_preserves_past_assistant_when_enabled",
|
||||
renderer: &LFM2Renderer{IsThinking: true},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reason1</think>A1"},
|
||||
@@ -94,334 +157,137 @@ func TestLFM2Renderer(t *testing.T) {
|
||||
{Role: "assistant", Content: "<think>reason2</think>A2"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
|
||||
expected: "<|startoftext|><|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2",
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls",
|
||||
name: "arbitrary_roles_are_rendered_verbatim",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "developer", Content: "Do X"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
expected: "<|startoftext|><|im_start|>developer\nDo X<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "assistant with content and tool calls",
|
||||
name: "empty_messages_still_add_generation_prompt",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: nil,
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|startoftext|><|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "assistant_prefill_no_generation_prompt",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tool response",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{Role: "assistant", Content: "Let me check."},
|
||||
{Role: "tool", Content: "22C, Sunny"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Get weather for Paris and London"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "London",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tools definitions with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tools definitions without system message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "multiple tools without system message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_time",
|
||||
Description: "Get time",
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nList of tools: [{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather\",\"parameters\":{\"type\":\"\",\"properties\":null}}}, {\"type\":\"function\",\"function\":{\"name\":\"get_time\",\"description\":\"Get time\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "user-tool sequence",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Check weather"},
|
||||
{Role: "tool", Content: "22C"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "full tool call cycle",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Check weather"},
|
||||
{Role: "assistant", Content: "Let me check"},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "It's 22C"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "unicode content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "你好世界! مرحبا 🌍"},
|
||||
{Role: "assistant", Content: "Hello! 👋"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "newlines in content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Line 1\nLine 2\n\nLine 4"},
|
||||
{Role: "assistant", Content: "Response with\nmultiple\nlines"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "empty assistant content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: ""},
|
||||
{Role: "user", Content: "OK"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Generation prompt does NOT include <think> - model outputs it
|
||||
name: "generation prompt has no think tag",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Think hard"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Interleaved: thinking before tool call - last assistant preserves thinking
|
||||
name: "thinking before tool call (last assistant)",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>I need to check the weather</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>I need to check the weather</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Two assistants with tool calls - first has thinking stripped
|
||||
name: "two assistants with tools: first thinking stripped",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>checking</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Two assistants with tools - both preserved when thinking enabled
|
||||
name: "two assistants with tools: both preserved when thinking enabled",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>checking</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>checking</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Content before thinking before tool call
|
||||
name: "content then thinking then tool call",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check.<think>Using weather API</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<think>Using weather API</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
|
||||
expected: "<|startoftext|><|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\nHello",
|
||||
},
|
||||
}
|
||||
|
||||
renderer := &LFM2Renderer{IsThinking: true}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
rendered, err := tt.renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Render() error = %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
|
||||
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
|
||||
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2Renderer_Images(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
renderer *LFM2Renderer
|
||||
message api.Message
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single_image_default_placeholder",
|
||||
renderer: &LFM2Renderer{},
|
||||
message: api.Message{
|
||||
Role: "user",
|
||||
Content: "Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
expected: "<|startoftext|><|im_start|>user\n<image>Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multiple_images_default_placeholder",
|
||||
renderer: &LFM2Renderer{},
|
||||
message: api.Message{
|
||||
Role: "user",
|
||||
Content: "Describe these images.",
|
||||
Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")},
|
||||
},
|
||||
expected: "<|startoftext|><|im_start|>user\n<image><image>Describe these images.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "single_image_img_tag_placeholder",
|
||||
renderer: &LFM2Renderer{useImgTags: true},
|
||||
message: api.Message{
|
||||
Role: "user",
|
||||
Content: "Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
expected: "<|startoftext|><|im_start|>user\n[img]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "existing_indexed_img_placeholder_not_duplicated",
|
||||
renderer: &LFM2Renderer{useImgTags: true},
|
||||
message: api.Message{
|
||||
Role: "user",
|
||||
Content: "[img-0]Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
expected: "<|startoftext|><|im_start|>user\n[img-0]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "existing_template_image_placeholder_not_duplicated",
|
||||
renderer: &LFM2Renderer{},
|
||||
message: api.Message{
|
||||
Role: "user",
|
||||
Content: "<image>Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
expected: "<|startoftext|><|im_start|>user\n<image>Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.renderer.Render([]api.Message{tt.message}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Render() error = %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.expected, got); diff != "" {
|
||||
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2Renderer_JSONFormatting(t *testing.T) {
|
||||
tool := api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "echo",
|
||||
Description: "<html>",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := lfm2JSON(tool)
|
||||
want := "{\"type\": \"function\", \"function\": {\"name\": \"echo\", \"description\": \"<html>\", \"parameters\": {\"type\": \"object\", \"properties\": null}}}"
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Fatalf("lfm2JSON mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,9 +85,9 @@ func rendererForName(name string) Renderer {
|
||||
case "glm-ocr":
|
||||
return &GlmOcrRenderer{}
|
||||
case "lfm2":
|
||||
return &LFM2Renderer{IsThinking: false}
|
||||
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
|
||||
case "lfm2-thinking":
|
||||
return &LFM2Renderer{IsThinking: true}
|
||||
return &LFM2Renderer{IsThinking: true, useImgTags: RenderImgTags}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -447,7 +447,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
|
||||
|
||||
// Some architectures are not safe with num_parallel > 1.
|
||||
// ref: https://github.com/ollama/ollama/issues/4165
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
numParallel = 1
|
||||
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
|
||||
}
|
||||
|
||||
@@ -37,9 +37,11 @@ func QuantizationParams(quantization string) (groupSize, bits int, mode string)
|
||||
case "MXFP8":
|
||||
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbias)
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8", "":
|
||||
case "FP8", "Q8", "INT8":
|
||||
// 8-bit quantization with affine mode (default for quantized models)
|
||||
return 64, 8, "affine"
|
||||
case "":
|
||||
return 0, 0, ""
|
||||
default:
|
||||
return 32, 8, "affine" // Default to affine
|
||||
}
|
||||
|
||||
@@ -3,94 +3,65 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// CacheEntry stores a single sequence
|
||||
type CacheEntry struct {
|
||||
Caches []cache.Cache
|
||||
Count int
|
||||
Entries map[int32]*CacheEntry
|
||||
Tokens []int32
|
||||
Caches []cache.Cache
|
||||
}
|
||||
|
||||
func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
||||
current := &CacheEntry{Entries: s.CacheEntries}
|
||||
index, cacheIndex := 0, -1
|
||||
for _, token := range tokens {
|
||||
if _, ok := current.Entries[token]; !ok {
|
||||
break
|
||||
}
|
||||
|
||||
current = current.Entries[token]
|
||||
if len(current.Caches) > 0 {
|
||||
cacheIndex = index
|
||||
}
|
||||
|
||||
index += 1
|
||||
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
|
||||
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
||||
if r.cache == nil {
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
}
|
||||
|
||||
if cacheIndex == len(tokens)-1 {
|
||||
slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
|
||||
return current.Caches, []int32{}
|
||||
} else if cacheIndex > 1 {
|
||||
slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
|
||||
return current.Caches, tokens[cacheIndex+1:]
|
||||
} else if index > 0 && cacheIndex < 0 {
|
||||
type stackItem struct {
|
||||
entry *CacheEntry
|
||||
tokens []int32
|
||||
}
|
||||
|
||||
var best, item stackItem
|
||||
stack := []stackItem{{entry: current, tokens: []int32{}}}
|
||||
for len(stack) > 0 {
|
||||
item, stack = stack[len(stack)-1], stack[:len(stack)-1]
|
||||
if len(item.entry.Caches) > 0 {
|
||||
if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
|
||||
best = item
|
||||
}
|
||||
} else {
|
||||
for token, entry := range item.entry.Entries {
|
||||
stack = append(stack, stackItem{
|
||||
entry: entry,
|
||||
tokens: append(item.tokens, token),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prefix := min(len(tokens)-1, index)
|
||||
caches := make([]cache.Cache, len(best.entry.Caches))
|
||||
trim := len(best.tokens)+1
|
||||
for i := range caches {
|
||||
caches[i] = best.entry.Caches[i].Clone()
|
||||
caches[i].Trim(trim)
|
||||
}
|
||||
|
||||
slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
|
||||
return caches, tokens[prefix:]
|
||||
// Find longest common prefix
|
||||
prefix := 0
|
||||
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
|
||||
prefix++
|
||||
}
|
||||
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
switch {
|
||||
case prefix == 0:
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Free()
|
||||
}
|
||||
r.cache = nil
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
case prefix < len(r.cache.Tokens):
|
||||
trim := len(r.cache.Tokens) - prefix
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Trim(trim)
|
||||
}
|
||||
r.cache.Tokens = r.cache.Tokens[:prefix]
|
||||
}
|
||||
|
||||
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
||||
return r.cache.Caches, tokens[prefix:]
|
||||
}
|
||||
|
||||
func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
||||
current := &CacheEntry{Entries: s.CacheEntries}
|
||||
for _, token := range tokens {
|
||||
if _, ok := current.Entries[token]; !ok {
|
||||
current.Entries[token] = &CacheEntry{
|
||||
Entries: make(map[int32]*CacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
current = current.Entries[token]
|
||||
}
|
||||
|
||||
if len(current.Caches) > 0 {
|
||||
current.Count += 1
|
||||
} else {
|
||||
current.Caches = caches
|
||||
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
||||
r.cache = &CacheEntry{
|
||||
Tokens: tokens,
|
||||
Caches: caches,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CacheEntry) LogCache() {
|
||||
var totalBytes int
|
||||
for _, kv := range c.Caches {
|
||||
k, v := kv.State()
|
||||
totalBytes += k.NumBytes() + v.NumBytes()
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||
}
|
||||
|
||||
22
x/mlxrunner/cache/cache.go
vendored
22
x/mlxrunner/cache/cache.go
vendored
@@ -3,8 +3,7 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
@@ -13,6 +12,7 @@ type Cache interface {
|
||||
State() (keys, values *mlx.Array)
|
||||
Trim(int) int
|
||||
Clone() Cache
|
||||
Free()
|
||||
Offset() int
|
||||
Len() int
|
||||
}
|
||||
@@ -47,6 +47,7 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
c.values.Set(c.values.Concatenate(2, newValues))
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
mlx.Pin(c.keys, c.values)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,12 +74,19 @@ func (c *KVCache) Trim(n int) int {
|
||||
}
|
||||
|
||||
func (c *KVCache) Clone() Cache {
|
||||
return &KVCache{
|
||||
clone := &KVCache{
|
||||
keys: c.keys.Clone(),
|
||||
values: c.values.Clone(),
|
||||
offset: c.offset,
|
||||
step: c.step,
|
||||
}
|
||||
mlx.Pin(clone.keys, clone.values)
|
||||
return clone
|
||||
}
|
||||
|
||||
func (c *KVCache) Free() {
|
||||
mlx.Unpin(c.keys, c.values)
|
||||
c.keys, c.values = nil, nil
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
@@ -104,9 +112,10 @@ func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
|
||||
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
logutil.Trace("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = keys, values
|
||||
c.keys, c.values = keys.Clone(), values.Clone()
|
||||
mlx.Pin(c.keys, c.values)
|
||||
} else {
|
||||
if c.idx < c.keys.Dim(2) {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
@@ -130,7 +139,7 @@ func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
slog.Debug("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
logutil.Trace("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
||||
prev := c.offset
|
||||
@@ -145,6 +154,7 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
|
||||
c.values.Set(c.values.Concatenate(2, newValues))
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
mlx.Pin(c.keys, c.values)
|
||||
}
|
||||
c.idx = prev
|
||||
}
|
||||
|
||||
@@ -119,16 +119,13 @@ func NewClient(modelName string) (*Client, error) {
|
||||
stdout, _ := cmd.StdoutPipe()
|
||||
stderr, _ := cmd.StderrPipe()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
for scanner.Scan() {
|
||||
slog.Info("mlx-runner", "msg", scanner.Text())
|
||||
}
|
||||
io.Copy(os.Stderr, stdout) //nolint:errcheck
|
||||
}()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
slog.Warn("mlx-runner", "msg", line)
|
||||
fmt.Fprintln(os.Stderr, line)
|
||||
c.lastErrLock.Lock()
|
||||
c.lastErr = line
|
||||
c.lastErrLock.Unlock()
|
||||
|
||||
@@ -7,48 +7,29 @@ import "C"
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type tensorDesc struct {
|
||||
name string
|
||||
inputs []*Array
|
||||
numRefs int
|
||||
}
|
||||
|
||||
func (d tensorDesc) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("name", d.name),
|
||||
slog.Int("inputs", len(d.inputs)),
|
||||
slog.Int("num_refs", d.numRefs),
|
||||
)
|
||||
}
|
||||
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
desc tensorDesc
|
||||
ctx C.mlx_array
|
||||
name string
|
||||
pinned bool
|
||||
}
|
||||
|
||||
var arrays []*Array
|
||||
|
||||
// constructor utilities
|
||||
|
||||
func New(name string, inputs ...*Array) *Array {
|
||||
t := &Array{
|
||||
desc: tensorDesc{
|
||||
name: name,
|
||||
inputs: inputs,
|
||||
},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
input.desc.numRefs++
|
||||
}
|
||||
logutil.Trace("New", "t", t)
|
||||
func New(name string) *Array {
|
||||
t := &Array{name: name}
|
||||
arrays = append(arrays, t)
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -133,18 +114,51 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
|
||||
}
|
||||
|
||||
func (t *Array) Set(other *Array) {
|
||||
Free(t.desc.inputs...)
|
||||
other.desc.numRefs++
|
||||
t.desc.inputs = []*Array{other}
|
||||
C.mlx_array_set(&t.ctx, other.ctx)
|
||||
}
|
||||
|
||||
func (t *Array) Clone() *Array {
|
||||
tt := New(t.desc.name, t.desc.inputs...)
|
||||
tt := New(t.name)
|
||||
C.mlx_array_set(&tt.ctx, t.ctx)
|
||||
return tt
|
||||
}
|
||||
|
||||
// lifecycle utilities
|
||||
|
||||
// Pin marks arrays as in-use so they are retained during Sweep.
|
||||
func Pin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unpin marks arrays as no longer in-use, allowing Sweep to free them.
|
||||
func Unpin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
|
||||
// free them when there are no other references, including dependencies in the graph.
|
||||
func Sweep() {
|
||||
n := 0
|
||||
for _, t := range arrays {
|
||||
if t.pinned && t.Valid() {
|
||||
arrays[n] = t
|
||||
n++
|
||||
} else if t.Valid() {
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
arrays = arrays[:n]
|
||||
}
|
||||
|
||||
// misc. utilities
|
||||
|
||||
func (t *Array) Valid() bool {
|
||||
@@ -159,7 +173,10 @@ func (t *Array) String() string {
|
||||
}
|
||||
|
||||
func (t *Array) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{slog.Any("", t.desc)}
|
||||
attrs := []slog.Attr{
|
||||
slog.String("name", t.name),
|
||||
slog.Bool("pinned", t.pinned),
|
||||
}
|
||||
if t.Valid() {
|
||||
attrs = append(attrs,
|
||||
slog.Any("dtype", t.DType()),
|
||||
@@ -238,37 +255,15 @@ func (t Array) Save(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Free(s ...*Array) (n int) {
|
||||
now := time.Now()
|
||||
defer func() {
|
||||
if n > 0 {
|
||||
logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now))
|
||||
}
|
||||
}()
|
||||
// LogArrays logs all live arrays, sorted by size
|
||||
func LogArrays() {
|
||||
sort.Slice(arrays, func(i, j int) bool {
|
||||
return arrays[i].NumBytes() > arrays[j].NumBytes()
|
||||
})
|
||||
|
||||
free := make([]*Array, 0, 8192)
|
||||
fn := func(t *Array) {
|
||||
if t.Valid() {
|
||||
t.desc.numRefs--
|
||||
if t.desc.numRefs <= 0 {
|
||||
free = append(free, t.desc.inputs...)
|
||||
logutil.Trace("Free", "t", t)
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
for _, t := range arrays {
|
||||
nb := t.NumBytes()
|
||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims()))
|
||||
}
|
||||
|
||||
for _, t := range s {
|
||||
fn(t)
|
||||
}
|
||||
|
||||
for len(free) > 0 {
|
||||
tail := free[len(free)-1]
|
||||
free = free[:len(free)-1]
|
||||
fn(tail)
|
||||
}
|
||||
|
||||
return n
|
||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA", query, key, value, mask, sinks)
|
||||
out := New("FAST_SDPA")
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -31,7 +31,7 @@ type LayerNorm struct {
|
||||
}
|
||||
|
||||
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM", x)
|
||||
out := New("FAST_LAYERNORM")
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -41,7 +41,7 @@ type RMSNorm struct {
|
||||
}
|
||||
|
||||
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM", x)
|
||||
out := New("FAST_RMSNORM")
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -55,7 +55,7 @@ type RoPE struct {
|
||||
|
||||
func (r RoPE) Forward(t *Array, offset int) *Array {
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE", t, freqs)
|
||||
out := New("FAST_ROPE")
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
t.ctx,
|
||||
|
||||
@@ -37,7 +37,9 @@ func Load(path string) iter.Seq2[string, *Array] {
|
||||
}
|
||||
|
||||
name := C.GoString(key)
|
||||
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
|
||||
arr := New(name)
|
||||
arr.ctx = value
|
||||
if !yield(name, arr) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,43 +10,43 @@ import (
|
||||
)
|
||||
|
||||
func (t *Array) Abs() *Array {
|
||||
out := New("ABS", t)
|
||||
out := New("ABS")
|
||||
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Add(other *Array) *Array {
|
||||
out := New("ADD", t, other)
|
||||
out := New("ADD")
|
||||
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
|
||||
out := New("ADDMM", t, a, b)
|
||||
out := New("ADDMM")
|
||||
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Argmax(axis int, keepDims bool) *Array {
|
||||
out := New("ARGMAX", t)
|
||||
out := New("ARGMAX")
|
||||
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
|
||||
out := New("ARGPARTITION", t)
|
||||
out := New("ARGPARTITION")
|
||||
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgsortAxis(axis int) *Array {
|
||||
out := New("ARGSORT_AXIS", t)
|
||||
out := New("ARGSORT_AXIS")
|
||||
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) AsType(dtype DType) *Array {
|
||||
out := New("AS_TYPE", t)
|
||||
out := New("AS_TYPE")
|
||||
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
|
||||
cStrides[i] = C.int64_t(s)
|
||||
}
|
||||
|
||||
out := New("AS_STRIDED", t)
|
||||
out := New("AS_STRIDED")
|
||||
C.mlx_as_strided(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(cShape), C.size_t(len(shape)),
|
||||
@@ -82,31 +82,31 @@ func (t *Array) Concatenate(axis int, others ...*Array) *Array {
|
||||
C.mlx_vector_array_append_value(vector, other.ctx)
|
||||
}
|
||||
|
||||
out := New("CONCATENATE", s...)
|
||||
out := New("CONCATENATE")
|
||||
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Divide(other *Array) *Array {
|
||||
out := New("DIVIDE", t, other)
|
||||
out := New("DIVIDE")
|
||||
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ExpandDims(axis int) *Array {
|
||||
out := New("EXPAND_DIMS", t)
|
||||
out := New("EXPAND_DIMS")
|
||||
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Flatten(startAxis, endAxis int) *Array {
|
||||
out := New("FLATTEN", t)
|
||||
out := New("FLATTEN")
|
||||
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) FloorDivide(other *Array) *Array {
|
||||
out := New("FLOOR_DIVIDE", t, other)
|
||||
out := New("FLOOR_DIVIDE")
|
||||
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -118,43 +118,43 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
|
||||
if rhs == nil {
|
||||
rhs = New("")
|
||||
}
|
||||
out := New("GATHER_MM", t, other, lhs, rhs)
|
||||
out := New("GATHER_MM")
|
||||
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Logsumexp(keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP", t)
|
||||
out := New("LOGSUMEXP")
|
||||
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Matmul(other *Array) *Array {
|
||||
out := New("MATMUL", t, other)
|
||||
out := New("MATMUL")
|
||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Multiply(other *Array) *Array {
|
||||
out := New("MULTIPLY", t, other)
|
||||
out := New("MULTIPLY")
|
||||
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Negative() *Array {
|
||||
out := New("NEGATIVE", t)
|
||||
out := New("NEGATIVE")
|
||||
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Power(exponent *Array) *Array {
|
||||
out := New("POWER", t, exponent)
|
||||
out := New("POWER")
|
||||
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
|
||||
out := New("PUT_ALONG_AXIS", t, indices, values)
|
||||
out := New("PUT_ALONG_AXIS")
|
||||
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -165,25 +165,25 @@ func (t *Array) Reshape(axes ...int) *Array {
|
||||
cAxes[i] = C.int(axes[i])
|
||||
}
|
||||
|
||||
out := New("RESHAPE", t)
|
||||
out := New("RESHAPE")
|
||||
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sigmoid() *Array {
|
||||
out := New("SIGMOID", t)
|
||||
out := New("SIGMOID")
|
||||
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sqrt() *Array {
|
||||
out := New("SQRT", t)
|
||||
out := New("SQRT")
|
||||
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Squeeze(axis int) *Array {
|
||||
out := New("SQUEEZE", t)
|
||||
out := New("SQUEEZE")
|
||||
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -198,37 +198,37 @@ func (t *Array) StackAxis(axis int, others ...*Array) *Array {
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
out := New("STACK_AXIS", append(others, t)...)
|
||||
out := New("STACK_AXIS")
|
||||
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Subtract(other *Array) *Array {
|
||||
out := New("SUBTRACT", t, other)
|
||||
out := New("SUBTRACT")
|
||||
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
|
||||
out := New("SUM_AXIS", t)
|
||||
out := New("SUM_AXIS")
|
||||
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_AXIS", t, indices)
|
||||
out := New("TAKE_AXIS")
|
||||
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_ALONG_AXIS", t, indices)
|
||||
out := New("TAKE_ALONG_AXIS")
|
||||
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Tanh() *Array {
|
||||
out := New("TANH", t)
|
||||
out := New("TANH")
|
||||
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -239,7 +239,7 @@ func (t *Array) Transpose(axes ...int) *Array {
|
||||
cAxes[i] = C.int(axis)
|
||||
}
|
||||
|
||||
out := New("TRANSPOSE", t)
|
||||
out := New("TRANSPOSE")
|
||||
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -41,14 +41,12 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
optDtype := C.mlx_optional_dtype{has_value: false}
|
||||
|
||||
inputs := []*Array{w, scales}
|
||||
var b C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
inputs = append(inputs, biases)
|
||||
}
|
||||
|
||||
out := New("DEQUANTIZE", inputs...)
|
||||
out := New("DEQUANTIZE")
|
||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -59,14 +57,12 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
|
||||
inputs := []*Array{x, w, scales}
|
||||
var b C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
inputs = append(inputs, biases)
|
||||
}
|
||||
|
||||
out := New("QUANTIZED_MATMUL", inputs...)
|
||||
out := New("QUANTIZED_MATMUL")
|
||||
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -77,22 +73,18 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
|
||||
inputs := []*Array{x, w, scales}
|
||||
var b, lhs, rhs C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
inputs = append(inputs, biases)
|
||||
}
|
||||
if lhsIndices != nil {
|
||||
lhs = lhsIndices.ctx
|
||||
inputs = append(inputs, lhsIndices)
|
||||
}
|
||||
if rhsIndices != nil {
|
||||
rhs = rhsIndices.ctx
|
||||
inputs = append(inputs, rhsIndices)
|
||||
}
|
||||
|
||||
out := New("GATHER_QMM", inputs...)
|
||||
out := New("GATHER_QMM")
|
||||
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -104,7 +96,7 @@ func Tile(a *Array, reps []int32) *Array {
|
||||
for i, r := range reps {
|
||||
cReps[i] = C.int(r)
|
||||
}
|
||||
out := New("TILE", a)
|
||||
out := New("TILE")
|
||||
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -116,7 +108,7 @@ func Tri(n, m int32, k int) *Array {
|
||||
}
|
||||
|
||||
func Where(condition, a, b *Array) *Array {
|
||||
out := New("WHERE", condition, a, b)
|
||||
out := New("WHERE")
|
||||
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -131,7 +123,7 @@ func Stack(arrays []*Array, axis int) *Array {
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
out := New("STACK", arrays...)
|
||||
out := New("STACK")
|
||||
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -153,13 +145,13 @@ func Take(a *Array, indices *Array, axis int) *Array {
|
||||
}
|
||||
|
||||
func RSqrt(a *Array) *Array {
|
||||
out := New("RSQRT", a)
|
||||
out := New("RSQRT")
|
||||
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Mean(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("MEAN_AXIS", a)
|
||||
out := New("MEAN_AXIS")
|
||||
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -235,7 +227,7 @@ func SliceStartStop(a *Array, start, stop []int32) *Array {
|
||||
cStop[i] = C.int(stop[i])
|
||||
cStrides[i] = 1
|
||||
}
|
||||
out := New("SLICE", a)
|
||||
out := New("SLICE")
|
||||
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -257,7 +249,7 @@ func SiLU(a *Array) *Array {
|
||||
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE", x, freqs)
|
||||
out := New("FAST_ROPE")
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
@@ -289,13 +281,13 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA", q, k, v, mask, sinks)
|
||||
out := New("FAST_SDPA")
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM", x)
|
||||
out := New("FAST_RMSNORM")
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -322,7 +314,7 @@ func scalarWithDtype(s float32, a *Array) C.mlx_array {
|
||||
|
||||
func AddScalar(a *Array, s float32) *Array {
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("ADD_SCALAR", a)
|
||||
out := New("ADD_SCALAR")
|
||||
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
@@ -330,7 +322,7 @@ func AddScalar(a *Array, s float32) *Array {
|
||||
|
||||
func MulScalar(a *Array, s float32) *Array {
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("MUL_SCALAR", a)
|
||||
out := New("MUL_SCALAR")
|
||||
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
@@ -338,7 +330,7 @@ func MulScalar(a *Array, s float32) *Array {
|
||||
|
||||
func DivScalar(a *Array, s float32) *Array {
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("DIV_SCALAR", a)
|
||||
out := New("DIV_SCALAR")
|
||||
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
|
||||
@@ -7,7 +7,7 @@ import "C"
|
||||
|
||||
func (t *Array) Categorical(axis int) *Array {
|
||||
key := New("")
|
||||
out := New("", t, key)
|
||||
out := New("")
|
||||
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
||||
|
||||
func (t *Array) Slice(slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE", t)
|
||||
out := New("SLICE")
|
||||
C.mlx_slice(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
@@ -74,7 +74,7 @@ func (t *Array) Slice(slices ...slice) *Array {
|
||||
|
||||
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE_UPDATE", t, other)
|
||||
out := New("SLICE_UPDATE")
|
||||
C.mlx_slice_update(
|
||||
&out.ctx, t.ctx, other.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
|
||||
@@ -78,8 +78,21 @@ func New(root *model.Root) (Model, error) {
|
||||
return fn(root)
|
||||
}
|
||||
|
||||
// Weights returns the model's LoadWeights method, which encapsulates all
|
||||
// weight assignment and post-processing (MLA absorption, expert stacking).
|
||||
// Weights returns a function that loads model weights, then pins all
|
||||
// arrays reachable from the model struct and sweeps everything else.
|
||||
func Weights(m Model) func(map[string]*mlx.Array) error {
|
||||
return m.LoadWeights
|
||||
return func(tensors map[string]*mlx.Array) error {
|
||||
if err := m.LoadWeights(tensors); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
for _, arr := range collected {
|
||||
mlx.Pin(arr)
|
||||
}
|
||||
mlx.Sweep()
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,8 +17,10 @@ func QuantizationParams(quantization string) (groupSize, bits int, mode string)
|
||||
return 32, 4, "affine"
|
||||
case "MXFP8":
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8", "":
|
||||
case "FP8", "Q8", "INT8":
|
||||
return 64, 8, "affine"
|
||||
case "":
|
||||
return 0, 0, ""
|
||||
default:
|
||||
return 32, 8, "affine"
|
||||
}
|
||||
|
||||
@@ -4,10 +4,12 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
@@ -45,8 +47,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
n := min(2<<10, total-processed-1)
|
||||
temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
defer mlx.Free(temp)
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
mlx.Sweep()
|
||||
mlx.Eval(func() []*mlx.Array {
|
||||
s := make([]*mlx.Array, 2*len(caches))
|
||||
for i, c := range caches {
|
||||
@@ -65,11 +67,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
||||
return request.Sample(logprobs), logprobs
|
||||
sample := request.Sample(logprobs)
|
||||
|
||||
mlx.Pin(sample, logprobs)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
|
||||
return sample, logprobs
|
||||
}
|
||||
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
@@ -78,7 +85,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
||||
for i := range request.Options.MaxTokens {
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
mlx.AsyncEval(nextSample, nextLogprobs)
|
||||
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
@@ -91,6 +97,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
outputs = append(outputs, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
mlx.Unpin(nextSample, nextLogprobs)
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
@@ -102,7 +109,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
Token: int(output),
|
||||
}
|
||||
|
||||
mlx.Free(sample, logprobs)
|
||||
mlx.Unpin(sample, logprobs)
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
@@ -110,10 +117,19 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
}
|
||||
|
||||
mlx.Free(sample, logprobs)
|
||||
mlx.Unpin(sample, logprobs)
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
request.Responses <- final
|
||||
r.InsertCache(append(inputs, outputs...), caches)
|
||||
mlx.Sweep()
|
||||
|
||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||
mlx.LogArrays()
|
||||
if r.cache != nil {
|
||||
r.cache.LogCache()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -58,10 +58,10 @@ type Response struct {
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
CacheEntries map[int32]*CacheEntry
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache *CacheEntry
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
|
||||
@@ -40,8 +40,7 @@ func Execute(args []string) error {
|
||||
flagSet.Parse(args)
|
||||
|
||||
runner := Runner{
|
||||
Requests: make(chan Request),
|
||||
CacheEntries: make(map[int32]*CacheEntry),
|
||||
Requests: make(chan Request),
|
||||
}
|
||||
|
||||
if err := runner.Load(modelName); err != nil {
|
||||
|
||||
@@ -401,9 +401,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
if m.NormScaled == nil {
|
||||
return fmt.Errorf("missing precomputed final norm weight")
|
||||
}
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -702,9 +702,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
}
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -235,9 +235,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -252,9 +252,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user