Compare commits

..

15 Commits

Author SHA1 Message Date
Jeffrey Morgan
97b2a92299 lint 2026-02-22 22:11:47 -08:00
Jeffrey Morgan
0c82c9069a minimize changes 2026-02-22 21:54:39 -08:00
Jeffrey Morgan
a52a739dda lint 2026-02-22 21:33:08 -08:00
Jeffrey Morgan
906b046928 cleanup 2026-02-22 19:42:07 -08:00
Jeffrey Morgan
7469d659e7 models: improved support for LiquidAI LFM2 and LFM2.5 models 2026-02-22 19:19:10 -08:00
Jeffrey Morgan
a8c1ca226c vision support 2026-02-22 15:11:41 -08:00
Jeffrey Morgan
317cfb5610 wip 2026-02-22 15:11:41 -08:00
jmorganca
4e9d388f65 fix repeat messages causing crash 2026-02-22 15:11:40 -08:00
Jeffrey Morgan
0ade9205cc models: add nemotronh architecture support (#14356) 2026-02-22 15:09:14 -08:00
Parth Sareen
06edabdde1 cmd/config: install web search plugin to user-level extensions dir (#14362) 2026-02-22 02:17:03 -08:00
Jeffrey Morgan
8b4e5a82a8 mlx: remove noisy error output from dynamic library loading (#14346)
The recent change in #14322 added tryLoadByName() which attempts to
load libmlxc.dylib via rpath before searching directories. This is an
optimization for Homebrew installations where rpath is correctly set.

However, when rpath isn't set (which is the common case for app bundle
installations), dlopen fails and the CHECK macro prints an error to
stderr:

  ERROR - dynamic.c:21 - CHECK failed: handle->ctx != NULL

This error is misleading because it's an expected failure path - the
code correctly falls back to searching the executable directory and
loads the library successfully. The error message causes user confusion
and makes it appear that something is broken.

Replace the CHECK macro with a simple return code so the C code fails
silently. The Go code already handles error logging appropriately:
tryLoadByName() fails silently (intentional fallback), while
tryLoadFromDir() logs via slog.Error() when explicit path loading fails.
2026-02-20 23:46:07 -08:00
Parth Sareen
3445223311 cmd: openclaw onboarding (#14344) 2026-02-20 19:08:38 -08:00
Jeffrey Morgan
fa6c0127e6 app: expose server's default context length to UI (#14037)
Parse the default_num_ctx from the server's "vram-based default context"
log line and expose it through the inference compute API. This eliminates
duplicate VRAM tier calculation logic in the frontend.

- Add InferenceInfo struct with Computes and DefaultContextLength
- Rename GetInferenceComputer to GetInferenceInfo
- Handle missing default context line gracefully (older servers)
- Add DefaultContextLength to InferenceComputeResponse
- Update Settings UI to use server's default, disable slider while loading
- Add disabled prop to Slider component (grays out + hides handle)
- Migrate existing users with context_length=4096 to 0 (auto mode)
2026-02-20 18:56:30 -08:00
Patrick Devine
97323d1c68 consolidate the tokenizer (#14327)
This change adds a new x/tokenizer package which includes:
  * New BPE and SentencePiece tokenizers
  * Removing the dependency on the imagegen tokenizers
  * Fixes to multibyte decoding in the pipeline
  * Various correctness and benchmark tests

Not included in this PR is the WordPiece tokenizer for BERT models which will be
added when we add embedding models. The imagegen tokenizers will also be removed in
a follow-up PR.
2026-02-19 15:55:45 -08:00
natl-set
458dd1b9d9 mlx: try loading library via rpath before searching directories (#14322)
The existing code manually searches directories for libmlxc.* and passes
full paths to dlopen, bypassing the binary's rpath. This means MLX
libraries installed via package managers (e.g., Homebrew) aren't found
even when rpath is correctly set at link time.

This change adds a fallback that tries loading via rpath first (using
just the library name), before falling back to the existing directory
search. This follows standard Unix/macOS conventions and works with any
installation that sets rpath.

Fixes library loading on macOS with Homebrew-installed mlx-c without
requiring OLLAMA_LIBRARY_PATH environment variable.

Co-authored-by: Natl <nat@MacBook-Pro.local>
2026-02-19 10:55:02 -08:00
81 changed files with 9924 additions and 1595 deletions

View File

@@ -41,6 +41,11 @@ type InferenceCompute struct {
VRAM string
}
type InferenceInfo struct {
Computes []InferenceCompute
DefaultContextLength int
}
func New(s *store.Store, devMode bool) *Server {
p := resolvePath("ollama")
return &Server{store: s, bin: p, dev: devMode}
@@ -272,9 +277,12 @@ func openRotatingLog() (io.WriteCloser, error) {
// Attempt to retrieve inference compute information from the server
// log. Set ctx to timeout to control how long to wait for the logs to appear
func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
inference := []InferenceCompute{}
marker := regexp.MustCompile(`inference compute.*library=`)
func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) {
info := &InferenceInfo{}
computeMarker := regexp.MustCompile(`inference compute.*library=`)
defaultCtxMarker := regexp.MustCompile(`vram-based default context`)
defaultCtxRegex := regexp.MustCompile(`default_num_ctx=(\d+)`)
q := `inference compute.*%s=["]([^"]*)["]`
nq := `inference compute.*%s=(\S+)\s`
type regex struct {
@@ -340,8 +348,8 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
match := marker.FindStringSubmatch(line)
if len(match) > 0 {
// Check for inference compute lines
if computeMarker.MatchString(line) {
ic := InferenceCompute{
Library: get("library", line),
Variant: get("variant", line),
@@ -352,12 +360,25 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
}
slog.Info("Matched", "inference compute", ic)
inference = append(inference, ic)
} else {
// Break out on first non matching line after we start matching
if len(inference) > 0 {
return inference, nil
info.Computes = append(info.Computes, ic)
continue
}
// Check for default context length line
if defaultCtxMarker.MatchString(line) {
match := defaultCtxRegex.FindStringSubmatch(line)
if len(match) > 1 {
numCtx, err := strconv.Atoi(match[1])
if err == nil {
info.DefaultContextLength = numCtx
slog.Info("Matched default context length", "default_num_ctx", numCtx)
}
}
return info, nil
}
// If we've found compute info but hit a non-matching line, return what we have
// This handles older server versions that don't log the default context line
if len(info.Computes) > 0 {
return info, nil
}
}
time.Sleep(100 * time.Millisecond)

View File

@@ -205,44 +205,50 @@ func TestServerCmdCloudSettingEnv(t *testing.T) {
}
}
func TestGetInferenceComputer(t *testing.T) {
func TestGetInferenceInfo(t *testing.T) {
tests := []struct {
name string
log string
exp []InferenceCompute
name string
log string
expComputes []InferenceCompute
expDefaultCtxLen int
}{
{
name: "metal",
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
time=2025-06-30T09:23:07.417-07:00 level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="96.0 GiB" default_num_ctx=262144
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
`,
exp: []InferenceCompute{{
expComputes: []InferenceCompute{{
Library: "metal",
Driver: "0.0",
VRAM: "96.0 GiB",
}},
expDefaultCtxLen: 262144,
},
{
name: "cpu",
log: `time=2025-07-01T17:59:51.470Z level=INFO source=gpu.go:377 msg="no compatible GPUs were discovered"
time=2025-07-01T17:59:51.470Z level=INFO source=types.go:130 msg="inference compute" id=0 library=cpu variant="" compute="" driver=0.0 name="" total="31.3 GiB" available="30.4 GiB"
time=2025-07-01T17:59:51.471Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="31.3 GiB" default_num_ctx=32768
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
`,
exp: []InferenceCompute{{
expComputes: []InferenceCompute{{
Library: "cpu",
Driver: "0.0",
VRAM: "31.3 GiB",
}},
expDefaultCtxLen: 32768,
},
{
name: "cuda1",
log: `time=2025-07-01T19:33:43.162Z level=DEBUG source=amd_linux.go:419 msg="amdgpu driver not detected /sys/module/amdgpu"
releasing cuda driver library
time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference compute" id=GPU-452cac9f-6960-839c-4fb3-0cec83699196 library=cuda variant=v12 compute=6.1 driver=12.7 name="NVIDIA GeForce GT 1030" total="3.9 GiB" available="3.9 GiB"
time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="3.9 GiB" default_num_ctx=4096
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
`,
exp: []InferenceCompute{{
expComputes: []InferenceCompute{{
Library: "cuda",
Variant: "v12",
Compute: "6.1",
@@ -250,6 +256,7 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
Name: "NVIDIA GeForce GT 1030",
VRAM: "3.9 GiB",
}},
expDefaultCtxLen: 4096,
},
{
name: "frank",
@@ -257,9 +264,10 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
releasing cuda driver library
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-d6de3398-9932-6902-11ec-fee8e424c8a2 library=cuda variant=v12 compute=7.5 driver=12.8 name="NVIDIA GeForce RTX 2080 Ti" total="10.6 GiB" available="10.4 GiB"
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-9abb57639fa80c50 library=rocm variant="" compute=gfx1030 driver=6.3 name=1002:73bf total="16.0 GiB" available="1.3 GiB"
time=2025-07-01T19:36:13.316Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="26.6 GiB" default_num_ctx=32768
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
`,
exp: []InferenceCompute{
expComputes: []InferenceCompute{
{
Library: "cuda",
Variant: "v12",
@@ -276,6 +284,20 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
VRAM: "16.0 GiB",
},
},
expDefaultCtxLen: 32768,
},
{
name: "missing_default_context",
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
`,
expComputes: []InferenceCompute{{
Library: "metal",
Driver: "0.0",
VRAM: "96.0 GiB",
}},
expDefaultCtxLen: 0, // No default context line, should return 0
},
}
for _, tt := range tests {
@@ -288,18 +310,21 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
}
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
defer cancel()
ics, err := GetInferenceComputer(ctx)
info, err := GetInferenceInfo(ctx)
if err != nil {
t.Fatalf(" failed to get inference compute: %v", err)
t.Fatalf("failed to get inference info: %v", err)
}
if !reflect.DeepEqual(ics, tt.exp) {
t.Fatalf("got:\n%#v\nwant:\n%#v", ics, tt.exp)
if !reflect.DeepEqual(info.Computes, tt.expComputes) {
t.Fatalf("computes mismatch\ngot:\n%#v\nwant:\n%#v", info.Computes, tt.expComputes)
}
if info.DefaultContextLength != tt.expDefaultCtxLen {
t.Fatalf("default context length mismatch: got %d, want %d", info.DefaultContextLength, tt.expDefaultCtxLen)
}
})
}
}
func TestGetInferenceComputerTimeout(t *testing.T) {
func TestGetInferenceInfoTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
defer cancel()
tmpDir := t.TempDir()
@@ -308,7 +333,7 @@ func TestGetInferenceComputerTimeout(t *testing.T) {
if err != nil {
t.Fatalf("failed to write log file %s: %s", serverLogPath, err)
}
_, err = GetInferenceComputer(ctx)
_, err = GetInferenceInfo(ctx)
if err == nil {
t.Fatal("expected timeout")
}

View File

@@ -14,7 +14,7 @@ import (
// currentSchemaVersion defines the current database schema version.
// Increment this when making schema changes that require migrations.
const currentSchemaVersion = 13
const currentSchemaVersion = 14
// database wraps the SQLite connection.
// SQLite handles its own locking for concurrent access:
@@ -73,7 +73,7 @@ func (db *database) init() error {
agent BOOLEAN NOT NULL DEFAULT 0,
tools BOOLEAN NOT NULL DEFAULT 0,
working_dir TEXT NOT NULL DEFAULT '',
context_length INTEGER NOT NULL DEFAULT 4096,
context_length INTEGER NOT NULL DEFAULT 0,
window_width INTEGER NOT NULL DEFAULT 0,
window_height INTEGER NOT NULL DEFAULT 0,
config_migrated BOOLEAN NOT NULL DEFAULT 0,
@@ -251,6 +251,12 @@ func (db *database) migrate() error {
return fmt.Errorf("migrate v12 to v13: %w", err)
}
version = 13
case 13:
// change default context_length from 4096 to 0 (VRAM-based tiered defaults)
if err := db.migrateV13ToV14(); err != nil {
return fmt.Errorf("migrate v13 to v14: %w", err)
}
version = 14
default:
// If we have a version we don't recognize, just set it to current
// This might happen during development
@@ -474,6 +480,22 @@ func (db *database) migrateV12ToV13() error {
return nil
}
// migrateV13ToV14 changes the default context_length from 4096 to 0.
// When context_length is 0, the ollama server uses VRAM-based tiered defaults.
func (db *database) migrateV13ToV14() error {
_, err := db.conn.Exec(`UPDATE settings SET context_length = 0 WHERE context_length = 4096`)
if err != nil {
return fmt.Errorf("update context_length default: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 14`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
func (db *database) cleanupOrphanedData() error {
_, err := db.conn.Exec(`

View File

@@ -98,6 +98,43 @@ func TestSchemaMigrations(t *testing.T) {
})
}
func TestMigrationV13ToV14ContextLength(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
db, err := newDatabase(dbPath)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
_, err = db.conn.Exec("UPDATE settings SET context_length = 4096, schema_version = 13")
if err != nil {
t.Fatalf("failed to seed v13 settings row: %v", err)
}
if err := db.migrate(); err != nil {
t.Fatalf("migration from v13 to v14 failed: %v", err)
}
var contextLength int
if err := db.conn.QueryRow("SELECT context_length FROM settings").Scan(&contextLength); err != nil {
t.Fatalf("failed to read context_length: %v", err)
}
if contextLength != 0 {
t.Fatalf("expected context_length to migrate to 0, got %d", contextLength)
}
version, err := db.getSchemaVersion()
if err != nil {
t.Fatalf("failed to get schema version: %v", err)
}
if version != currentSchemaVersion {
t.Fatalf("expected schema version %d, got %d", currentSchemaVersion, version)
}
}
func TestChatDeletionWithCascade(t *testing.T) {
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
tmpDir := t.TempDir()

View File

@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS settings (
agent BOOLEAN NOT NULL DEFAULT 0,
tools BOOLEAN NOT NULL DEFAULT 0,
working_dir TEXT NOT NULL DEFAULT '',
context_length INTEGER NOT NULL DEFAULT 4096,
context_length INTEGER NOT NULL DEFAULT 0,
window_width INTEGER NOT NULL DEFAULT 0,
window_height INTEGER NOT NULL DEFAULT 0,
config_migrated BOOLEAN NOT NULL DEFAULT 0,

View File

@@ -289,10 +289,12 @@ export class InferenceCompute {
}
export class InferenceComputeResponse {
inferenceComputes: InferenceCompute[];
defaultContextLength: number;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.inferenceComputes = this.convertValues(source["inferenceComputes"], InferenceCompute);
this.defaultContextLength = source["defaultContextLength"];
}
convertValues(a: any, classs: any, asMap: boolean = false): any {

View File

@@ -4,7 +4,6 @@ import {
ChatEvent,
DownloadEvent,
ErrorEvent,
InferenceCompute,
InferenceComputeResponse,
ModelCapabilitiesResponse,
Model,
@@ -407,7 +406,7 @@ export async function* pullModel(
}
}
export async function getInferenceCompute(): Promise<InferenceCompute[]> {
export async function getInferenceCompute(): Promise<InferenceComputeResponse> {
const response = await fetch(`${API_BASE}/api/v1/inference-compute`);
if (!response.ok) {
throw new Error(
@@ -416,8 +415,7 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
}
const data = await response.json();
const inferenceComputeResponse = new InferenceComputeResponse(data);
return inferenceComputeResponse.inferenceComputes || [];
return new InferenceComputeResponse(data);
}
export async function fetchHealth(): Promise<boolean> {

View File

@@ -26,6 +26,7 @@ import {
type CloudStatusResponse,
updateCloudSetting,
updateSettings,
getInferenceCompute,
} from "@/api";
function AnimatedDots() {
@@ -77,6 +78,13 @@ export default function Settings() {
const settings = settingsData?.settings || null;
const { data: inferenceComputeResponse } = useQuery({
queryKey: ["inferenceCompute"],
queryFn: getInferenceCompute,
});
const defaultContextLength = inferenceComputeResponse?.defaultContextLength;
const updateSettingsMutation = useMutation({
mutationFn: updateSettings,
onSuccess: () => {
@@ -204,7 +212,7 @@ export default function Settings() {
Models: "",
Agent: false,
Tools: false,
ContextLength: 4096,
ContextLength: 0,
});
updateSettingsMutation.mutate(defaultSettings);
}
@@ -507,13 +515,11 @@ export default function Settings() {
</Description>
<div className="mt-3">
<Slider
value={(() => {
// Otherwise use the settings value
return settings.ContextLength || 4096;
})()}
value={settings.ContextLength || defaultContextLength || 0}
onChange={(value) => {
handleChange("ContextLength", value);
}}
disabled={!defaultContextLength}
options={[
{ value: 4096, label: "4k" },
{ value: 8192, label: "8k" },

View File

@@ -6,10 +6,11 @@ export interface SliderProps {
value?: number;
onChange?: (value: number) => void;
className?: string;
disabled?: boolean;
}
const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
({ label, options, value = 0, onChange }, ref) => {
({ label, options, value = 0, onChange, disabled = false }, ref) => {
const [selectedValue, setSelectedValue] = React.useState(value);
const [isDragging, setIsDragging] = React.useState(false);
const containerRef = React.useRef<HTMLDivElement>(null);
@@ -20,6 +21,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
}, [value]);
const handleClick = (optionValue: number) => {
if (disabled) return;
setSelectedValue(optionValue);
onChange?.(optionValue);
};
@@ -39,6 +41,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
};
const handleMouseDown = (e: React.MouseEvent) => {
if (disabled) return;
setIsDragging(true);
e.preventDefault();
};
@@ -77,7 +80,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
}
return (
<div className="space-y-2" ref={ref}>
<div className={`space-y-2 ${disabled ? "opacity-50" : ""}`} ref={ref}>
{label && <label className="text-sm font-medium">{label}</label>}
<div className="relative">
<div className="absolute top-[9px] left-2 right-2 h-1 bg-neutral-200 dark:bg-neutral-700 pointer-events-none rounded-full" />
@@ -88,10 +91,11 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
<button
onClick={() => handleClick(option.value)}
onMouseDown={handleMouseDown}
className="relative px-3 py-6 -mx-3 -my-6 z-10 cursor-pointer"
disabled={disabled}
className={`relative px-3 py-6 -mx-3 -my-6 z-10 ${disabled ? "cursor-not-allowed" : "cursor-pointer"}`}
>
<div className="relative w-5 h-5 flex items-center justify-center">
{selectedValue === option.value && (
{selectedValue === option.value && !disabled && (
<div className="w-4 h-4 bg-white dark:bg-white border border-neutral-400 dark:border-neutral-500 rounded-full cursor-grab active:cursor-grabbing" />
)}
</div>

View File

@@ -28,12 +28,14 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
currentChatId && currentChatId !== "new" ? currentChatId : "",
);
const { data: inferenceComputes = [] } = useQuery({
queryKey: ["inference-compute"],
const { data: inferenceComputeResponse } = useQuery({
queryKey: ["inferenceCompute"],
queryFn: getInferenceCompute,
enabled: !settings.selectedModel, // Only fetch if no model is selected
});
const inferenceComputes = inferenceComputeResponse?.inferenceComputes || [];
const totalVRAM = useMemo(
() => getTotalVRAM(inferenceComputes),
[inferenceComputes],

View File

@@ -45,7 +45,8 @@ type InferenceCompute struct {
}
type InferenceComputeResponse struct {
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
DefaultContextLength int `json:"defaultContextLength"`
}
type ModelCapabilitiesResponse struct {

View File

@@ -1420,11 +1420,6 @@ func (s *Server) getSettings(w http.ResponseWriter, r *http.Request) error {
settings.Models = envconfig.Models()
}
// set default context length if not set
if settings.ContextLength == 0 {
settings.ContextLength = 4096
}
// Include current runtime settings
settings.Agent = s.Agent
settings.Tools = s.Tools
@@ -1500,14 +1495,14 @@ func (s *Server) writeCloudStatus(w http.ResponseWriter) error {
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
defer cancel()
serverInferenceComputes, err := server.GetInferenceComputer(ctx)
info, err := server.GetInferenceInfo(ctx)
if err != nil {
s.log().Error("failed to get inference compute", "error", err)
return fmt.Errorf("failed to get inference compute: %w", err)
s.log().Error("failed to get inference info", "error", err)
return fmt.Errorf("failed to get inference info: %w", err)
}
inferenceComputes := make([]responses.InferenceCompute, len(serverInferenceComputes))
for i, ic := range serverInferenceComputes {
inferenceComputes := make([]responses.InferenceCompute, len(info.Computes))
for i, ic := range info.Computes {
inferenceComputes[i] = responses.InferenceCompute{
Library: ic.Library,
Variant: ic.Variant,
@@ -1519,7 +1514,8 @@ func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) err
}
response := responses.InferenceComputeResponse{
InferenceComputes: inferenceComputes,
InferenceComputes: inferenceComputes,
DefaultContextLength: info.DefaultContextLength,
}
w.Header().Set("Content-Type", "application/json")

View File

@@ -1956,6 +1956,10 @@ func runInteractiveTUI(cmd *cobra.Command) {
}
launchIntegration := func(name string) bool {
if err := config.EnsureInstalled(name); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return true
}
// If not configured or model no longer exists, prompt for model selection
configuredModel := config.IntegrationModel(name)
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {

View File

@@ -15,8 +15,9 @@ import (
)
type integration struct {
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
Onboarded bool `json:"onboarded,omitempty"`
}
type config struct {
@@ -139,34 +140,54 @@ func SaveIntegration(appName string, models []string) error {
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
var aliases map[string]string
if existing != nil && existing.Aliases != nil {
var onboarded bool
if existing != nil {
aliases = existing.Aliases
onboarded = existing.Onboarded
}
cfg.Integrations[key] = &integration{
Models: models,
Aliases: aliases,
Models: models,
Aliases: aliases,
Onboarded: onboarded,
}
return save(cfg)
}
// integrationOnboarded marks an integration as onboarded in ollama's config.
func integrationOnboarded(appName string) error {
cfg, err := load()
if err != nil {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
if existing == nil {
existing = &integration{}
}
existing.Onboarded = true
cfg.Integrations[key] = existing
return save(cfg)
}
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
func IntegrationModel(appName string) string {
ic, err := loadIntegration(appName)
if err != nil || len(ic.Models) == 0 {
integrationConfig, err := loadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 {
return ""
}
return ic.Models[0]
return integrationConfig.Models[0]
}
// IntegrationModels returns all configured models for an integration, or nil.
func IntegrationModels(appName string) []string {
ic, err := loadIntegration(appName)
if err != nil || len(ic.Models) == 0 {
integrationConfig, err := loadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 {
return nil
}
return ic.Models
return integrationConfig.Models
}
// LastModel returns the last model that was run, or empty string if none.
@@ -234,12 +255,12 @@ func loadIntegration(appName string) (*integration, error) {
return nil, err
}
ic, ok := cfg.Integrations[strings.ToLower(appName)]
integrationConfig, ok := cfg.Integrations[strings.ToLower(appName)]
if !ok {
return nil, os.ErrNotExist
}
return ic, nil
return integrationConfig, nil
}
func saveAliases(appName string, aliases map[string]string) error {
@@ -272,8 +293,8 @@ func listIntegrations() ([]integration, error) {
}
result := make([]integration, 0, len(cfg.Integrations))
for _, ic := range cfg.Integrations {
result = append(result, *ic)
for _, integrationConfig := range cfg.Integrations {
result = append(result, *integrationConfig)
}
return result, nil

View File

@@ -228,6 +228,31 @@ func IsIntegrationInstalled(name string) bool {
}
}
// AutoInstallable returns true if the integration can be automatically
// installed when not found (e.g. via npm).
func AutoInstallable(name string) bool {
switch strings.ToLower(name) {
case "openclaw", "clawdbot", "moltbot":
return true
default:
return false
}
}
// EnsureInstalled checks if an auto-installable integration is present and
// offers to install it if missing. Returns nil for non-auto-installable
// integrations or when the binary is already on PATH.
func EnsureInstalled(name string) error {
if !AutoInstallable(name) {
return nil
}
if IsIntegrationInstalled(name) {
return nil
}
_, err := ensureOpenclawInstalled()
return err
}
// IsEditorIntegration returns true if the named integration uses multi-model
// selection (implements the Editor interface).
func IsEditorIntegration(name string) bool {
@@ -926,6 +951,10 @@ Examples:
return fmt.Errorf("unknown integration: %s", name)
}
if err := EnsureInstalled(name); err != nil {
return err
}
if modelFlag != "" && IsCloudModelDisabled(cmd.Context(), modelFlag) {
modelFlag = ""
}

View File

@@ -1,81 +1,287 @@
package config
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
const defaultGatewayPort = 18789
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
var openclawModelShowTimeout = 5 * time.Second
type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" }
func (c *Openclaw) Run(model string, args []string) error {
bin := "openclaw"
if _, err := exec.LookPath(bin); err != nil {
bin = "clawdbot"
if _, err := exec.LookPath(bin); err != nil {
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
}
}
models := []string{model}
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
models = config.Models
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("openclaw", models, func() ([]string, error) {
return selectModels(context.Background(), "openclaw", "")
})
if errors.Is(err, errCancelled) {
return nil
}
bin, err := ensureOpenclawInstalled()
if err != nil {
return err
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
firstLaunch := true
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
firstLaunch = !integrationConfig.Onboarded
}
if firstLaunch {
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, " OpenClaw can read files and run actions when tools are enabled.\n")
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
ok, err := confirmPrompt("I understand the risks. Continue?")
if err != nil {
return err
}
if !ok {
return nil
}
}
if !c.onboarded() {
// Onboarding not completed: run it (model already set via Edit)
// Use "ollama" as gateway token for simple local access
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
cmd := exec.Command(bin, "onboard",
"--non-interactive",
"--accept-risk",
"--auth-choice", "skip",
"--gateway-token", "ollama",
"--install-daemon",
"--skip-channels",
"--skip-skills",
)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
if err := cmd.Run(); err != nil {
return windowsHint(fmt.Errorf("openclaw onboarding failed: %w\n\nTry running: openclaw onboard", err))
}
patchDeviceScopes()
// Onboarding overwrites openclaw.json, so re-apply the model config
// that Edit() wrote before Run() was called.
if err := c.Edit([]string{model}); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not re-apply model config: %v%s\n", ansiYellow, err, ansiReset)
}
}
// Onboarding completed: run gateway
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
cmd.Stdin = os.Stdin
if strings.HasSuffix(model, ":cloud") || strings.HasSuffix(model, "-cloud") {
if ensureWebSearchPlugin() {
registerWebSearchPlugin()
}
}
// Capture output to detect "already running" message
var outputBuf bytes.Buffer
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
if firstLaunch {
fmt.Fprintf(os.Stderr, "\n%sPreparing your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
} else {
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
}
err = cmd.Run()
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
// When extra args are passed through, run exactly what the user asked for
// after setup and skip the built-in gateway+TUI convenience flow.
if len(args) > 0 {
cmd := exec.Command(bin, args...)
cmd.Env = openclawEnv()
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return windowsHint(err)
}
if firstLaunch {
if err := integrationOnboarded("openclaw"); err != nil {
return fmt.Errorf("failed to save onboarding state: %w", err)
}
}
return nil
}
return err
token, port := c.gatewayInfo()
addr := fmt.Sprintf("localhost:%d", port)
// If the gateway is already running (e.g. via the daemon), restart it
// so it picks up any config changes from Edit() above (model, provider, etc.).
if portOpen(addr) {
restart := exec.Command(bin, "daemon", "restart")
restart.Env = openclawEnv()
if err := restart.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: daemon restart failed: %v%s\n", ansiYellow, err, ansiReset)
}
if !waitForPort(addr, 10*time.Second) {
fmt.Fprintf(os.Stderr, "%s Warning: gateway did not come back after restart%s\n", ansiYellow, ansiReset)
}
}
// If the gateway isn't running, start it as a background child process.
if !portOpen(addr) {
gw := exec.Command(bin, "gateway", "run", "--force")
gw.Env = openclawEnv()
if err := gw.Start(); err != nil {
return windowsHint(fmt.Errorf("failed to start gateway: %w", err))
}
defer func() {
if gw.Process != nil {
_ = gw.Process.Kill()
_ = gw.Wait()
}
}()
}
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
if !waitForPort(addr, 30*time.Second) {
return windowsHint(fmt.Errorf("gateway did not start on %s", addr))
}
printOpenclawReady(bin, token, port, firstLaunch)
tuiArgs := []string{"tui"}
if firstLaunch {
tuiArgs = append(tuiArgs, "--message", "Wake up, my friend!")
}
tui := exec.Command(bin, tuiArgs...)
tui.Env = openclawEnv()
tui.Stdin = os.Stdin
tui.Stdout = os.Stdout
tui.Stderr = os.Stderr
if err := tui.Run(); err != nil {
return windowsHint(err)
}
if firstLaunch {
if err := integrationOnboarded("openclaw"); err != nil {
return fmt.Errorf("failed to save onboarding state: %w", err)
}
}
return nil
}
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
func (c *Openclaw) gatewayInfo() (token string, port int) {
port = defaultGatewayPort
home, err := os.UserHomeDir()
if err != nil {
return "", port
}
for _, path := range []string{
filepath.Join(home, ".openclaw", "openclaw.json"),
filepath.Join(home, ".clawdbot", "clawdbot.json"),
} {
data, err := os.ReadFile(path)
if err != nil {
continue
}
var config map[string]any
if json.Unmarshal(data, &config) != nil {
continue
}
gw, _ := config["gateway"].(map[string]any)
if p, ok := gw["port"].(float64); ok && p > 0 {
port = int(p)
}
auth, _ := gw["auth"].(map[string]any)
if t, _ := auth["token"].(string); t != "" {
token = t
}
return token, port
}
return "", port
}
func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
u := fmt.Sprintf("http://localhost:%d", port)
if token != "" {
u += "/#token=" + url.QueryEscape(token)
}
fmt.Fprintf(os.Stderr, "\n%s✓ OpenClaw is running%s\n\n", ansiGreen, ansiReset)
fmt.Fprintf(os.Stderr, " Open the Web UI:\n")
fmt.Fprintf(os.Stderr, " %s\n\n", hyperlink(u, u))
if firstLaunch {
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
fmt.Fprintf(os.Stderr, "%s %s configure --section channels connect WhatsApp, Telegram, etc.%s\n", ansiGray, bin, ansiReset)
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
} else {
fmt.Fprintf(os.Stderr, "%sTip: connect WhatsApp, Telegram, and more with: %s configure --section channels%s\n", ansiGray, bin, ansiReset)
}
}
// openclawEnv returns the current environment with provider API keys cleared
// so openclaw only uses the Ollama gateway, not keys from the user's shell.
func openclawEnv() []string {
clear := map[string]bool{
"ANTHROPIC_API_KEY": true,
"ANTHROPIC_OAUTH_TOKEN": true,
"OPENAI_API_KEY": true,
"GEMINI_API_KEY": true,
"MISTRAL_API_KEY": true,
"GROQ_API_KEY": true,
"XAI_API_KEY": true,
"OPENROUTER_API_KEY": true,
}
var env []string
for _, e := range os.Environ() {
key, _, _ := strings.Cut(e, "=")
if !clear[key] {
env = append(env, e)
}
}
return env
}
// portOpen checks if a TCP port is currently accepting connections.
func portOpen(addr string) bool {
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
if err != nil {
return false
}
conn.Close()
return true
}
func waitForPort(addr string, timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
if err == nil {
conn.Close()
return true
}
time.Sleep(250 * time.Millisecond)
}
return false
}
func windowsHint(err error) error {
if runtime.GOOS != "windows" {
return err
}
return fmt.Errorf("%w\n\n"+
"OpenClaw runs best on WSL2.\n"+
"Quick setup: wsl --install\n"+
"Guide: https://docs.openclaw.ai/windows", err)
}
// onboarded checks if OpenClaw onboarding wizard was completed
@@ -107,6 +313,144 @@ func (c *Openclaw) onboarded() bool {
return lastRunAt != ""
}
// patchDeviceScopes upgrades the local CLI device's paired scopes to include
// operator.admin. Only patches the local device, not remote ones.
// Best-effort: silently returns on any error.
func patchDeviceScopes() {
home, err := os.UserHomeDir()
if err != nil {
return
}
deviceID := readLocalDeviceID(home)
if deviceID == "" {
return
}
path := filepath.Join(home, ".openclaw", "devices", "paired.json")
data, err := os.ReadFile(path)
if err != nil {
return
}
var devices map[string]map[string]any
if err := json.Unmarshal(data, &devices); err != nil {
return
}
dev, ok := devices[deviceID]
if !ok {
return
}
required := []string{
"operator.read",
"operator.admin",
"operator.approvals",
"operator.pairing",
}
changed := patchScopes(dev, "scopes", required)
if tokens, ok := dev["tokens"].(map[string]any); ok {
for _, tok := range tokens {
if tokenMap, ok := tok.(map[string]any); ok {
if patchScopes(tokenMap, "scopes", required) {
changed = true
}
}
}
}
if !changed {
return
}
out, err := json.MarshalIndent(devices, "", " ")
if err != nil {
return
}
_ = os.WriteFile(path, out, 0o600)
}
// readLocalDeviceID reads the local device ID from openclaw's identity file.
func readLocalDeviceID(home string) string {
data, err := os.ReadFile(filepath.Join(home, ".openclaw", "identity", "device-auth.json"))
if err != nil {
return ""
}
var auth map[string]any
if err := json.Unmarshal(data, &auth); err != nil {
return ""
}
id, _ := auth["deviceId"].(string)
return id
}
// patchScopes ensures obj[key] contains all required scopes. Returns true if
// any scopes were added.
func patchScopes(obj map[string]any, key string, required []string) bool {
existing, _ := obj[key].([]any)
have := make(map[string]bool, len(existing))
for _, s := range existing {
if str, ok := s.(string); ok {
have[str] = true
}
}
added := false
for _, s := range required {
if !have[s] {
existing = append(existing, s)
added = true
}
}
if added {
obj[key] = existing
}
return added
}
func ensureOpenclawInstalled() (string, error) {
if _, err := exec.LookPath("openclaw"); err == nil {
return "openclaw", nil
}
if _, err := exec.LookPath("clawdbot"); err == nil {
return "clawdbot", nil
}
if _, err := exec.LookPath("npm"); err != nil {
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
"Install Node.js first:\n" +
" https://nodejs.org/\n\n" +
"Then rerun:\n" +
" ollama launch\n" +
"and select OpenClaw")
}
ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?")
if err != nil {
return "", err
}
if !ok {
return "", fmt.Errorf("openclaw installation cancelled")
}
fmt.Fprintf(os.Stderr, "\nInstalling OpenClaw...\n")
cmd := exec.Command("npm", "install", "-g", "openclaw@latest")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("failed to install openclaw: %w", err)
}
if _, err := exec.LookPath("openclaw"); err != nil {
return "", fmt.Errorf("openclaw was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
return "openclaw", nil
}
func (c *Openclaw) Paths() []string {
home, _ := os.UserHomeDir()
p := filepath.Join(home, ".openclaw", "openclaw.json")
@@ -161,8 +505,7 @@ func (c *Openclaw) Edit(models []string) error {
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
// needed to register provider
ollama["apiKey"] = "ollama-local"
// TODO(parthsareen): potentially move to responses
ollama["api"] = "openai-completions"
ollama["api"] = "ollama"
// Build map of existing models to preserve user customizations
existingModels, _ := ollama["models"].([]any)
@@ -175,25 +518,13 @@ func (c *Openclaw) Edit(models []string) error {
}
}
client, _ := api.ClientFromEnvironment()
var newModels []any
for _, model := range models {
entry := map[string]any{
"id": model,
"name": model,
"reasoning": false,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
// TODO(parthsareen): get these values from API
"contextWindow": 131072,
"maxTokens": 16384,
}
for _, m := range models {
entry, _ := openclawModelConfig(context.Background(), client, m)
// Merge existing fields (user customizations)
if existing, ok := existingByID[model]; ok {
if existing, ok := existingByID[m]; ok {
for k, v := range existing {
if _, isNew := entry[k]; !isNew {
entry[k] = v
@@ -230,7 +561,213 @@ func (c *Openclaw) Edit(models []string) error {
if err != nil {
return err
}
return writeWithBackup(configPath, data)
if err := writeWithBackup(configPath, data); err != nil {
return err
}
// Clear any per-session model overrides so the new primary takes effect
// immediately rather than being shadowed by a cached modelOverride.
clearSessionModelOverride(models[0])
return nil
}
// clearSessionModelOverride removes per-session model overrides from the main
// agent session so the global primary model takes effect on the next TUI launch.
func clearSessionModelOverride(primary string) {
home, err := os.UserHomeDir()
if err != nil {
return
}
path := filepath.Join(home, ".openclaw", "agents", "main", "sessions", "sessions.json")
data, err := os.ReadFile(path)
if err != nil {
return
}
var sessions map[string]map[string]any
if json.Unmarshal(data, &sessions) != nil {
return
}
changed := false
for _, sess := range sessions {
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
delete(sess, "modelOverride")
delete(sess, "providerOverride")
sess["model"] = primary
changed = true
}
}
if !changed {
return
}
out, err := json.MarshalIndent(sessions, "", " ")
if err != nil {
return
}
_ = os.WriteFile(path, out, 0o600)
}
const webSearchNpmPackage = "@ollama/openclaw-web-search"
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
// present. Returns true if the extension is available.
func ensureWebSearchPlugin() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
if _, err := os.Stat(filepath.Join(pluginDir, "index.ts")); err == nil {
return true // already installed
}
npmBin, err := exec.LookPath("npm")
if err != nil {
return false
}
if err := os.MkdirAll(pluginDir, 0o755); err != nil {
return false
}
// Download the tarball via `npm pack`, extract it flat into the plugin dir.
pack := exec.Command(npmBin, "pack", webSearchNpmPackage, "--pack-destination", pluginDir)
out, err := pack.Output()
if err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not download web search plugin: %v%s\n", ansiYellow, err, ansiReset)
return false
}
tgzName := strings.TrimSpace(string(out))
tgzPath := filepath.Join(pluginDir, tgzName)
defer os.Remove(tgzPath)
tar := exec.Command("tar", "xzf", tgzPath, "--strip-components=1", "-C", pluginDir)
if err := tar.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not extract web search plugin: %v%s\n", ansiYellow, err, ansiReset)
return false
}
fmt.Fprintf(os.Stderr, "%s ✓ Installed web search plugin%s\n", ansiGreen, ansiReset)
return true
}
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
// config so the gateway activates it on next start. Best-effort; silently returns
// on any error.
func registerWebSearchPlugin() {
home, err := os.UserHomeDir()
if err != nil {
return
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
data, err := os.ReadFile(configPath)
if err != nil {
return
}
var config map[string]any
if json.Unmarshal(data, &config) != nil {
return
}
plugins, _ := config["plugins"].(map[string]any)
if plugins == nil {
plugins = make(map[string]any)
}
entries, _ := plugins["entries"].(map[string]any)
if entries == nil {
entries = make(map[string]any)
}
if _, ok := entries["openclaw-web-search"]; ok {
return // already registered
}
entries["openclaw-web-search"] = map[string]any{"enabled": true}
plugins["entries"] = entries
config["plugins"] = plugins
// Disable the built-in web search since our plugin replaces it.
tools, _ := config["tools"].(map[string]any)
if tools == nil {
tools = make(map[string]any)
}
web, _ := tools["web"].(map[string]any)
if web == nil {
web = make(map[string]any)
}
web["search"] = map[string]any{"enabled": false}
tools["web"] = web
config["tools"] = tools
out, err := json.MarshalIndent(config, "", " ")
if err != nil {
return
}
_ = os.WriteFile(configPath, out, 0o600)
}
// openclawModelConfig builds an OpenClaw model config entry with capability detection.
// The second return value indicates whether the model is a cloud (remote) model.
func openclawModelConfig(ctx context.Context, client *api.Client, modelID string) (map[string]any, bool) {
entry := map[string]any{
"id": modelID,
"name": modelID,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
}
if client == nil {
return entry, false
}
showCtx := ctx
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
showCtx, cancel = context.WithTimeout(ctx, openclawModelShowTimeout)
defer cancel()
}
resp, err := client.Show(showCtx, &api.ShowRequest{Model: modelID})
if err != nil {
return entry, false
}
// Set input types based on vision capability
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
entry["input"] = []any{"text", "image"}
}
// Set reasoning based on thinking capability
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
entry["reasoning"] = true
}
// Cloud models: use hardcoded limits for context/output tokens.
// Capability detection above still applies (vision, thinking).
if resp.RemoteModel != "" {
if l, ok := lookupCloudModelLimit(modelID); ok {
entry["contextWindow"] = l.Context
entry["maxTokens"] = l.Output
}
return entry, true
}
// Extract context window from ModelInfo (local models only)
for key, val := range resp.ModelInfo {
if strings.HasSuffix(key, ".context_length") {
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
entry["contextWindow"] = int(ctxLen)
}
break
}
}
return entry, false
}
func (c *Openclaw) Models() []string {

View File

@@ -1,11 +1,21 @@
package config
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestOpenclawIntegration(t *testing.T) {
@@ -26,6 +36,124 @@ func TestOpenclawIntegration(t *testing.T) {
})
}
func TestOpenclawRunPassthroughArgs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses a POSIX shell test binary")
}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
if err := integrationOnboarded("openclaw"); err != nil {
t.Fatal(err)
}
configDir := filepath.Join(tmpDir, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
}`), 0o644); err != nil {
t.Fatal(err)
}
bin := filepath.Join(tmpDir, "openclaw")
if err := os.WriteFile(bin, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \"$HOME/invocations.log\"\n"), 0o755); err != nil {
t.Fatal(err)
}
c := &Openclaw{}
if err := c.Run("llama3.2", []string{"gateway", "--someflag"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
data, err := os.ReadFile(filepath.Join(tmpDir, "invocations.log"))
if err != nil {
t.Fatal(err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) != 1 {
t.Fatalf("expected exactly 1 invocation, got %d: %v", len(lines), lines)
}
if lines[0] != "gateway --someflag" {
t.Fatalf("invocation = %q, want %q", lines[0], "gateway --someflag")
}
}
func TestOpenclawRunFirstLaunchPersistence(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses a POSIX shell test binary")
}
oldHook := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string) (bool, error) {
return true, nil
}
defer func() { DefaultConfirmPrompt = oldHook }()
t.Run("success persists onboarding flag", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
// Mark OpenClaw onboarding complete so Run takes passthrough path directly.
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatal(err)
}
c := &Openclaw{}
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
integrationConfig, err := loadIntegration("openclaw")
if err != nil {
t.Fatalf("loadIntegration() error = %v", err)
}
if !integrationConfig.Onboarded {
t.Fatal("expected onboarding flag to be persisted after successful run")
}
})
t.Run("failure does not persist onboarding flag", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
t.Fatal(err)
}
c := &Openclaw{}
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err == nil {
t.Fatal("expected run failure")
}
integrationConfig, err := loadIntegration("openclaw")
if err == nil && integrationConfig.Onboarded {
t.Fatal("expected onboarding flag to remain unset after failed run")
}
})
}
func TestOpenclawEdit(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
@@ -359,19 +487,16 @@ func TestOpenclawEditSchemaFields(t *testing.T) {
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
// Verify required schema fields
if entry["reasoning"] != false {
t.Error("reasoning should be false")
// Verify base schema fields (always set regardless of API availability)
if entry["id"] != "llama3.2" {
t.Errorf("id = %v, want llama3.2", entry["id"])
}
if entry["name"] != "llama3.2" {
t.Errorf("name = %v, want llama3.2", entry["name"])
}
if entry["input"] == nil {
t.Error("input should be set")
}
if entry["contextWindow"] == nil {
t.Error("contextWindow should be set")
}
if entry["maxTokens"] == nil {
t.Error("maxTokens should be set")
}
cost := entry["cost"].(map[string]any)
if cost["cacheRead"] == nil {
t.Error("cost.cacheRead should be set")
@@ -876,3 +1001,589 @@ func TestOpenclawOnboarded(t *testing.T) {
}
})
}
func TestOpenclawGatewayInfo(t *testing.T) {
c := &Openclaw{}
t.Run("returns defaults when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
token, port := c.gatewayInfo()
if token != "" {
t.Errorf("expected empty token, got %q", token)
}
if port != defaultGatewayPort {
t.Errorf("expected default port %d, got %d", defaultGatewayPort, port)
}
})
t.Run("reads token and port from config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"gateway": {
"port": 9999,
"auth": {"mode": "token", "token": "my-secret"}
}
}`), 0o644)
token, port := c.gatewayInfo()
if token != "my-secret" {
t.Errorf("expected token %q, got %q", "my-secret", token)
}
if port != 9999 {
t.Errorf("expected port 9999, got %d", port)
}
})
t.Run("uses default port when not in config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"gateway": {"auth": {"token": "tok"}}
}`), 0o644)
token, port := c.gatewayInfo()
if token != "tok" {
t.Errorf("expected token %q, got %q", "tok", token)
}
if port != defaultGatewayPort {
t.Errorf("expected default port %d, got %d", defaultGatewayPort, port)
}
})
t.Run("falls back to legacy clawdbot config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
"gateway": {"port": 12345, "auth": {"token": "legacy-token"}}
}`), 0o644)
token, port := c.gatewayInfo()
if token != "legacy-token" {
t.Errorf("expected token %q, got %q", "legacy-token", token)
}
if port != 12345 {
t.Errorf("expected port 12345, got %d", port)
}
})
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
token, port := c.gatewayInfo()
if token != "" {
t.Errorf("expected empty token, got %q", token)
}
if port != defaultGatewayPort {
t.Errorf("expected default port, got %d", port)
}
})
t.Run("handles missing gateway section", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
token, port := c.gatewayInfo()
if token != "" {
t.Errorf("expected empty token, got %q", token)
}
if port != defaultGatewayPort {
t.Errorf("expected default port, got %d", port)
}
})
}
func TestPrintOpenclawReady(t *testing.T) {
t.Run("includes port in URL", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "", 9999, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
if !strings.Contains(output, "localhost:9999") {
t.Errorf("expected port 9999 in output, got:\n%s", output)
}
if strings.Contains(output, "#token=") {
t.Error("should not include token fragment when token is empty")
}
})
t.Run("URL-escapes token", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "my token&special=chars", defaultGatewayPort, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
escaped := url.QueryEscape("my token&special=chars")
if !strings.Contains(output, "#token="+escaped) {
t.Errorf("expected URL-escaped token %q in output, got:\n%s", escaped, output)
}
})
t.Run("simple token is not mangled", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
if !strings.Contains(output, "#token=ollama") {
t.Errorf("expected #token=ollama in output, got:\n%s", output)
}
})
t.Run("includes web UI hint", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "", defaultGatewayPort, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
if !strings.Contains(output, "Open the Web UI") {
t.Errorf("expected web UI hint in output, got:\n%s", output)
}
})
t.Run("first launch shows quick start tips", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, true)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
for _, want := range []string{"/help", "channels", "skills", "gateway"} {
if !strings.Contains(output, want) {
t.Errorf("expected %q in first-launch output, got:\n%s", want, output)
}
}
})
t.Run("subsequent launch shows single tip", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
if !strings.Contains(output, "Tip:") {
t.Errorf("expected single tip line, got:\n%s", output)
}
if strings.Contains(output, "Quick start") {
t.Errorf("should not show quick start on subsequent launch")
}
})
}
func TestOpenclawModelConfig(t *testing.T) {
t.Run("nil client returns base config", func(t *testing.T) {
cfg, _ := openclawModelConfig(context.Background(), nil, "llama3.2")
if cfg["id"] != "llama3.2" {
t.Errorf("id = %v, want llama3.2", cfg["id"])
}
if cfg["name"] != "llama3.2" {
t.Errorf("name = %v, want llama3.2", cfg["name"])
}
if cfg["cost"] == nil {
t.Error("cost should be set")
}
// Should not have capability fields without API
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set without API")
}
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set without API")
}
})
t.Run("sets vision input when model has vision capability", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{"llama.context_length":4096}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "llava:7b")
input, ok := cfg["input"].([]any)
if !ok || len(input) != 2 {
t.Errorf("input = %v, want [text image]", cfg["input"])
}
})
t.Run("sets text-only input when model lacks vision", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["completion"],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "llama3.2")
input, ok := cfg["input"].([]any)
if !ok || len(input) != 1 {
t.Errorf("input = %v, want [text]", cfg["input"])
}
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set for non-thinking model")
}
})
t.Run("sets reasoning when model has thinking capability", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "qwq")
if cfg["reasoning"] != true {
t.Error("expected reasoning = true for thinking model")
}
})
t.Run("extracts context window from model info", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":131072}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "llama3.2")
if cfg["contextWindow"] != 131072 {
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
}
})
t.Run("handles all capabilities together", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["vision","thinking"],"model_info":{"qwen3.context_length":32768}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "qwen3-vision")
input, ok := cfg["input"].([]any)
if !ok || len(input) != 2 {
t.Errorf("input = %v, want [text image]", cfg["input"])
}
if cfg["reasoning"] != true {
t.Error("expected reasoning = true")
}
if cfg["contextWindow"] != 32768 {
t.Errorf("contextWindow = %v, want 32768", cfg["contextWindow"])
}
})
t.Run("returns base config when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "missing-model")
if cfg["id"] != "missing-model" {
t.Errorf("id = %v, want missing-model", cfg["id"])
}
// Should still have input (default)
if cfg["input"] == nil {
t.Error("input should always be set")
}
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set when show fails")
}
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set when show fails")
}
})
t.Run("times out slow show and returns base config", func(t *testing.T) {
oldTimeout := openclawModelShowTimeout
openclawModelShowTimeout = 50 * time.Millisecond
t.Cleanup(func() { openclawModelShowTimeout = oldTimeout })
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
time.Sleep(300 * time.Millisecond)
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{"llama.context_length":4096}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
start := time.Now()
cfg, _ := openclawModelConfig(context.Background(), client, "slow-model")
elapsed := time.Since(start)
if elapsed >= 250*time.Millisecond {
t.Fatalf("openclawModelConfig took too long: %v", elapsed)
}
if cfg["id"] != "slow-model" {
t.Errorf("id = %v, want slow-model", cfg["id"])
}
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set on timeout")
}
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set on timeout")
}
})
t.Run("skips zero context length", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":0}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "test-model")
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set for zero value")
}
})
t.Run("cloud model uses hardcoded limits", func(t *testing.T) {
// Use a model name that's in cloudModelLimits and make the server
// report it as a remote/cloud model
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.5"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.5:cloud")
if !isCloud {
t.Error("expected isCloud = true for cloud model")
}
if cfg["contextWindow"] != 204_800 {
t.Errorf("contextWindow = %v, want 204800", cfg["contextWindow"])
}
if cfg["maxTokens"] != 128_000 {
t.Errorf("maxTokens = %v, want 128000", cfg["maxTokens"])
}
})
t.Run("cloud model with vision capability gets image input", func(t *testing.T) {
// Regression test: cloud models must not skip capability detection.
// A cloud model that reports vision capability should have input: [text, image].
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{},"remote_model":"qwen3-vl"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, isCloud := openclawModelConfig(context.Background(), client, "qwen3-vl:235b-cloud")
if !isCloud {
t.Error("expected isCloud = true for cloud vision model")
}
input, ok := cfg["input"].([]any)
if !ok || len(input) != 2 {
t.Errorf("input = %v, want [text image] for cloud vision model", cfg["input"])
}
})
t.Run("cloud model with thinking capability gets reasoning flag", func(t *testing.T) {
// Regression test: cloud models must not skip capability detection.
// A cloud model that reports thinking capability should have reasoning: true.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{},"remote_model":"qwq-cloud"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, isCloud := openclawModelConfig(context.Background(), client, "qwq:cloud")
if !isCloud {
t.Error("expected isCloud = true for cloud thinking model")
}
if cfg["reasoning"] != true {
t.Error("expected reasoning = true for cloud thinking model")
}
})
}
func TestIntegrationOnboarded(t *testing.T) {
t.Run("returns false when not set", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
integrationConfig, err := loadIntegration("openclaw")
if err == nil && integrationConfig.Onboarded {
t.Error("expected false for fresh config")
}
})
t.Run("returns true after integrationOnboarded", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
if err := integrationOnboarded("openclaw"); err != nil {
t.Fatal(err)
}
integrationConfig, err := loadIntegration("openclaw")
if err != nil || !integrationConfig.Onboarded {
t.Error("expected true after integrationOnboarded")
}
})
t.Run("is case insensitive", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
if err := integrationOnboarded("OpenClaw"); err != nil {
t.Fatal(err)
}
integrationConfig, err := loadIntegration("openclaw")
if err != nil || !integrationConfig.Onboarded {
t.Error("expected true when set with different case")
}
})
t.Run("preserves existing integration data", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
if err := SaveIntegration("openclaw", []string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
if err := integrationOnboarded("openclaw"); err != nil {
t.Fatal(err)
}
// Verify onboarded is set
integrationConfig, err := loadIntegration("openclaw")
if err != nil || !integrationConfig.Onboarded {
t.Error("expected true after integrationOnboarded")
}
// Verify models are preserved
model := IntegrationModel("openclaw")
if model != "llama3.2" {
t.Errorf("expected first model llama3.2, got %q", model)
}
})
}

View File

@@ -10,10 +10,11 @@ import (
// ANSI escape sequences for terminal formatting.
const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiYellow = "\033[33m"
)
// ErrCancelled is returned when the user cancels a selection.

View File

@@ -524,7 +524,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case "enter", " ":
item := m.items[m.cursor]
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) && !config.AutoInstallable(item.integration) {
return m, nil
}
@@ -555,6 +555,12 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
item := m.items[m.cursor]
if item.integration != "" || item.isRunModel {
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
if config.AutoInstallable(item.integration) {
// Auto-installable: select to trigger install flow
m.selected = true
m.quitting = true
return m, tea.Quit
}
return m, nil
}
if item.integration != "" && config.IsEditorIntegration(item.integration) {
@@ -618,7 +624,11 @@ func (m model) View() string {
var modelSuffix string
if item.integration != "" {
if !isInstalled {
title += " " + notInstalledStyle.Render("(not installed)")
if config.AutoInstallable(item.integration) {
title += " " + notInstalledStyle.Render("(install)")
} else {
title += " " + notInstalledStyle.Render("(not installed)")
}
} else if m.cursor == i {
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
modelSuffix = " " + modelStyle.Render("("+mdl+")")
@@ -634,7 +644,9 @@ func (m model) View() string {
desc := item.description
if !isInstalled && item.integration != "" && m.cursor == i {
if hint := config.IntegrationInstallHint(item.integration); hint != "" {
if config.AutoInstallable(item.integration) {
desc = "Press enter to install"
} else if hint := config.IntegrationInstallHint(item.integration); hint != "" {
desc = hint
} else {
desc = "not installed"

View File

@@ -257,10 +257,11 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
if err != nil {
return nil, nil, err
}
bts = sanitizeNonFiniteJSON(bts)
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("parse config.json: %w", err)
}
if len(p.Architectures) < 1 {
@@ -315,16 +316,20 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &glm4MoeLiteModel{}
case "GlmOcrForConditionalGeneration":
conv = &glmOcrModel{}
case "Lfm2ForCausalLM":
case "Lfm2ForCausalLM", "Lfm2MoeForCausalLM":
conv = &lfm2Model{}
case "Lfm2VlForConditionalGeneration":
conv = &lfm2VLTextModel{}
case "Qwen3NextForCausalLM":
conv = &qwen3NextModel{}
case "NemotronHForCausalLM":
conv = &nemotronHModel{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("parse config.json for %q: %w", p.Architectures[0], err)
}
if t, ok := conv.(moreParser); ok {

View File

@@ -1,6 +1,7 @@
package convert
import (
"fmt"
"slices"
"strings"
@@ -13,25 +14,110 @@ type lfm2Model struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
IntermediateSize uint32 `json:"intermediate_size"`
BlockFFDim uint32 `json:"block_ff_dim"`
BlockMultipleOf uint32 `json:"block_multiple_of"`
BlockAutoAdjustFFDim bool `json:"block_auto_adjust_ff_dim"`
BlockFFNDimMultiplier float32 `json:"block_ffn_dim_multiplier"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
NormEps float32 `json:"norm_eps"`
ConvLCache uint32 `json:"conv_L_cache"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
NumExperts uint32 `json:"num_experts"`
NumLocalExperts uint32 `json:"num_local_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NumDenseLayers uint32 `json:"num_dense_layers"`
LayerTypes []string `json:"layer_types"`
TieEmbedding bool `json:"tie_embedding"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
} `json:"rope_parameters"`
}
var _ ModelConverter = (*lfm2Model)(nil)
const (
defaultMaxPositionEmbeddings = uint32(128_000)
fallbackContextLength = uint32(32_768)
)
func (p *lfm2Model) isMoE() bool {
return p.ModelType == "lfm2_moe" || p.expertCount() > 0
}
func (p *lfm2Model) ropeFreqBase() float32 {
if p.RopeTheta != 0 {
return p.RopeTheta
}
return p.RopeParameters.RopeTheta
}
func (p *lfm2Model) expertCount() uint32 {
if p.NumLocalExperts > 0 {
return p.NumLocalExperts
}
return p.NumExperts
}
func (p *lfm2Model) feedForwardLength() uint32 {
ff := p.IntermediateSize
if p.BlockFFDim != 0 {
ff = p.BlockFFDim
}
if !p.BlockAutoAdjustFFDim || p.BlockMultipleOf == 0 {
return ff
}
ff = (2 * ff) / 3
// Keep default multiplier behavior consistent with llama.cpp conversion.
if p.BlockFFNDimMultiplier != 0 {
ff = uint32(float32(ff) * p.BlockFFNDimMultiplier)
}
m := p.BlockMultipleOf
return m * ((ff + m - 1) / m)
}
func (p *lfm2Model) hasKnownContextLengthFallbackSignature() bool {
return p.isMoE() &&
p.VocabSize == 65536 &&
p.HiddenSize == 2048 &&
p.NumHiddenLayers == 40 &&
p.IntermediateSize == 11776 &&
p.NumAttentionHeads == 32 &&
p.NumKeyValueHeads == 8 &&
p.NumDenseLayers == 2 &&
p.expertCount() == 64 &&
p.NumExpertsPerToken == 4 &&
p.MoEIntermediateSize == 1536
}
func (p *lfm2Model) contextLength() uint32 {
if p.MaxPositionEmbeddings == defaultMaxPositionEmbeddings && p.hasKnownContextLengthFallbackSignature() {
return fallbackContextLength
}
return p.MaxPositionEmbeddings
}
func (p *lfm2Model) KV(t *Tokenizer) KV {
architecture := "lfm2"
if p.isMoE() {
architecture = "lfm2moe"
}
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "lfm2"
kv["lfm2.vocab_size"] = p.VocabSize
kv["lfm2.block_count"] = p.NumHiddenLayers
kv["lfm2.embedding_length"] = p.HiddenSize
kv["lfm2.feed_forward_length"] = p.IntermediateSize
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
kv["general.architecture"] = architecture
kv["tokenizer.ggml.pre"] = "lfm2"
kv["vocab_size"] = p.VocabSize
kv["block_count"] = p.NumHiddenLayers
kv["embedding_length"] = p.HiddenSize
kv["feed_forward_length"] = p.feedForwardLength()
kv["context_length"] = p.contextLength()
// Build per-layer KV head count array based on layer_types
// (0 = shortconv layer, non-zero = attention layer with that many KV heads)
@@ -42,13 +128,24 @@ func (p *lfm2Model) KV(t *Tokenizer) KV {
}
}
kv["lfm2.attention.head_count"] = p.NumAttentionHeads
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
kv["lfm2.rope.freq_base"] = p.RopeTheta
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
kv["attention.head_count"] = p.NumAttentionHeads
kv["attention.head_count_kv"] = kvHeadCounts
kv["attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
kv["attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
kv["attention.layer_norm_rms_epsilon"] = p.NormEps
kv["shortconv.l_cache"] = p.ConvLCache
if ropeFreqBase := p.ropeFreqBase(); ropeFreqBase != 0 {
kv["rope.freq_base"] = ropeFreqBase
}
if p.isMoE() {
kv["expert_count"] = p.expertCount()
kv["expert_used_count"] = p.NumExpertsPerToken
kv["expert_feed_forward_length"] = p.MoEIntermediateSize
kv["leading_dense_block_count"] = p.NumDenseLayers
kv["expert_gating_func"] = uint32(2) // sigmoid
}
return kv
}
@@ -56,6 +153,30 @@ func (p *lfm2Model) KV(t *Tokenizer) KV {
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
if p.isMoE() {
merges := make([]merge, 0, p.NumHiddenLayers*3)
for i := range p.NumHiddenLayers {
if i < p.NumDenseLayers {
continue
}
merges = append(merges, merge{
fmt.Sprintf("blk.%d.feed_forward.experts.*.w1.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.feed_forward.experts.*.w2.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.feed_forward.experts.*.w3.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
})
}
merged, remaining := mergeTensors(ts, merges...)
out = append(out, merged...)
ts = remaining
}
for _, t := range ts {
shape := t.Shape()
@@ -80,7 +201,7 @@ func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
func (p *lfm2Model) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.embedding_norm", "output_norm",
"model.embedding_norm", "token_embd_norm",
"model.layers", "blk",
"operator_norm", "attn_norm",
"self_attn.q_proj", "attn_q",
@@ -92,6 +213,8 @@ func (p *lfm2Model) Replacements() []string {
"conv.conv", "shortconv.conv",
"conv.in_proj", "shortconv.in_proj",
"conv.out_proj", "shortconv.out_proj",
"feed_forward.gate", "ffn_gate_inp",
"feed_forward.expert_bias", "exp_probs_b.bias",
"feed_forward.w1", "ffn_gate",
"feed_forward.w2", "ffn_down",
"feed_forward.w3", "ffn_up",

View File

@@ -0,0 +1,271 @@
package convert
import (
"io"
"slices"
"strings"
"testing"
)
type lfm2StubTensor struct {
tensorBase
}
func newLFM2StubTensor(name string, shape []uint64) *lfm2StubTensor {
return &lfm2StubTensor{
tensorBase: tensorBase{
name: name,
shape: shape,
},
}
}
func (t *lfm2StubTensor) WriteTo(io.Writer) (int64, error) {
return 0, nil
}
func (t *lfm2StubTensor) Clone() Tensor {
return &lfm2StubTensor{
tensorBase: tensorBase{
name: t.name,
shape: slices.Clone(t.shape),
},
}
}
func TestLFM2MoEKV(t *testing.T) {
var p lfm2Model
p.ModelParameters.ModelType = "lfm2_moe"
p.VocabSize = 65536
p.HiddenSize = 2048
p.NumHiddenLayers = 4
p.MaxPositionEmbeddings = 128000
p.IntermediateSize = 11776
p.NumAttentionHeads = 32
p.NumKeyValueHeads = 8
p.LayerTypes = []string{"conv", "full_attention", "conv", "full_attention"}
p.NormEps = 1e-5
p.ConvLCache = 3
p.MoEIntermediateSize = 1536
p.NumExperts = 64
p.NumExpertsPerToken = 4
p.NumDenseLayers = 2
p.RopeParameters.RopeTheta = 1_000_000
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["general.architecture"], "lfm2moe"; got != want {
t.Fatalf("general.architecture = %v, want %v", got, want)
}
if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want {
t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want)
}
if got, want := kv["expert_count"], uint32(64); got != want {
t.Fatalf("expert_count = %v, want %v", got, want)
}
if got, want := kv["expert_used_count"], uint32(4); got != want {
t.Fatalf("expert_used_count = %v, want %v", got, want)
}
if got, want := kv["expert_feed_forward_length"], uint32(1536); got != want {
t.Fatalf("expert_feed_forward_length = %v, want %v", got, want)
}
if got, want := kv["leading_dense_block_count"], uint32(2); got != want {
t.Fatalf("leading_dense_block_count = %v, want %v", got, want)
}
if got, want := kv["expert_gating_func"], uint32(2); got != want {
t.Fatalf("expert_gating_func = %v, want %v", got, want)
}
gotHeadCounts, ok := kv["attention.head_count_kv"].([]uint32)
if !ok {
t.Fatalf("attention.head_count_kv has unexpected type %T", kv["attention.head_count_kv"])
}
wantHeadCounts := []uint32{0, 8, 0, 8}
if !slices.Equal(gotHeadCounts, wantHeadCounts) {
t.Fatalf("attention.head_count_kv = %v, want %v", gotHeadCounts, wantHeadCounts)
}
if got, want := kv["rope.freq_base"], float32(1_000_000); got != want {
t.Fatalf("rope.freq_base = %v, want %v", got, want)
}
}
func TestLFM2DenseKV(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 32000},
HiddenSize: 1024,
NumHiddenLayers: 2,
MaxPositionEmbeddings: 32768,
IntermediateSize: 4096,
NumAttentionHeads: 16,
NumKeyValueHeads: 4,
LayerTypes: []string{"conv", "full_attention"},
NormEps: 1e-5,
ConvLCache: 3,
RopeTheta: 10000,
}
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["general.architecture"], "lfm2"; got != want {
t.Fatalf("general.architecture = %v, want %v", got, want)
}
if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want {
t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want)
}
if _, ok := kv["expert_count"]; ok {
t.Fatalf("expert_count should not be set for dense lfm2")
}
}
func TestLFM2MoETensors(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2_moe"},
NumHiddenLayers: 4,
NumDenseLayers: 2,
}
in := []Tensor{
newLFM2StubTensor("blk.2.feed_forward.experts.0.w1.weight", []uint64{1536, 2048}),
newLFM2StubTensor("blk.2.feed_forward.experts.1.w1.weight", []uint64{1536, 2048}),
newLFM2StubTensor("blk.2.feed_forward.experts.0.w2.weight", []uint64{2048, 1536}),
newLFM2StubTensor("blk.2.feed_forward.experts.1.w2.weight", []uint64{2048, 1536}),
newLFM2StubTensor("blk.2.feed_forward.experts.0.w3.weight", []uint64{1536, 2048}),
newLFM2StubTensor("blk.2.feed_forward.experts.1.w3.weight", []uint64{1536, 2048}),
newLFM2StubTensor("blk.0.shortconv.conv.weight", []uint64{2048, 1, 3}),
}
out := p.Tensors(in)
byName := make(map[string][]uint64, len(out))
for _, tns := range out {
byName[tns.Name] = tns.Shape
}
if got, ok := byName["blk.2.ffn_gate_exps.weight"]; !ok {
t.Fatalf("missing merged tensor blk.2.ffn_gate_exps.weight")
} else if !slices.Equal(got, []uint64{2, 1536, 2048}) {
t.Fatalf("blk.2.ffn_gate_exps.weight shape = %v, want [2 1536 2048]", got)
}
if got, ok := byName["blk.2.ffn_down_exps.weight"]; !ok {
t.Fatalf("missing merged tensor blk.2.ffn_down_exps.weight")
} else if !slices.Equal(got, []uint64{2, 2048, 1536}) {
t.Fatalf("blk.2.ffn_down_exps.weight shape = %v, want [2 2048 1536]", got)
}
if got, ok := byName["blk.2.ffn_up_exps.weight"]; !ok {
t.Fatalf("missing merged tensor blk.2.ffn_up_exps.weight")
} else if !slices.Equal(got, []uint64{2, 1536, 2048}) {
t.Fatalf("blk.2.ffn_up_exps.weight shape = %v, want [2 1536 2048]", got)
}
if got, ok := byName["blk.0.shortconv.conv.weight"]; !ok {
t.Fatalf("missing shortconv tensor")
} else if !slices.Equal(got, []uint64{2048, 3}) {
t.Fatalf("blk.0.shortconv.conv.weight shape = %v, want [2048 3]", got)
}
if _, ok := byName["blk.2.feed_forward.experts.0.w1.weight"]; ok {
t.Fatalf("unmerged expert tensor should not be present")
}
}
func TestLFM2MoEReplacements(t *testing.T) {
p := lfm2Model{}
replacer := strings.NewReplacer(p.Replacements()...)
if got, want := replacer.Replace("model.layers.2.feed_forward.expert_bias"), "blk.2.exp_probs_b.bias"; got != want {
t.Fatalf("expert bias replacement = %q, want %q", got, want)
}
if got, want := replacer.Replace("model.layers.2.feed_forward.gate.weight"), "blk.2.ffn_gate_inp.weight"; got != want {
t.Fatalf("gate replacement = %q, want %q", got, want)
}
}
func TestLFM2KVContextLengthEdgeCaseFallbackOverride(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536},
HiddenSize: 2048,
NumHiddenLayers: 40,
MaxPositionEmbeddings: 128000,
IntermediateSize: 11776,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
LayerTypes: make([]string, 40),
NormEps: 1e-5,
ConvLCache: 3,
MoEIntermediateSize: 1536,
NumExperts: 64,
NumExpertsPerToken: 4,
NumDenseLayers: 2,
}
for i := 0; i < len(p.LayerTypes); i++ {
p.LayerTypes[i] = "conv"
}
p.LayerTypes[2] = "full_attention"
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["context_length"], uint32(32768); got != want {
t.Fatalf("context_length = %v, want %v", got, want)
}
}
func TestLFM2KVContextLengthNoOverride(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536},
HiddenSize: 2048,
NumHiddenLayers: 39, // mismatch: should not trigger edge case
MaxPositionEmbeddings: 128000,
IntermediateSize: 11776,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
LayerTypes: []string{"conv", "full_attention"},
NormEps: 1e-5,
ConvLCache: 3,
MoEIntermediateSize: 1536,
NumExperts: 64,
NumExpertsPerToken: 4,
NumDenseLayers: 2,
}
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["context_length"], uint32(128000); got != want {
t.Fatalf("context_length = %v, want %v", got, want)
}
}
func TestLFM2KVFeedForwardLengthAutoAdjust(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536},
HiddenSize: 2048,
NumHiddenLayers: 16,
MaxPositionEmbeddings: 128000,
IntermediateSize: 12288, // should be ignored when block_ff_dim is set
BlockFFDim: 12288,
BlockAutoAdjustFFDim: true,
BlockMultipleOf: 256,
BlockFFNDimMultiplier: 1.0,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
LayerTypes: []string{"conv", "full_attention"},
NormEps: 1e-5,
ConvLCache: 3,
}
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["feed_forward_length"], uint32(8192); got != want {
t.Fatalf("feed_forward_length = %v, want %v", got, want)
}
}

417
convert/convert_lfm2_vl.go Normal file
View File

@@ -0,0 +1,417 @@
package convert
import (
"cmp"
"encoding/json"
"errors"
"fmt"
"io/fs"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
// lfm2VLTextModel converts the language model component of LFM2 VL checkpoints.
type lfm2VLTextModel struct {
TextConfig lfm2Model `json:"text_config"`
DoImageSplitting *bool `json:"do_image_splitting"`
DownsampleFactor uint32 `json:"downsample_factor"`
EncoderPatchSize uint32 `json:"encoder_patch_size"`
ImageTokenID uint32 `json:"image_token_id"`
MaxImageTokens uint32 `json:"max_image_tokens"`
MinImageTokens uint32 `json:"min_image_tokens"`
MaxTiles uint32 `json:"max_tiles"`
MinTiles uint32 `json:"min_tiles"`
TileSize uint32 `json:"tile_size"`
MaxPixelsTolerance float32 `json:"max_pixels_tolerance"`
ProjectorUseLayernorm bool `json:"projector_use_layernorm"`
ProjectorHiddenSize uint32 `json:"projector_hidden_size"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
UseImageSpecialTokens *bool `json:"use_image_special_tokens"`
UseThumbnail *bool `json:"use_thumbnail"`
VisionConfig struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
LayerNormEpsilon float32 `json:"layer_norm_eps"`
} `json:"vision_config"`
Processor struct {
ImageProcessor struct {
DoImageSplitting *bool `json:"do_image_splitting"`
DownsampleFactor uint32 `json:"downsample_factor"`
MaxImageTokens uint32 `json:"max_image_tokens"`
MinImageTokens uint32 `json:"min_image_tokens"`
MaxTiles uint32 `json:"max_tiles"`
MinTiles uint32 `json:"min_tiles"`
MaxPixelsTol float32 `json:"max_pixels_tolerance"`
TileSize uint32 `json:"tile_size"`
UseThumbnail *bool `json:"use_thumbnail"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
Size struct {
Height uint32 `json:"height"`
Width uint32 `json:"width"`
} `json:"size"`
} `json:"image_processor"`
}
}
func (p *lfm2VLTextModel) textModel() *lfm2Model {
return &p.TextConfig
}
func (p *lfm2VLTextModel) specialTokenTypes() []string {
return p.textModel().specialTokenTypes()
}
func (p *lfm2VLTextModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "processor_config.json")
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return err
}
return json.Unmarshal(bts, &p.Processor)
}
func (p *lfm2VLTextModel) visionImageSize() uint32 {
// LFM2-VL image processor operates on 512 tiles and downsamples by factor 2
// before projection. Keep a fixed square image size compatible with position
// embeddings and the simplified runtime image pipeline.
tile := cmp.Or(
p.Processor.ImageProcessor.TileSize,
p.Processor.ImageProcessor.Size.Height,
p.Processor.ImageProcessor.Size.Width,
uint32(512),
)
downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
if downsample == 0 {
return tile
}
return max(uint32(1), tile/downsample)
}
func (p *lfm2VLTextModel) KV(t *Tokenizer) KV {
kv := p.textModel().KV(t)
boolOr := func(defaultValue bool, values ...*bool) bool {
for _, v := range values {
if v != nil {
return *v
}
}
return defaultValue
}
kv["vision.block_count"] = cmp.Or(p.VisionConfig.NumHiddenLayers, uint32(27))
kv["vision.embedding_length"] = cmp.Or(p.VisionConfig.HiddenSize, uint32(1152))
kv["vision.feed_forward_length"] = cmp.Or(p.VisionConfig.IntermediateSize, uint32(4304))
kv["vision.attention.head_count"] = cmp.Or(p.VisionConfig.NumAttentionHeads, uint32(16))
kv["vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionConfig.LayerNormEpsilon, float32(1e-6))
kv["vision.patch_size"] = cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16))
kv["vision.num_channels"] = cmp.Or(p.VisionConfig.NumChannels, uint32(3))
kv["vision.image_size"] = p.visionImageSize()
kv["vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
kv["vision.projector.use_layernorm"] = p.ProjectorUseLayernorm
kv["vision.do_image_splitting"] = boolOr(true, p.DoImageSplitting, p.Processor.ImageProcessor.DoImageSplitting)
kv["vision.min_tiles"] = cmp.Or(p.MinTiles, p.Processor.ImageProcessor.MinTiles, uint32(2))
kv["vision.max_tiles"] = cmp.Or(p.MaxTiles, p.Processor.ImageProcessor.MaxTiles, uint32(10))
kv["vision.tile_size"] = cmp.Or(p.TileSize, p.Processor.ImageProcessor.TileSize, uint32(512))
kv["vision.min_image_tokens"] = cmp.Or(p.MinImageTokens, p.Processor.ImageProcessor.MinImageTokens, uint32(64))
kv["vision.max_image_tokens"] = cmp.Or(p.MaxImageTokens, p.Processor.ImageProcessor.MaxImageTokens, uint32(256))
kv["vision.max_pixels_tolerance"] = cmp.Or(p.MaxPixelsTolerance, p.Processor.ImageProcessor.MaxPixelsTol, float32(2.0))
kv["vision.use_thumbnail"] = boolOr(true, p.UseThumbnail, p.Processor.ImageProcessor.UseThumbnail)
kv["vision.use_image_special_tokens"] = boolOr(true, p.UseImageSpecialTokens)
kv["vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5}))
kv["vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5}))
kv["vision.image_token_id"] = cmp.Or(p.ImageTokenID, uint32(396))
setVisionTokenID := func(k, token string) {
if t == nil || t.Vocabulary == nil {
return
}
for i, v := range t.Vocabulary.Tokens {
if v == token {
kv[k] = uint32(i)
return
}
}
}
setVisionTokenID("vision.image_start_token_id", "<|image_start|>")
setVisionTokenID("vision.image_end_token_id", "<|image_end|>")
setVisionTokenID("vision.image_thumbnail_token_id", "<|img_thumbnail|>")
return kv
}
func (p *lfm2VLTextModel) Tensors(ts []Tensor) []*ggml.Tensor {
patchSize := int(cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16)))
numChannels := int(cmp.Or(p.VisionConfig.NumChannels, uint32(3)))
for _, t := range ts {
if t.Name() == "v.patch_embd.weight" {
shape := t.Shape()
if len(shape) == 2 {
inputDim := uint64(numChannels * patchSize * patchSize)
if shape[1] == inputDim {
channels := numChannels
patch := patchSize
t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) {
return repackPatchEmbeddingWeight(data, srcShape, channels, patch)
})
}
}
}
}
out := p.textModel().Tensors(ts)
for _, t := range out {
if t.Name == "v.patch_embd.weight" && len(t.Shape) == 2 {
t.Shape = []uint64{t.Shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)}
}
}
return out
}
func (p *lfm2VLTextModel) Replacements() []string {
out := make([]string, 0, 96)
addText := func(from, to string) {
out = append(out, from, to)
if strings.HasPrefix(from, "model.") {
suffix := strings.TrimPrefix(from, "model.")
out = append(out,
"model.language_model."+suffix, to,
"model.language_model.model."+suffix, to,
)
}
}
base := p.textModel().Replacements()
for i := 0; i+1 < len(base); i += 2 {
addText(base[i], base[i+1])
}
// Vision tower + multimodal projector tensors (single-file conversion).
out = append(out,
"model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd",
"model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd",
"model.vision_tower.vision_model.encoder.layers", "v.blk",
"model.vision_tower.vision_model.post_layernorm", "v.post_ln",
"model.multi_modal_projector.layer_norm", "mm.layer_norm",
"model.multi_modal_projector.linear_1", "mm.1",
"model.multi_modal_projector.linear_2", "mm.2",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.out_proj", "attn_out",
"layer_norm1", "ln1",
"layer_norm2", "ln2",
"mlp.fc1", "ffn_up",
"mlp.fc2", "ffn_down",
)
return out
}
// lfm2VLProjectorModel converts the vision encoder + projector component of LFM2 VL checkpoints.
type lfm2VLProjectorModel struct {
ModelParameters
DownsampleFactor uint32 `json:"downsample_factor"`
ProjectorHiddenDim uint32 `json:"projector_hidden_size"`
VisionModel struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
LayerNormEpsilon float32 `json:"layer_norm_eps"`
ImageSize uint32 `json:"image_size"`
} `json:"vision_config"`
Processor struct {
ImageProcessor struct {
DownsampleFactor uint32 `json:"downsample_factor"`
TileSize uint32 `json:"tile_size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
Size struct {
Height uint32 `json:"height"`
Width uint32 `json:"width"`
} `json:"size"`
} `json:"image_processor"`
}
}
var (
_ ModelConverter = (*lfm2VLTextModel)(nil)
_ ModelConverter = (*lfm2VLProjectorModel)(nil)
_ moreParser = (*lfm2VLTextModel)(nil)
_ moreParser = (*lfm2VLProjectorModel)(nil)
)
func (p *lfm2VLProjectorModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "processor_config.json")
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return err
}
return json.Unmarshal(bts, &p.Processor)
}
func (p *lfm2VLProjectorModel) imageSize() uint32 {
if p.VisionModel.ImageSize > 0 {
return p.VisionModel.ImageSize
}
downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
baseSize := cmp.Or(
p.Processor.ImageProcessor.TileSize,
p.Processor.ImageProcessor.Size.Height,
p.Processor.ImageProcessor.Size.Width,
uint32(256),
)
if downsample == 0 {
return baseSize
}
return max(uint32(1), baseSize/downsample)
}
func (p *lfm2VLProjectorModel) KV(_ *Tokenizer) KV {
kv := KV{
"general.architecture": "clip",
"general.type": "mmproj",
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"clip.has_vision_encoder": true,
"clip.projector_type": "lfm2",
"clip.use_gelu": true,
}
kv["clip.vision.block_count"] = cmp.Or(p.VisionModel.NumHiddenLayers, uint32(27))
kv["clip.vision.embedding_length"] = cmp.Or(p.VisionModel.HiddenSize, uint32(1152))
kv["clip.vision.feed_forward_length"] = cmp.Or(p.VisionModel.IntermediateSize, uint32(4304))
kv["clip.vision.attention.head_count"] = cmp.Or(p.VisionModel.NumAttentionHeads, uint32(16))
kv["clip.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, float32(1e-6))
kv["clip.vision.patch_size"] = cmp.Or(p.VisionModel.PatchSize, uint32(16))
kv["clip.vision.image_size"] = p.imageSize()
kv["clip.vision.projection_dim"] = cmp.Or(p.ProjectorHiddenDim, uint32(2048))
kv["clip.vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
kv["clip.vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5}))
kv["clip.vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5}))
return kv
}
func defaultFloat32Slice(v, fallback []float32) []float32 {
if len(v) > 0 {
return v
}
return fallback
}
func (p *lfm2VLProjectorModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
numChannels := cmp.Or(p.VisionModel.NumChannels, uint32(3))
patchSize := cmp.Or(p.VisionModel.PatchSize, uint32(16))
for _, t := range ts {
name := t.Name()
if !(strings.HasPrefix(name, "v.") || strings.HasPrefix(name, "mm.")) {
continue
}
shape := t.Shape()
if name == "v.patch_embd.weight" && len(shape) == 2 {
inputDim := uint64(numChannels * patchSize * patchSize)
if shape[1] == inputDim {
shape = []uint64{shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)}
channels := int(numChannels)
patch := int(patchSize)
t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) {
return repackPatchEmbeddingWeight(data, srcShape, channels, patch)
})
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
}
return out
}
func (p *lfm2VLProjectorModel) Replacements() []string {
return []string{
"model.multi_modal_projector.linear_1", "mm.1",
"model.multi_modal_projector.linear_2", "mm.2",
"model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd",
"model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd",
"model.vision_tower.vision_model.encoder.layers", "v.blk",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.out_proj", "attn_out",
"layer_norm1", "ln1",
"layer_norm2", "ln2",
"mlp.fc1", "ffn_up",
"mlp.fc2", "ffn_down",
"model.vision_tower.vision_model.post_layernorm", "v.post_ln",
}
}
func repackPatchEmbeddingWeight(data []float32, srcShape []uint64, channels, patch int) ([]float32, error) {
if len(srcShape) != 2 {
return nil, fmt.Errorf("invalid patch embedding shape rank: %d", len(srcShape))
}
outDim := int(srcShape[0])
flatInputDim := int(srcShape[1])
expectedInputDim := channels * patch * patch
if flatInputDim != expectedInputDim {
return nil, fmt.Errorf("invalid patch embedding input dim: got %d, want %d", flatInputDim, expectedInputDim)
}
expectedSize := outDim * flatInputDim
if len(data) != expectedSize {
return nil, fmt.Errorf("invalid patch embedding data size: got %d, want %d", len(data), expectedSize)
}
repacked := make([]float32, len(data))
perChannel := patch * patch
for o := range outDim {
inBase := o * flatInputDim
outBase := o * flatInputDim
for y := range patch {
for x := range patch {
inPixelBase := inBase + (y*patch+x)*channels
for c := range channels {
src := inPixelBase + c
dst := outBase + c*perChannel + y*patch + x
repacked[dst] = data[src]
}
}
}
}
return repacked, nil
}

View File

@@ -0,0 +1,249 @@
package convert
import (
"slices"
"strings"
"testing"
)
func TestLFM2VLTextModelKVUsesTextConfig(t *testing.T) {
p := lfm2VLTextModel{
TextConfig: lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536},
HiddenSize: 2048,
NumHiddenLayers: 16,
MaxPositionEmbeddings: 128000,
IntermediateSize: 12288,
BlockFFDim: 12288,
BlockAutoAdjustFFDim: true,
BlockMultipleOf: 256,
BlockFFNDimMultiplier: 1.0,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
LayerTypes: []string{"conv", "full_attention"},
NormEps: 1e-5,
ConvLCache: 3,
},
DownsampleFactor: 2,
VisionConfig: struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
LayerNormEpsilon float32 `json:"layer_norm_eps"`
}{
HiddenSize: 1152,
IntermediateSize: 4304,
NumAttentionHeads: 16,
NumHiddenLayers: 27,
NumChannels: 3,
PatchSize: 16,
LayerNormEpsilon: 1e-6,
},
}
p.Processor.ImageProcessor.TileSize = 512
p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5}
p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5}
kv := p.KV(&Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
Tokens: []string{"<|pad|>", "<image>", "<|image_start|>", "<|image_end|>", "<|img_thumbnail|>"},
},
})
if got, want := kv["general.architecture"], "lfm2"; got != want {
t.Fatalf("general.architecture = %v, want %v", got, want)
}
if got, want := kv["feed_forward_length"], uint32(8192); got != want {
t.Fatalf("feed_forward_length = %v, want %v", got, want)
}
if got, want := kv["vision.block_count"], uint32(27); got != want {
t.Fatalf("vision.block_count = %v, want %v", got, want)
}
if got, want := kv["vision.image_size"], uint32(256); got != want {
t.Fatalf("vision.image_size = %v, want %v", got, want)
}
if got, want := kv["vision.image_token_id"], uint32(396); got != want {
t.Fatalf("vision.image_token_id = %v, want %v", got, want)
}
if got, want := kv["vision.image_start_token_id"], uint32(2); got != want {
t.Fatalf("vision.image_start_token_id = %v, want %v", got, want)
}
if got, want := kv["vision.do_image_splitting"], true; got != want {
t.Fatalf("vision.do_image_splitting = %v, want %v", got, want)
}
if got, want := kv["vision.min_tiles"], uint32(2); got != want {
t.Fatalf("vision.min_tiles = %v, want %v", got, want)
}
if got, want := kv["vision.max_tiles"], uint32(10); got != want {
t.Fatalf("vision.max_tiles = %v, want %v", got, want)
}
if got, want := kv["vision.tile_size"], uint32(512); got != want {
t.Fatalf("vision.tile_size = %v, want %v", got, want)
}
if got, want := kv["vision.use_thumbnail"], true; got != want {
t.Fatalf("vision.use_thumbnail = %v, want %v", got, want)
}
if got, want := kv["vision.use_image_special_tokens"], true; got != want {
t.Fatalf("vision.use_image_special_tokens = %v, want %v", got, want)
}
}
func TestLFM2VLTextModelTensorsIncludeVision(t *testing.T) {
p := lfm2VLTextModel{}
p.VisionConfig.PatchSize = 16
p.VisionConfig.NumChannels = 3
input := []Tensor{
newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}),
newLFM2StubTensor("model.layers.0.ffn_norm.weight", []uint64{2048}),
newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}),
newLFM2StubTensor("v.blk.0.attn_q.weight", []uint64{1152, 1152}),
newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}),
}
out := p.Tensors(input)
if len(out) == 0 {
t.Fatal("expected non-empty tensor list")
}
foundPatch := false
foundVision := false
for _, tns := range out {
if tns.Name == "v.patch_embd.weight" {
foundPatch = true
if !slices.Equal(tns.Shape, []uint64{1152, 3, 16, 16}) {
t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", tns.Shape)
}
}
if strings.HasPrefix(tns.Name, "v.") || strings.HasPrefix(tns.Name, "mm.") {
foundVision = true
}
}
if !foundPatch {
t.Fatal("expected v.patch_embd.weight in output tensors")
}
if !foundVision {
t.Fatal("expected at least one vision/projector tensor in output")
}
}
func TestLFM2VLTextModelReplacements(t *testing.T) {
p := lfm2VLTextModel{}
r := strings.NewReplacer(p.Replacements()...)
tests := []struct {
name string
in string
want string
}{
{
name: "language_model_embed_tokens",
in: "model.language_model.embed_tokens.weight",
want: "token_embd.weight",
},
{
name: "language_model_layers",
in: "model.language_model.layers.2.self_attn.q_proj.weight",
want: "blk.2.attn_q.weight",
},
{
name: "nested_language_model_prefix",
in: "model.language_model.model.embedding_norm.weight",
want: "token_embd_norm.weight",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := r.Replace(tt.in); got != tt.want {
t.Fatalf("replacement(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestLFM2VLProjectorKV(t *testing.T) {
p := lfm2VLProjectorModel{
DownsampleFactor: 2,
ProjectorHiddenDim: 2048,
}
p.VisionModel.NumHiddenLayers = 27
p.VisionModel.HiddenSize = 1152
p.VisionModel.IntermediateSize = 4304
p.VisionModel.NumAttentionHeads = 16
p.VisionModel.PatchSize = 16
p.VisionModel.LayerNormEpsilon = 1e-6
p.Processor.ImageProcessor.TileSize = 512
p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5}
p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5}
kv := p.KV(nil)
if got, want := kv["general.architecture"], "clip"; got != want {
t.Fatalf("general.architecture = %v, want %v", got, want)
}
if got, want := kv["clip.projector_type"], "lfm2"; got != want {
t.Fatalf("clip.projector_type = %v, want %v", got, want)
}
if got, want := kv["clip.vision.image_size"], uint32(256); got != want {
t.Fatalf("clip.vision.image_size = %v, want %v", got, want)
}
}
func TestLFM2VLProjectorTensorsPatchReshape(t *testing.T) {
p := lfm2VLProjectorModel{}
p.VisionModel.NumChannels = 3
p.VisionModel.PatchSize = 16
input := []Tensor{
newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}),
newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}),
newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}),
}
out := p.Tensors(input)
if len(out) != 2 {
t.Fatalf("expected 2 tensors, got %d", len(out))
}
var patchShape []uint64
for _, tns := range out {
if tns.Name == "v.patch_embd.weight" {
patchShape = tns.Shape
break
}
}
if !slices.Equal(patchShape, []uint64{1152, 3, 16, 16}) {
t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", patchShape)
}
}
func TestRepackPatchEmbeddingWeight(t *testing.T) {
data := []float32{
0, 1, // y=0,x=0
2, 3, // y=0,x=1
4, 5, // y=1,x=0
6, 7, // y=1,x=1
}
got, err := repackPatchEmbeddingWeight(data, []uint64{1, 8}, 2, 2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := []float32{0, 2, 4, 6, 1, 3, 5, 7}
if !slices.Equal(got, want) {
t.Fatalf("repacked data = %v, want %v", got, want)
}
}

View File

@@ -0,0 +1,385 @@
package convert
import (
"cmp"
"encoding/json"
"fmt"
"io/fs"
"math"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type hybridPattern string
func (p *hybridPattern) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
*p = ""
return nil
}
var single string
if err := json.Unmarshal(data, &single); err == nil {
*p = hybridPattern(strings.TrimSpace(single))
return nil
}
var parts []string
if err := json.Unmarshal(data, &parts); err == nil {
*p = hybridPattern(strings.Join(parts, ""))
return nil
}
return fmt.Errorf("hybrid_override_pattern must be a string or string array")
}
type nemotronHModel struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_eps"`
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
ConvKernel uint32 `json:"conv_kernel"`
SSMStateSize uint32 `json:"ssm_state_size"`
MambaNumHeads uint32 `json:"mamba_num_heads"`
MambaHeadDim uint32 `json:"mamba_head_dim"`
NGroups uint32 `json:"n_groups"`
IntermediateSize uint32 `json:"intermediate_size"`
HybridOverridePattern hybridPattern `json:"hybrid_override_pattern"`
// MoE
NumExperts uint32 `json:"num_experts"`
NumSharedExperts uint32 `json:"num_shared_experts"`
NRoutedExperts uint32 `json:"n_routed_experts"`
NSharedExperts uint32 `json:"n_shared_experts"`
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
MoESharedExpertIntermediate uint32 `json:"moe_shared_expert_intermediate_size"`
NormTopKProb bool `json:"norm_topk_prob"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
ExpertGroupCount uint32 `json:"n_group"`
ExpertGroupUsedCount uint32 `json:"topk_group"`
}
var _ ModelConverter = (*nemotronHModel)(nil)
func (n *nemotronHModel) parseMore(_ fs.FS) error {
if n.NumHiddenLayers == 0 {
return fmt.Errorf("nemotron_h: num_hidden_layers must be set")
}
if n.HiddenSize == 0 {
return fmt.Errorf("nemotron_h: hidden_size must be set")
}
if n.NumAttentionHeads == 0 {
return fmt.Errorf("nemotron_h: num_attention_heads must be set")
}
if n.HeadDim == 0 {
if n.HiddenSize%n.NumAttentionHeads != 0 {
return fmt.Errorf("nemotron_h: hidden_size (%d) must be divisible by num_attention_heads (%d)", n.HiddenSize, n.NumAttentionHeads)
}
n.HeadDim = n.HiddenSize / n.NumAttentionHeads
}
if n.NumKeyValueHeads == 0 {
n.NumKeyValueHeads = n.NumAttentionHeads
}
if n.ConvKernel == 0 {
return fmt.Errorf("nemotron_h: conv_kernel must be set")
}
if n.SSMStateSize == 0 {
return fmt.Errorf("nemotron_h: ssm_state_size must be set")
}
if n.ssmHeadCount() == 0 {
return fmt.Errorf("nemotron_h: mamba_num_heads must be set")
}
if n.MambaHeadDim == 0 {
return fmt.Errorf("nemotron_h: mamba_head_dim must be set")
}
if n.NGroups == 0 {
n.NGroups = 1
}
if _, _, err := n.layerArrays(); err != nil {
return err
}
if n.isMoE() {
if n.routedExpertCount() == 0 {
return fmt.Errorf("nemotron_h: routed expert count must be set for MoE models")
}
if n.NumExpertsPerTok == 0 {
return fmt.Errorf("nemotron_h: num_experts_per_tok must be set for MoE models")
}
if n.NumExpertsPerTok > n.routedExpertCount() {
return fmt.Errorf("nemotron_h: num_experts_per_tok (%d) cannot exceed expert_count (%d)", n.NumExpertsPerTok, n.routedExpertCount())
}
if n.moeIntermediateSize() == 0 {
return fmt.Errorf("nemotron_h: moe_intermediate_size must be set for MoE models")
}
}
return nil
}
func (n *nemotronHModel) isMoE() bool {
return cmp.Or(n.routedExpertCount(), n.NumExpertsPerTok, n.MoEIntermediateSize) > 0
}
func (n *nemotronHModel) routedExpertCount() uint32 {
return cmp.Or(n.NRoutedExperts, n.NumExperts)
}
func (n *nemotronHModel) sharedExpertCount() uint32 {
return cmp.Or(n.NSharedExperts, n.NumSharedExperts)
}
func (n *nemotronHModel) ssmHeadCount() uint32 {
return n.MambaNumHeads
}
func (n *nemotronHModel) ssmInnerSize() uint32 {
return n.MambaHeadDim * n.ssmHeadCount()
}
func (n *nemotronHModel) epsilon() float32 {
return cmp.Or(n.NormEpsilon, n.LayerNormEpsilon, float32(1e-5))
}
func (n *nemotronHModel) moeIntermediateSize() uint32 {
return cmp.Or(n.MoEIntermediateSize, n.IntermediateSize)
}
func (n *nemotronHModel) denseIntermediateSize() uint32 {
return cmp.Or(n.IntermediateSize, n.MoEIntermediateSize)
}
func (n *nemotronHModel) layerArrays() (headCountKV []uint32, ffnLengths []uint32, err error) {
pattern := strings.TrimSpace(string(n.HybridOverridePattern))
if pattern == "" {
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern must be set")
}
runes := []rune(pattern)
if len(runes) != int(n.NumHiddenLayers) {
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern length (%d) must match num_hidden_layers (%d)", len(runes), n.NumHiddenLayers)
}
headCountKV = make([]uint32, n.NumHiddenLayers)
ffnLengths = make([]uint32, n.NumHiddenLayers)
attnKVHeads := cmp.Or(n.NumKeyValueHeads, n.NumAttentionHeads)
moeFFN := n.moeIntermediateSize()
denseFFN := n.denseIntermediateSize()
for i, layerType := range runes {
switch layerType {
case 'M':
// Recurrent layer: no KV heads and no FFN.
case '*', 'A':
// Attention-only layer.
headCountKV[i] = attnKVHeads
case 'E':
// MoE layer.
if moeFFN == 0 {
return nil, nil, fmt.Errorf("nemotron_h: moe layer at index %d but moe_intermediate_size is zero", i)
}
ffnLengths[i] = moeFFN
case '-':
// Dense FFN layer.
if denseFFN == 0 {
return nil, nil, fmt.Errorf("nemotron_h: dense FFN layer at index %d but intermediate_size is zero", i)
}
ffnLengths[i] = denseFFN
default:
return nil, nil, fmt.Errorf("nemotron_h: unsupported layer type %q in hybrid_override_pattern at index %d", layerType, i)
}
}
return headCountKV, ffnLengths, nil
}
func (n *nemotronHModel) KV(t *Tokenizer) KV {
kv := n.ModelParameters.KV(t)
arch := "nemotron_h"
if n.isMoE() {
arch = "nemotron_h_moe"
}
kv["general.architecture"] = arch
kv["block_count"] = n.NumHiddenLayers
kv["context_length"] = n.MaxPositionEmbeddings
kv["embedding_length"] = n.HiddenSize
kv["attention.head_count"] = n.NumAttentionHeads
kv["attention.key_length"] = n.HeadDim
kv["attention.value_length"] = n.HeadDim
kv["attention.layer_norm_epsilon"] = n.epsilon()
kv["attention.layer_norm_rms_epsilon"] = n.epsilon()
kv["rope.freq_base"] = cmp.Or(n.RopeTheta, float32(10000))
if n.PartialRotaryFactor > 0 && n.PartialRotaryFactor <= 1 {
kv["rope.dimension_count"] = uint32(float32(n.HeadDim) * n.PartialRotaryFactor)
}
if headCountKV, ffnLengths, err := n.layerArrays(); err == nil {
kv["attention.head_count_kv"] = headCountKV
kv["feed_forward_length"] = ffnLengths
}
kv["ssm.conv_kernel"] = n.ConvKernel
kv["ssm.inner_size"] = n.ssmInnerSize()
kv["ssm.state_size"] = n.SSMStateSize
kv["ssm.group_count"] = n.NGroups
kv["ssm.time_step_rank"] = n.ssmHeadCount()
if n.isMoE() {
kv["expert_count"] = n.routedExpertCount()
kv["expert_used_count"] = n.NumExpertsPerTok
kv["expert_feed_forward_length"] = n.moeIntermediateSize()
if n.sharedExpertCount() > 0 {
kv["expert_shared_count"] = n.sharedExpertCount()
}
if n.MoESharedExpertIntermediate > 0 {
kv["expert_shared_feed_forward_length"] = n.MoESharedExpertIntermediate
}
kv["expert_weights_norm"] = n.NormTopKProb
kv["expert_weights_scale"] = n.RoutedScalingFactor
if n.ExpertGroupCount > 0 {
kv["expert_group_count"] = n.ExpertGroupCount
}
if n.ExpertGroupUsedCount > 0 {
kv["expert_group_used_count"] = n.ExpertGroupUsedCount
}
}
return kv
}
func normalizeVectorShapeToColumn(shape []uint64) []uint64 {
switch len(shape) {
case 1:
return []uint64{shape[0], 1}
case 2:
if shape[0] == 1 && shape[1] > 1 {
return []uint64{shape[1], 1}
}
if shape[1] == 1 && shape[0] > 1 {
return []uint64{shape[0], 1}
}
}
return slices.Clone(shape)
}
func (n *nemotronHModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
remaining := ts
if n.isMoE() {
merges := make([]merge, 0, n.NumHiddenLayers*2)
for i := range n.NumHiddenLayers {
merges = append(merges, merge{
fmt.Sprintf("blk.%d.mixer.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.mixer.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
})
}
merged, rest := mergeTensors(ts, merges...)
out = append(out, merged...)
remaining = rest
}
nGroups := uint64(cmp.Or(n.NGroups, uint32(1)))
for _, t := range remaining {
name := t.Name()
shape := slices.Clone(t.Shape())
switch {
case strings.HasSuffix(name, ".ssm_a"):
shape = normalizeVectorShapeToColumn(shape)
t.SetRepacker(func(_ string, data []float32, _ []uint64) ([]float32, error) {
out := make([]float32, len(data))
for i, v := range data {
out[i] = -float32(math.Exp(float64(v)))
}
return out, nil
})
case strings.HasSuffix(name, ".ssm_d"):
shape = normalizeVectorShapeToColumn(shape)
case strings.HasSuffix(name, ".ssm_norm.weight"):
switch len(shape) {
case 1:
if nGroups > 0 && shape[0]%nGroups == 0 {
shape = []uint64{nGroups, shape[0] / nGroups}
}
case 2:
if shape[0] == 1 && nGroups > 0 && shape[1]%nGroups == 0 {
shape = []uint64{nGroups, shape[1] / nGroups}
}
}
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
if len(shape) == 3 {
if shape[0] == 1 {
shape = []uint64{shape[1], shape[2]}
} else if shape[1] == 1 {
shape = []uint64{shape[0], shape[2]}
}
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
})
}
return out
}
func (n *nemotronHModel) Replacements() []string {
return []string{
// Embedding and output
"lm_head", "output",
"backbone.embeddings", "token_embd",
"backbone.norm_f", "output_norm",
"backbone.layers", "blk",
// Recurrent (Mamba2) tensors
"mixer.in_proj", "ssm_in",
"mixer.out_proj", "ssm_out",
"mixer.dt_bias", "ssm_dt.bias",
"mixer.A_log", "ssm_a",
"mixer.D", "ssm_d",
"mixer.conv1d", "ssm_conv1d",
"mixer.norm.weight", "ssm_norm.weight",
// Attention tensors
"mixer.q_proj", "attn_q",
"mixer.k_proj", "attn_k",
"mixer.v_proj", "attn_v",
"mixer.o_proj", "attn_output",
// FFN / MoE tensors
"mixer.gate.e_score_correction_bias", "exp_probs_b.bias",
"mixer.gate", "ffn_gate_inp",
"mixer.fc1_latent_proj", "ffn_latent_in",
"mixer.fc2_latent_proj", "ffn_latent_out",
"mixer.shared_experts.up_proj", "ffn_up_shexp",
"mixer.shared_experts.down_proj", "ffn_down_shexp",
"mixer.up_proj", "ffn_up",
"mixer.down_proj", "ffn_down",
// Per-layer pre-norm
".norm.weight", ".attn_norm.weight",
}
}

View File

@@ -0,0 +1,230 @@
package convert
import (
"bytes"
"encoding/binary"
"encoding/json"
"io"
"os"
"path/filepath"
"slices"
"strings"
"testing"
)
func TestHybridPatternUnmarshal(t *testing.T) {
t.Run("string", func(t *testing.T) {
var p hybridPattern
if err := json.Unmarshal([]byte(`"MEM*"`), &p); err != nil {
t.Fatal(err)
}
if got, want := string(p), "MEM*"; got != want {
t.Fatalf("unexpected pattern: got %q want %q", got, want)
}
})
t.Run("array", func(t *testing.T) {
var p hybridPattern
if err := json.Unmarshal([]byte(`["M","E","M","*"]`), &p); err != nil {
t.Fatal(err)
}
if got, want := string(p), "MEM*"; got != want {
t.Fatalf("unexpected pattern: got %q want %q", got, want)
}
})
}
func TestNemotronHLayerArrays(t *testing.T) {
m := &nemotronHModel{
NumHiddenLayers: 5,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
HybridOverridePattern: "MEM*E",
NRoutedExperts: 128,
NumExpertsPerTok: 6,
MoEIntermediateSize: 1856,
}
headsKV, ffn, err := m.layerArrays()
if err != nil {
t.Fatal(err)
}
if got, want := headsKV, []uint32{0, 0, 0, 8, 0}; !slices.Equal(got, want) {
t.Fatalf("unexpected head_count_kv: got %v want %v", got, want)
}
if got, want := ffn, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
}
}
func TestNemotronHKV(t *testing.T) {
m := &nemotronHModel{
MaxPositionEmbeddings: 1048576,
HiddenSize: 2688,
NumHiddenLayers: 5,
NumAttentionHeads: 32,
NumKeyValueHeads: 2,
HeadDim: 128,
LayerNormEpsilon: 1e-5,
RopeTheta: 10000,
PartialRotaryFactor: 0.5,
ConvKernel: 4,
SSMStateSize: 128,
MambaNumHeads: 64,
MambaHeadDim: 64,
NGroups: 8,
HybridOverridePattern: "MEM*E",
NRoutedExperts: 128,
NSharedExperts: 1,
NumExpertsPerTok: 6,
MoEIntermediateSize: 1856,
MoESharedExpertIntermediate: 3712,
NormTopKProb: true,
RoutedScalingFactor: 2.5,
}
if err := m.parseMore(nil); err != nil {
t.Fatal(err)
}
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
if got, want := kv["general.architecture"], "nemotron_h_moe"; got != want {
t.Fatalf("unexpected architecture: got %v want %v", got, want)
}
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
if !ok {
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
}
if got, want := headCountKV, []uint32{0, 0, 0, 2, 0}; !slices.Equal(got, want) {
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
}
ffnLength, ok := kv["feed_forward_length"].([]uint32)
if !ok {
t.Fatalf("feed_forward_length has unexpected type: %T", kv["feed_forward_length"])
}
if got, want := ffnLength, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
}
}
func TestNemotronHTensorsTransforms(t *testing.T) {
m := &nemotronHModel{NGroups: 8}
in := []Tensor{
&fakeTensor{
name: "blk.0.ssm_a",
shape: []uint64{4},
data: []float32{0, 1, 2, 3},
},
&fakeTensor{
name: "blk.0.ssm_d",
shape: []uint64{4},
data: []float32{0, 1, 2, 3},
},
&fakeTensor{
name: "blk.0.ssm_norm.weight",
shape: []uint64{16},
data: make([]float32, 16),
},
&fakeTensor{
name: "blk.0.ssm_conv1d.weight",
shape: []uint64{10, 1, 4},
data: make([]float32, 40),
},
}
out := m.Tensors(in)
if len(out) != len(in) {
t.Fatalf("unexpected output tensor count: got %d want %d", len(out), len(in))
}
got := map[string]struct {
shape []uint64
writer io.WriterTo
}{}
for _, t := range out {
got[t.Name] = struct {
shape []uint64
writer io.WriterTo
}{shape: t.Shape, writer: t.WriterTo}
}
if shape := got["blk.0.ssm_a"].shape; !slices.Equal(shape, []uint64{4, 1}) {
t.Fatalf("unexpected ssm_a shape: %v", shape)
}
if shape := got["blk.0.ssm_d"].shape; !slices.Equal(shape, []uint64{4, 1}) {
t.Fatalf("unexpected ssm_d shape: %v", shape)
}
if shape := got["blk.0.ssm_norm.weight"].shape; !slices.Equal(shape, []uint64{8, 2}) {
t.Fatalf("unexpected ssm_norm shape: %v", shape)
}
if shape := got["blk.0.ssm_conv1d.weight"].shape; !slices.Equal(shape, []uint64{10, 4}) {
t.Fatalf("unexpected ssm_conv1d shape: %v", shape)
}
var b bytes.Buffer
if _, err := got["blk.0.ssm_a"].writer.WriteTo(&b); err != nil {
t.Fatal(err)
}
values := make([]float32, 4)
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
t.Fatal(err)
}
// 0 -> -exp(0) == -1
if values[0] != -1 {
t.Fatalf("unexpected transformed ssm_a[0]: got %v want -1", values[0])
}
}
func TestNemotronHLoadModelMetadata(t *testing.T) {
tempDir := t.TempDir()
config := `{
"architectures": ["NemotronHForCausalLM"],
"model_type": "nemotron_h",
"num_hidden_layers": 4,
"hidden_size": 512,
"max_position_embeddings": 32768,
"num_attention_heads": 8,
"num_key_value_heads": 2,
"head_dim": 64,
"layer_norm_epsilon": 1e-5,
"conv_kernel": 4,
"ssm_state_size": 128,
"mamba_num_heads": 16,
"mamba_head_dim": 32,
"n_groups": 8,
"hybrid_override_pattern": "ME*M",
"n_routed_experts": 16,
"num_experts_per_tok": 4,
"moe_intermediate_size": 256
}`
if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil {
t.Fatal(err)
}
kv, _, err := LoadModelMetadata(os.DirFS(tempDir))
if err != nil {
t.Fatal(err)
}
if _, ok := kv.(*nemotronHModel); !ok {
t.Fatalf("unexpected converter type: %T", kv)
}
}
func TestNemotronHReplacementsLatentProjections(t *testing.T) {
m := &nemotronHModel{}
r := strings.NewReplacer(m.Replacements()...)
if got, want := r.Replace("backbone.layers.1.mixer.fc1_latent_proj.weight"), "blk.1.ffn_latent_in.weight"; got != want {
t.Fatalf("unexpected fc1 replacement: got %q want %q", got, want)
}
if got, want := r.Replace("backbone.layers.1.mixer.fc2_latent_proj.weight"), "blk.1.ffn_latent_out.weight"; got != want {
t.Fatalf("unexpected fc2 replacement: got %q want %q", got, want)
}
}

97
convert/json_compat.go Normal file
View File

@@ -0,0 +1,97 @@
package convert
// sanitizeNonFiniteJSON rewrites non-standard JSON numeric tokens that some
// HF configs emit (Infinity, -Infinity, NaN) into standard JSON numbers.
//
// This is intentionally conservative:
// - only runs outside quoted strings
// - only rewrites full tokens
//
// We map these values to 0 because encoding/json rejects non-finite values,
// and these fields are typically model-side metadata not consumed by the
// converter.
func sanitizeNonFiniteJSON(in []byte) []byte {
if len(in) == 0 {
return in
}
out := make([]byte, 0, len(in))
inString := false
escape := false
for i := 0; i < len(in); {
c := in[i]
if inString {
out = append(out, c)
if escape {
escape = false
} else if c == '\\' {
escape = true
} else if c == '"' {
inString = false
}
i++
continue
}
if c == '"' {
inString = true
out = append(out, c)
i++
continue
}
if hasToken(in, i, "-Infinity") {
out = append(out, '0')
i += len("-Infinity")
continue
}
if hasToken(in, i, "Infinity") {
out = append(out, '0')
i += len("Infinity")
continue
}
if hasToken(in, i, "NaN") {
out = append(out, '0')
i += len("NaN")
continue
}
out = append(out, c)
i++
}
return out
}
func hasToken(in []byte, at int, tok string) bool {
end := at + len(tok)
if at < 0 || end > len(in) {
return false
}
if string(in[at:end]) != tok {
return false
}
if at > 0 && !isJSONValuePrefixBoundary(in[at-1]) {
return false
}
if end < len(in) && !isJSONValueSuffixBoundary(in[end]) {
return false
}
return true
}
func isJSONWhitespace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
func isJSONValuePrefixBoundary(b byte) bool {
return isJSONWhitespace(b) || b == ':' || b == ',' || b == '['
}
func isJSONValueSuffixBoundary(b byte) bool {
return isJSONWhitespace(b) || b == ',' || b == ']' || b == '}'
}

View File

@@ -0,0 +1,46 @@
package convert
import "testing"
func TestSanitizeNonFiniteJSON(t *testing.T) {
tests := []struct {
name string
in string
want string
}{
{
name: "infinity token",
in: `{"a":[0,Infinity,1]}`,
want: `{"a":[0,0,1]}`,
},
{
name: "negative infinity token",
in: `{"a":-Infinity}`,
want: `{"a":0}`,
},
{
name: "nan token",
in: `{"a":NaN}`,
want: `{"a":0}`,
},
{
name: "tokens inside strings untouched",
in: `{"a":"Infinity -Infinity NaN","b":Infinity}`,
want: `{"a":"Infinity -Infinity NaN","b":0}`,
},
{
name: "identifier-like token untouched",
in: `{"a":InfinityValue}`,
want: `{"a":InfinityValue}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := string(sanitizeNonFiniteJSON([]byte(tt.in)))
if got != tt.want {
t.Fatalf("sanitizeNonFiniteJSON() = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -212,8 +212,13 @@ type tokenizer struct {
PreTokenizer struct {
PreTokenizers []struct {
Type string `json:"type"`
Pattern struct {
Type string `json:"type"`
Behavior string `json:"behavior"`
Invert bool `json:"invert"`
AddPrefixSpace bool `json:"add_prefix_space"`
TrimOffsets bool `json:"trim_offsets"`
UseRegex bool `json:"use_regex"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`

View File

@@ -191,6 +191,84 @@ func TestParseTokenizer(t *testing.T) {
Pre: "default",
},
},
{
name: "llama-bpe pretokenizer and control tokens",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"added_tokens": [
{"id": 1, "content": "<|startoftext|>", "special": true},
{"id": 6, "content": "<|im_start|>", "special": true},
{"id": 7, "content": "<|im_end|>", "special": true},
{"id": 8, "content": "<|tool_list_start|>", "special": true},
{"id": 9, "content": "<|tool_list_end|>", "special": true},
{"id": 10, "content": "<|tool_call_start|>", "special": true},
{"id": 11, "content": "<|tool_call_end|>", "special": true},
{"id": 12, "content": "<|tool_response_start|>", "special": true},
{"id": 13, "content": "<|tool_response_end|>", "special": true},
{"id": 396, "content": "<image>", "special": true},
{"id": 64400, "content": "<think>", "special": true},
{"id": 64401, "content": "</think>", "special": true}
],
"model": {
"vocab": {
"<|startoftext|>": 1,
"<|im_start|>": 6,
"<|im_end|>": 7,
"<|tool_list_start|>": 8,
"<|tool_list_end|>": 9,
"<|tool_call_start|>": 10,
"<|tool_call_end|>": 11,
"<|tool_response_start|>": 12,
"<|tool_response_end|>": 13,
"<image>": 396,
"<think>": 64400,
"</think>": 64401
}
},
"pre_tokenizer": {
"type": "Sequence",
"pretokenizers": [
{
"type": "Split",
"pattern": {
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
},
"behavior": "Isolated",
"invert": false
},
{
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": true,
"use_regex": false
}
]
}
}`),
}),
want: &Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
Tokens: []string{
"<|startoftext|>",
"<|im_start|>",
"<|im_end|>",
"<|tool_list_start|>",
"<|tool_list_end|>",
"<|tool_call_start|>",
"<|tool_call_end|>",
"<|tool_response_start|>",
"<|tool_response_end|>",
"<image>",
"<think>",
"</think>",
},
Scores: []float32{1, 6, 7, 8, 9, 10, 11, 12, 13, 396, 64400, 64401},
Types: []int32{3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3},
},
Pre: "llama-bpe",
},
},
{
name: "list string merges",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{

View File

@@ -4,47 +4,65 @@ title: OpenClaw
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
## Install
Install [OpenClaw](https://openclaw.ai/)
```bash
npm install -g openclaw@latest
```
Then run the onboarding wizard:
```bash
openclaw onboard --install-daemon
```
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
## Quick start
```bash
ollama launch openclaw
```
Ollama handles everything automatically:
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
2. **Security** — On the first launch, a security notice explains the risks of tool access
3. **Model** — Pick a model from the selector (local or cloud)
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, and sets your model as the primary
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
This configures OpenClaw to use Ollama and starts the gateway.
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
## Configure without launching
To change the model without starting the gateway and TUI:
To configure without launching:
```shell
```bash
ollama launch openclaw --config
```
## Recommended Models
To use a specific model directly:
- `qwen3-coder`
- `glm-4.7`
- `gpt-oss:20b`
- `gpt-oss:120b`
```bash
ollama launch openclaw --model kimi-k2.5:cloud
```
If the gateway is already running, it restarts automatically to pick up the new model.
## Recommended models
**Cloud models**:
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
- `minimax-m2.5:cloud` — Fast, efficient coding and real-world productivity
- `glm-5:cloud` — Reasoning and code generation
**Local models:**
- `glm-4.7-flash` — Reasoning and code generation locally (~25 GB VRAM)
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
## Connect messaging apps
```bash
openclaw configure --section channels
```
Link WhatsApp, Telegram, Slack, Discord, or iMessage to chat with your local models from anywhere.
## Stopping the gateway
```bash
openclaw gateway stop
```
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).

View File

@@ -160,6 +160,27 @@ func (kv KV) SSMGroupCount() uint64 {
return uint64(kv.Uint("ssm.group_count"))
}
func (kv KV) FFNLength() []uint64 {
ffnLengthDefault := uint32(0)
ffnLength := kv.UintOrArrayValueAsArray("feed_forward_length", ffnLengthDefault)
if len(ffnLength) == 1 {
ffnLengthDefault = ffnLength[0]
}
nLayers := int(kv.BlockCount())
if len(ffnLength) > nLayers {
slog.Warn("got more elements of feed_forward_length than layers", "len(ffnLength)", len(ffnLength), "layers", nLayers)
}
out := make([]uint64, nLayers)
for i := range nLayers {
if i >= len(ffnLength) {
out[i] = uint64(ffnLengthDefault)
} else {
out[i] = uint64(ffnLength[i])
}
}
return out
}
// general types
func (kv KV) String(key string, defaultValue ...string) string {
@@ -264,6 +285,7 @@ func (kv KV) OllamaEngineRequired() bool {
"llama4",
"mistral3",
"mllama",
"nemotron_h", "nemotron_h_moe",
"nomic-bert",
"olmo3",
"qwen25vl",
@@ -273,6 +295,7 @@ func (kv KV) OllamaEngineRequired() bool {
"glm4moelite",
"glmocr",
"lfm2",
"lfm2moe",
}, kv.Architecture())
}
@@ -864,7 +887,9 @@ func (f GGML) FlashAttention() bool {
"glmocr",
"gptoss", "gpt-oss",
"lfm2",
"lfm2moe",
"mistral3",
"nemotron_h", "nemotron_h_moe",
"olmo3",
"qwen3", "qwen3moe",
"qwen3next",

752
kvcache/recurrent.go Normal file
View File

@@ -0,0 +1,752 @@
package kvcache
import (
"errors"
"fmt"
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
const (
DefaultCheckpointCount = 32
DefaultCheckpointMinPos = int32(16)
DefaultCheckpointInterval = int32(1280)
)
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
// Config configures a shared hybrid recurrent cache.
type RecurrentConfig struct {
Shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
ConvDim int
ConvChannels int
RecurrentStateSize int
CheckpointLogPrefix string
}
var (
_ Cache = (*Recurrent)(nil)
_ CheckpointCache = (*Recurrent)(nil)
)
// Cache stores:
// - a standard causal KV cache
// - per-sequence conv state for recurrent operators
// - per-sequence recurrent state for recurrent operators
//
// Conv state shape (per layer, per sequence): [convDim, convChannels]
// Recurrent state shape (per layer, per sequence): [recurrentStateSize]
type Recurrent struct {
kv *Causal
backend ml.Backend
dtype ml.DType
maxSequences int
// Conv state dimensions
convDim int
convChannels int
// Recurrent state dimensions
recurrentStateSize int
logPrefix string
// slot mapping for recurrent state (copy-on-write)
slotForSeq map[int]int
refCount []int
freeSlots []int
seqCounts map[int]int
slotScratch [1]int32
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
// per-layer recurrent state buffers (allocated lazily)
recurrentCtxs map[int]ml.Context
recurrentStates map[int]ml.Tensor // [recurrentStateSize, maxSlots]
// recurrent checkpoints (per slot)
checkpointCount int
checkpointMinPos int32
checkpointInterval int32
checkpointCtxSize int
checkpoints map[int]*slotCheckpointStore
pendingRestore map[int]checkpointRestore
curCheckpointPos []int32
curCheckpointSlots map[int]int
reserveCheckpoints bool
checkpointConvCtxs map[int]ml.Context
checkpointRecurCtxs map[int]ml.Context
checkpointReserved map[int]struct{}
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
writableError error
}
func NewRecurrentCache(config RecurrentConfig) *Recurrent {
return &Recurrent{
kv: NewCausalCache(config.Shift),
convDim: config.ConvDim,
convChannels: config.ConvChannels,
recurrentStateSize: config.RecurrentStateSize,
logPrefix: config.CheckpointLogPrefix,
slotForSeq: make(map[int]int),
seqCounts: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
recurrentCtxs: make(map[int]ml.Context),
recurrentStates: make(map[int]ml.Tensor),
checkpointCount: DefaultCheckpointCount,
checkpointMinPos: DefaultCheckpointMinPos,
checkpointInterval: DefaultCheckpointInterval,
checkpoints: make(map[int]*slotCheckpointStore),
pendingRestore: make(map[int]checkpointRestore),
curCheckpointSlots: make(map[int]int),
checkpointConvCtxs: make(map[int]ml.Context),
checkpointRecurCtxs: make(map[int]ml.Context),
checkpointReserved: make(map[int]struct{}),
}
}
func (c *Recurrent) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
c.checkpoints = make(map[int]*slotCheckpointStore)
c.pendingRestore = make(map[int]checkpointRestore)
c.curCheckpointPos = c.curCheckpointPos[:0]
c.curCheckpointSlots = make(map[int]int)
c.checkpointReserved = make(map[int]struct{})
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
if c.checkpointCtxSize < 8 {
c.checkpointCtxSize = 8
}
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *Recurrent) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
for _, ctx := range c.recurrentCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointConvCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointRecurCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *Recurrent) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *Recurrent) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *Recurrent) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *Recurrent) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *Recurrent) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
nTokens := len(batch.Sequences)
if nTokens == 0 {
c.curSeqs = c.curSeqs[:0]
c.curSlots = c.curSlots[:0]
c.curSlotsInput = nil
c.curSeqTokens = 0
c.reserveCheckpoints = false
c.writableEnsured = false
c.writableError = nil
return nil
}
// Fast path for single-sequence batches (common during decode and prefill).
firstSeq := batch.Sequences[0]
singleSeq := true
for _, s := range batch.Sequences[1:] {
if s != firstSeq {
singleSeq = false
break
}
}
if singleSeq {
return c.startForwardSingleSeq(ctx, firstSeq, nTokens, batch, reserve)
}
// Derive equal-length sequence layout for recurrent layers.
seqCounts := c.seqCounts
for s := range seqCounts {
delete(seqCounts, s)
}
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if seqCounts[s] == 0 {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return ErrNotSupported
}
}
c.curSeqTokens = want
if reserve {
c.curSlots = c.curSlots[:0]
for i := range nSeqs {
c.curSlots = append(c.curSlots, i)
}
c.finalizeStartForward(ctx, batch, true)
return nil
}
// Ensure slots exist for sequences in this batch.
c.curSlots = c.curSlots[:0]
var newSlots []int
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
c.curSlots = append(c.curSlots, slot)
}
if len(newSlots) > 0 {
c.zeroSlots(ctx, newSlots)
}
c.finalizeStartForward(ctx, batch, false)
return nil
}
func (c *Recurrent) startForwardSingleSeq(ctx ml.Context, seq, seqTokens int, batch input.Batch, reserve bool) error {
c.curSeqs = append(c.curSeqs[:0], seq)
c.curSeqTokens = seqTokens
if reserve {
c.curSlots = append(c.curSlots[:0], 0)
c.finalizeStartForward(ctx, batch, true)
return nil
}
slot, ok := c.slotForSeq[seq]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[seq] = slot
c.refCount[slot] = 1
slotList := [1]int{slot}
c.zeroSlots(ctx, slotList[:])
}
c.curSlots = append(c.curSlots[:0], slot)
c.finalizeStartForward(ctx, batch, false)
return nil
}
func (c *Recurrent) finalizeStartForward(ctx ml.Context, batch input.Batch, reserve bool) {
c.setCurSlotsInput(ctx)
c.writableEnsured = false
c.writableError = nil
c.reserveCheckpoints = reserve
c.planCheckpoints(batch)
}
func (c *Recurrent) setCurSlotsInput(ctx ml.Context) {
c.curSlotsInput = c.slotsInput(ctx, c.curSlots)
}
func (c *Recurrent) slotsInput(ctx ml.Context, slots []int) ml.Tensor {
switch len(slots) {
case 0:
return nil
case 1:
c.slotScratch[0] = int32(slots[0])
return ctx.Input().FromInts(c.slotScratch[:], 1)
default:
slotIndices := make([]int32, len(slots))
for i, v := range slots {
slotIndices[i] = int32(v)
}
return ctx.Input().FromInts(slotIndices, len(slotIndices))
}
}
func (c *Recurrent) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *Recurrent) freeSlot(slot int) {
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// zeroSlots zeros recurrent state for the given slots across all cached layers.
func (c *Recurrent) zeroSlots(ctx ml.Context, slots []int) {
if len(slots) == 0 {
return
}
inputCtx := ctx.Input()
slotsTensor := c.slotsInput(ctx, slots)
if len(c.convStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
for _, buf := range c.convStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
if len(c.recurrentStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.recurrentStateSize, len(slots))
for _, buf := range c.recurrentStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
}
// EnsureWritable ensures sequences have private slots (copy-on-write).
func (c *Recurrent) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
}
c.setCurSlotsInput(ctx)
return nil
}
func (c *Recurrent) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
for _, buf := range c.convStates {
rows := buf.Rows(ctx, src)
if rows.DType() != ml.DTypeF32 {
rows = rows.Cast(ctx, ml.DTypeF32)
}
ctx.Forward(buf.SetRows(ctx, rows, dst))
}
for _, buf := range c.recurrentStates {
rows := buf.Rows(ctx, src)
if rows.DType() != ml.DTypeF32 {
rows = rows.Cast(ctx, ml.DTypeF32)
}
ctx.Forward(buf.SetRows(ctx, rows, dst))
}
}
func (c *Recurrent) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
if c.validSlot(dstSlot) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
return
}
if c.validSlot(srcSlot) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *Recurrent) CanResume(seq int, pos int32) bool {
if !c.kv.CanResume(seq, pos) {
return false
}
if pos == 0 {
return true
}
return c.hasCheckpoint(seq, pos)
}
func (c *Recurrent) Remove(seq int, beginIndex, endIndex int32) error {
if beginIndex > 0 && endIndex != math.MaxInt32 {
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
delete(c.pendingRestore, seq)
slot, ok := c.slotForSeq[seq]
if !ok || !c.validSlot(slot) {
return nil
}
// Detach shared recurrent state/checkpoints before mutating checkpoint positions.
if c.refCount[slot] > 1 {
newSlot, err := c.allocSlot()
if err != nil {
return err
}
ctx := c.backend.NewContext()
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
ctx.Compute()
}
ctx.Close()
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
slot = newSlot
}
c.shiftCheckpoints(slot, beginIndex, endIndex)
return nil
}
if beginIndex > 0 {
restore, ok := c.pendingRestore[seq]
if !ok || restore.pos+1 != beginIndex {
return ErrNotSupported
}
if !c.restoreComplete(restore) {
return ErrNotSupported
}
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
newSlot, err := c.allocSlot()
if err != nil {
return err
}
ctx := c.backend.NewContext()
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
ctx.Compute()
}
ctx.Close()
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
restore.slot = newSlot
c.pendingRestore[seq] = restore
}
}
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
if beginIndex > 0 {
restore := c.pendingRestore[seq]
delete(c.pendingRestore, seq)
return c.applyCheckpointRestore(restore)
}
slot, ok := c.slotForSeq[seq]
delete(c.pendingRestore, seq)
if !ok {
return nil
}
if !c.validSlot(slot) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.clearCheckpoints(slot)
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
}
func (c *Recurrent) validSlot(slot int) bool {
return slot >= 0 && slot < len(c.refCount)
}
func (c *Recurrent) SlotsTensor() ml.Tensor {
return c.curSlotsInput
}
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
func (c *Recurrent) contiguousSlots() (int, bool) {
if len(c.curSlots) == 0 {
return 0, false
}
start := c.curSlots[0]
for i, s := range c.curSlots {
if s != start+i {
return 0, false
}
}
return start, true
}
func (c *Recurrent) SeqTokens() int {
return c.curSeqTokens
}
func (c *Recurrent) NumSeqs() int {
return len(c.curSeqs)
}
func (c *Recurrent) convBuffer(layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
c.convStates[layer] = buf
return buf
}
func (c *Recurrent) recurrentBuffer(layer int) ml.Tensor {
if buf, ok := c.recurrentStates[layer]; ok {
return buf
}
if _, ok := c.recurrentCtxs[layer]; !ok {
c.recurrentCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
buf := c.recurrentCtxs[layer].Zeros(ml.DTypeF32, c.recurrentStateSize, c.maxSequences)
c.recurrentStates[layer] = buf
return buf
}
func (c *Recurrent) ensureWritable(ctx ml.Context) error {
c.ensureWritableOnce(ctx)
return c.writableError
}
func (c *Recurrent) currentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int) ml.Tensor {
if start, ok := c.contiguousSlots(); ok {
offset := start * buf.Stride(1)
return buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
}
return buf.Rows(ctx, c.SlotsTensor())
}
func (c *Recurrent) writeCurrentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int, src ml.Tensor) {
if start, ok := c.contiguousSlots(); ok {
offset := start * buf.Stride(1)
view := buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
ctx.Forward(src.Copy(ctx, view))
return
}
ctx.Forward(buf.SetRows(ctx, src, c.SlotsTensor()))
}
func (c *Recurrent) ensureWritableOnce(ctx ml.Context) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
}
// ConvState returns conv state for current batch sequences as [convDim, convChannels, nSeqs].
func (c *Recurrent) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
if err := c.ensureWritable(ctx); err != nil {
return nil, err
}
buf := c.convBuffer(layer)
cur := c.currentSlotRows(ctx, buf, c.convDim*c.convChannels)
return cur.Reshape(ctx, c.convDim, c.convChannels, c.NumSeqs()), nil
}
// UpdateConvState writes new conv state for current batch sequences.
func (c *Recurrent) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(layer)
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.NumSeqs())
srcF32 := src
if src.DType() != ml.DTypeF32 {
srcF32 = src.Cast(ctx, ml.DTypeF32)
}
c.writeCurrentSlotRows(ctx, buf, c.convDim*c.convChannels, srcF32)
c.captureConvCheckpoint(ctx, layer, srcF32)
}
// RecurrentState returns recurrent state for current batch sequences with shape [dims..., nSeqs].
func (c *Recurrent) RecurrentState(ctx ml.Context, layer int, dims ...int) (ml.Tensor, error) {
if err := c.ensureWritable(ctx); err != nil {
return nil, err
}
if len(dims) == 0 {
return nil, ErrInvalidRecurrentShape
}
size := 1
for _, d := range dims {
if d <= 0 {
return nil, ErrInvalidRecurrentShape
}
size *= d
}
if size != c.recurrentStateSize {
return nil, fmt.Errorf("%w: got %v (size %d), want size %d", ErrInvalidRecurrentShape, dims, size, c.recurrentStateSize)
}
buf := c.recurrentBuffer(layer)
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
shape := make([]int, 0, len(dims)+1)
shape = append(shape, dims...)
shape = append(shape, c.NumSeqs())
return cur.Reshape(ctx, shape...), nil
}
// RecurrentState4D returns recurrent state as [dim0, dim1, dim2, nSeqs].
func (c *Recurrent) RecurrentState4D(ctx ml.Context, layer int, dim0, dim1, dim2 int) (ml.Tensor, error) {
if err := c.ensureWritable(ctx); err != nil {
return nil, err
}
if dim0 <= 0 || dim1 <= 0 || dim2 <= 0 {
return nil, ErrInvalidRecurrentShape
}
size := dim0 * dim1 * dim2
if size != c.recurrentStateSize {
return nil, fmt.Errorf("%w: got [%d %d %d] (size %d), want size %d", ErrInvalidRecurrentShape, dim0, dim1, dim2, size, c.recurrentStateSize)
}
buf := c.recurrentBuffer(layer)
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
return cur.Reshape(ctx, dim0, dim1, dim2, c.NumSeqs()), nil
}
// UpdateRecurrentState writes new recurrent state for current batch sequences.
func (c *Recurrent) UpdateRecurrentState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.recurrentBuffer(layer)
src := newState.Reshape(ctx, c.recurrentStateSize, c.NumSeqs())
srcF32 := src
if src.DType() != ml.DTypeF32 {
srcF32 = src.Cast(ctx, ml.DTypeF32)
}
c.writeCurrentSlotRows(ctx, buf, c.recurrentStateSize, srcF32)
c.captureRecurrentCheckpoint(ctx, layer, srcF32)
}
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
func (c *Recurrent) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *Recurrent) Seqs() []int {
return slices.Clone(c.curSeqs)
}

View File

@@ -0,0 +1,561 @@
package kvcache
import (
"log/slog"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
// memory usage while preserving prefix reuse for recurrent state.
type checkpointEntry struct {
pos int32
conv map[int]ml.Tensor
recurrent map[int]ml.Tensor
}
type slotCheckpointStore struct {
entries []checkpointEntry
size int
next int
lastPos int32
}
type checkpointRestore struct {
slot int
idx int
pos int32
}
func newSlotCheckpointStore(n int) *slotCheckpointStore {
entries := make([]checkpointEntry, n)
for i := range entries {
entries[i].pos = -1
}
return &slotCheckpointStore{
entries: entries,
lastPos: -1,
}
}
func (s *slotCheckpointStore) reset() {
s.size = 0
s.next = 0
s.lastPos = -1
for i := range s.entries {
s.entries[i].pos = -1
}
}
func (s *slotCheckpointStore) record(pos int32) int {
if len(s.entries) == 0 {
return -1
}
idx := s.next
s.next = (s.next + 1) % len(s.entries)
if s.size < len(s.entries) {
s.size++
}
s.entries[idx].pos = pos
s.lastPos = pos
return idx
}
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
bestIdx := -1
bestPos := int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 || pos >= targetPos {
continue
}
if pos > bestPos {
bestPos = pos
bestIdx = i
}
}
if bestIdx < 0 {
return -1, -1, false
}
return bestIdx, bestPos, true
}
func (s *slotCheckpointStore) pruneAfter(pos int32) {
if len(s.entries) == 0 {
s.size = 0
s.next = 0
s.lastPos = -1
return
}
size := 0
next := -1
minPos := int32(math.MaxInt32)
minIdx := 0
for i := range s.entries {
if s.entries[i].pos > pos {
s.entries[i].pos = -1
}
if s.entries[i].pos >= 0 {
size++
if s.entries[i].pos < minPos {
minPos = s.entries[i].pos
minIdx = i
}
} else if next == -1 {
next = i
}
}
s.size = size
if size == 0 {
s.next = 0
s.lastPos = -1
return
}
if next != -1 {
s.next = next
} else {
// Full ring: overwrite the oldest checkpoint next.
s.next = minIdx
}
s.lastPos = pos
}
func (s *slotCheckpointStore) shiftRange(beginIndex, endIndex int32) {
if len(s.entries) == 0 {
s.size = 0
s.next = 0
s.lastPos = -1
return
}
offset := beginIndex - endIndex
size := 0
next := -1
minPos := int32(math.MaxInt32)
maxPos := int32(-1)
minIdx := 0
for i := range s.entries {
pos := s.entries[i].pos
if pos >= 0 {
if pos >= beginIndex && pos < endIndex {
s.entries[i].pos = -1
} else if pos >= endIndex {
s.entries[i].pos = pos + offset
}
}
pos = s.entries[i].pos
if pos >= 0 {
size++
if pos < minPos {
minPos = pos
minIdx = i
}
if pos > maxPos {
maxPos = pos
}
} else if next == -1 {
next = i
}
}
s.size = size
if size == 0 {
s.next = 0
s.lastPos = -1
return
}
if next != -1 {
s.next = next
} else {
// Full ring: overwrite the oldest checkpoint next.
s.next = minIdx
}
s.lastPos = maxPos
}
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
minPos = int32(math.MaxInt32)
maxPos = int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 {
continue
}
size++
if pos < minPos {
minPos = pos
}
if pos > maxPos {
maxPos = pos
}
}
if size == 0 {
minPos = -1
maxPos = -1
}
return size, minPos, maxPos, s.lastPos
}
func (c *Recurrent) checkpointTag() string {
if c.logPrefix == "" {
return "kvcache.recurrent"
}
return c.logPrefix
}
func (c *Recurrent) planCheckpoints(batch input.Batch) {
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
c.curCheckpointPos = c.curCheckpointPos[:0]
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
return
}
if cap(c.curCheckpointPos) < len(c.curSeqs) {
c.curCheckpointPos = make([]int32, len(c.curSeqs))
} else {
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
}
for i := range c.curCheckpointPos {
c.curCheckpointPos[i] = -1
}
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
posMax := make(map[int]int32, len(c.curSeqs))
for i, seq := range batch.Sequences {
pos := batch.Positions[i]
if cur, ok := posMax[seq]; !ok || pos > cur {
posMax[seq] = pos
}
}
for i, seq := range c.curSeqs {
pos, ok := posMax[seq]
if !ok {
continue
}
if pos < c.checkpointMinPos {
continue
}
slot := c.curSlots[i]
store := c.checkpointStore(slot)
lastPos := store.lastPos
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
c.curCheckpointPos[i] = pos
}
}
}
func (c *Recurrent) checkpointStore(slot int) *slotCheckpointStore {
store, ok := c.checkpoints[slot]
if ok {
return store
}
store = newSlotCheckpointStore(c.checkpointCount)
c.checkpoints[slot] = store
return store
}
func (c *Recurrent) checkpointIndexForSlot(slot int, pos int32) int {
if c.checkpointCount == 0 {
return -1
}
if idx, ok := c.curCheckpointSlots[slot]; ok {
return idx
}
store := c.checkpointStore(slot)
idx := store.record(pos)
if idx >= 0 {
c.curCheckpointSlots[slot] = idx
}
return idx
}
func (c *Recurrent) hasCheckpoint(seq int, pos int32) bool {
if pos <= 0 {
return false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return false
}
store, ok := c.checkpoints[slot]
if !ok {
return false
}
_, _, ok = store.bestIndex(pos)
return ok
}
func (c *Recurrent) PrepareRestore(seq int, targetPos int32) (int32, bool) {
if targetPos <= 0 {
return 0, false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return 0, false
}
store, ok := c.checkpoints[slot]
if !ok {
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
return 0, false
}
idx, pos, ok := store.bestIndex(targetPos)
if !ok {
size, minPos, maxPos, lastPos := store.window()
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
"min", minPos, "max", maxPos, "last", lastPos)
return 0, false
}
c.pendingRestore[seq] = checkpointRestore{
slot: slot,
idx: idx,
pos: pos,
}
return pos + 1, true
}
func (c *Recurrent) applyCheckpointRestore(restore checkpointRestore) error {
entry, ok := c.restoreEntry(restore)
if !ok {
return ErrNotSupported
}
ctx := c.backend.NewContext()
defer ctx.Close()
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
for layer, src := range entry.conv {
buf := c.convBuffer(layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
for layer, src := range entry.recurrent {
buf := c.recurrentBuffer(layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
if len(entry.conv) > 0 || len(entry.recurrent) > 0 {
ctx.Compute()
}
store := c.checkpoints[restore.slot]
store.pruneAfter(restore.pos)
return nil
}
func (c *Recurrent) restoreComplete(restore checkpointRestore) bool {
_, ok := c.restoreEntry(restore)
return ok
}
func (c *Recurrent) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
store, ok := c.checkpoints[restore.slot]
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
return nil, false
}
entry := &store.entries[restore.idx]
if entry.pos < 0 {
return nil, false
}
if !c.entryComplete(entry) {
return nil, false
}
return entry, true
}
func (c *Recurrent) entryComplete(entry *checkpointEntry) bool {
for layer := range c.convStates {
if entry.conv == nil || entry.conv[layer] == nil {
return false
}
}
for layer := range c.recurrentStates {
if entry.recurrent == nil || entry.recurrent[layer] == nil {
return false
}
}
return true
}
func (c *Recurrent) clearCheckpoints(slot int) {
if store, ok := c.checkpoints[slot]; ok {
store.reset()
}
}
func (c *Recurrent) shiftCheckpoints(slot int, beginIndex, endIndex int32) {
if store, ok := c.checkpoints[slot]; ok {
store.shiftRange(beginIndex, endIndex)
}
}
func (c *Recurrent) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
if c.checkpointCount == 0 {
return
}
srcStore, ok := c.checkpoints[srcSlot]
if !ok || srcStore.size == 0 {
return
}
dstStore := c.checkpointStore(dstSlot)
dstStore.size = srcStore.size
dstStore.next = srcStore.next
dstStore.lastPos = srcStore.lastPos
for i := range srcStore.entries {
srcEntry := &srcStore.entries[i]
dstEntry := &dstStore.entries[i]
dstEntry.pos = srcEntry.pos
if srcEntry.conv != nil {
if dstEntry.conv == nil {
dstEntry.conv = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.conv {
dst := c.ensureCheckpointConv(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
if srcEntry.recurrent != nil {
if dstEntry.recurrent == nil {
dstEntry.recurrent = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.recurrent {
dst := c.ensureCheckpointRecurrent(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
}
}
func (c *Recurrent) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointConv(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointConv(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *Recurrent) captureRecurrentCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointRecurrent(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointRecurrent(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *Recurrent) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
if entry.conv == nil {
entry.conv = make(map[int]ml.Tensor)
}
if t, ok := entry.conv[layer]; ok {
return t
}
ctx, ok := c.checkpointConvCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointConvCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
entry.conv[layer] = t
return t
}
func (c *Recurrent) ensureCheckpointRecurrent(layer int, entry *checkpointEntry) ml.Tensor {
if entry.recurrent == nil {
entry.recurrent = make(map[int]ml.Tensor)
}
if t, ok := entry.recurrent[layer]; ok {
return t
}
ctx, ok := c.checkpointRecurCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointRecurCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.recurrentStateSize, 1)
entry.recurrent[layer] = t
return t
}
func (c *Recurrent) reserveCheckpointConv(layer int) {
key := checkpointReserveKey(layer, 0)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointConv(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func (c *Recurrent) reserveCheckpointRecurrent(layer int) {
key := checkpointReserveKey(layer, 1)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointRecurrent(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func checkpointReserveKey(layer int, kind int) int {
return layer*2 + kind
}

View File

@@ -0,0 +1,288 @@
package kvcache
import (
"errors"
"math"
"slices"
"testing"
"github.com/ollama/ollama/ml"
)
func newTestCache() *Recurrent {
return NewRecurrentCache(RecurrentConfig{ConvDim: 1, ConvChannels: 2, RecurrentStateSize: 2})
}
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
store := newSlotCheckpointStore(2)
store.record(10)
store.record(20)
_, pos, ok := store.bestIndex(15)
if !ok || pos != 10 {
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
}
store.record(30) // overwrite oldest (10)
if _, _, ok := store.bestIndex(15); ok {
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
}
_, pos, ok = store.bestIndex(40)
if !ok || pos != 30 {
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
}
}
func TestCachePrepareRestore(t *testing.T) {
cache := newTestCache()
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
store := cache.checkpointStore(0)
store.record(5)
store.record(9)
store.record(15)
restorePos, ok := cache.PrepareRestore(1, 12)
if !ok {
t.Fatalf("expected restore ok")
}
if restorePos != 10 {
t.Fatalf("expected restorePos 10, got %d", restorePos)
}
rest, ok := cache.pendingRestore[1]
if !ok {
t.Fatalf("expected pending restore entry")
}
if rest.pos != 9 {
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
}
}
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.pruneAfter(20)
if store.lastPos != 20 {
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
}
_, pos, ok := store.bestIndex(25)
if !ok || pos != 20 {
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
}
_, pos, ok = store.bestIndex(35)
if !ok || pos != 20 {
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
}
}
func TestCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
cache := newTestCache()
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
// Simulate layer 0 requires both conv and recurrent checkpoints.
cache.convStates[0] = nil
cache.recurrentStates[0] = nil
store := cache.checkpointStore(0)
idx := store.record(9)
entry := &store.entries[idx]
entry.conv = map[int]ml.Tensor{0: nil}
// entry.recurrent intentionally missing
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
err := cache.Remove(1, 10, math.MaxInt32)
if !errors.Is(err, ErrNotSupported) {
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
}
}
func TestCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
cache := newTestCache()
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
store := cache.checkpointStore(0)
idx := store.record(9)
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
restore := cache.pendingRestore[1]
if !cache.restoreComplete(restore) {
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
}
}
func TestCacheRecurrentStateShapeValidation(t *testing.T) {
cache := newTestCache()
_, err := cache.RecurrentState(nil, 0, 3)
if !errors.Is(err, ErrInvalidRecurrentShape) {
t.Fatalf("expected ErrInvalidRecurrentShape, got %v", err)
}
}
func TestSlotCheckpointStoreShiftRange(t *testing.T) {
store := newSlotCheckpointStore(5)
store.record(1)
store.record(4)
store.record(7)
store.record(10)
store.shiftRange(2, 6)
var positions []int32
for i := range store.entries {
if store.entries[i].pos >= 0 {
positions = append(positions, store.entries[i].pos)
}
}
slices.Sort(positions)
want := []int32{1, 3, 6}
if !slices.Equal(positions, want) {
t.Fatalf("unexpected shifted positions: got=%v want=%v", positions, want)
}
if store.lastPos != 6 {
t.Fatalf("expected lastPos 6, got %d", store.lastPos)
}
}
func TestCacheRemoveMiddleShiftsCheckpoints(t *testing.T) {
cache := newTestCache()
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: 0, pos: 1}
store := cache.checkpointStore(0)
store.record(1)
store.record(4)
store.record(7)
store.record(10)
if err := cache.Remove(1, 2, 6); err != nil {
t.Fatalf("expected middle remove to succeed, got %v", err)
}
if _, ok := cache.pendingRestore[1]; ok {
t.Fatalf("expected pending restore to be cleared after middle remove")
}
var positions []int32
for i := range store.entries {
if store.entries[i].pos >= 0 {
positions = append(positions, store.entries[i].pos)
}
}
slices.Sort(positions)
want := []int32{1, 3, 6}
if !slices.Equal(positions, want) {
t.Fatalf("unexpected checkpoint positions after remove: got=%v want=%v", positions, want)
}
}
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.entries[0].conv = make(map[int]ml.Tensor)
store.entries[0].conv[0] = nil
store.entries[0].recurrent = make(map[int]ml.Tensor)
store.entries[0].recurrent[0] = nil
store.record(40)
if store.entries[0].conv == nil {
t.Fatalf("expected conv map to be preserved on reuse")
}
if store.entries[0].recurrent == nil {
t.Fatalf("expected recurrent map to be preserved on reuse")
}
if store.entries[0].pos != 40 {
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
}
}
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
store := newSlotCheckpointStore(2)
idx1 := store.record(10)
idx2 := store.record(20)
if idx1 != 0 || idx2 != 1 {
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
}
if store.size != 2 {
t.Fatalf("expected size 2, got %d", store.size)
}
_, pos1, ok1 := store.bestIndex(15)
_, pos2, ok2 := store.bestIndex(25)
if !ok1 || pos1 != 10 {
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
}
if !ok2 || pos2 != 20 {
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
}
}
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
store := newSlotCheckpointStore(0)
idx := store.record(10)
if idx != -1 {
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
}
_, _, ok := store.bestIndex(15)
if ok {
t.Fatalf("expected no checkpoint for empty buffer")
}
}
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.pruneAfter(5)
if store.size != 0 {
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
}
if store.lastPos != -1 {
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
}
_, _, ok := store.bestIndex(100)
if ok {
t.Fatalf("expected no checkpoint after pruning all")
}
}

View File

@@ -0,0 +1,37 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
Date: Sun, 22 Feb 2026 14:12:30 -0800
Subject: [PATCH] ggml-metal: guard mul_mat_id map0 and add ne20=22
specialization
---
ggml/src/ggml-metal/ggml-metal-ops.cpp | 3 ++-
ggml/src/ggml-metal/ggml-metal.metal | 1 +
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index 4ac135603..ac5ad53db 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
// ne21 = n_rows (batch size)
const int ne21_mm_id_min = 32;
- if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
+ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
+ (ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
//switch (op->src[0]->type) {
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index c37447a10..4f338aa13 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
+template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(

View File

@@ -163,6 +163,7 @@ type Tensor interface {
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
SSMConv(ctx Context, kernel Tensor) Tensor
SSMScan(ctx Context, x, dt, A, B, C, ids Tensor) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor

View File

@@ -1662,6 +1662,13 @@ func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) SSMScan(ctx ml.Context, x, dt, A, B, C, ids ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_ssm_scan(ctx.(*Context).ctx, t.t, x.(*Tensor).t, dt.(*Tensor).t, A.(*Tensor).t, B.(*Tensor).t, C.(*Tensor).t, ids.(*Tensor).t),
}
}
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{
b: t.b,

View File

@@ -12249,6 +12249,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(

View File

@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
// ne21 = n_rows (batch size)
const int ne21_mm_id_min = 32;
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
(ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
//switch (op->src[0]->type) {

View File

@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(

View File

@@ -67,6 +67,7 @@ func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
func (f *fakeTensor) SSMScan(ctx ml.Context, _, _, _, _, _, _ ml.Tensor) ml.Tensor { return f }
func (m *fakeBackend) Get(name string) ml.Tensor {
if slices.Contains(m.names, name) {

View File

@@ -1,410 +1,44 @@
package lfm2
import (
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
var _ kvcache.Cache = (*HybridCache)(nil)
var (
_ kvcache.Cache = (*HybridCache)(nil)
_ kvcache.CheckpointCache = (*HybridCache)(nil)
)
// HybridCache stores:
// - a standard causal KV cache for attention layers
// - a per-sequence recurrent conv state for shortconv layers
// HybridCache adapts the shared recurrent cache for LFM2:
// - KV attention cache is handled by the embedded causal cache
// - shortconv recurrent state uses conv slots [dConv, hiddenSize]
//
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
// This reuses shared checkpoint/restore logic for prefix mismatch recovery.
type HybridCache struct {
kv *kvcache.Causal
backend ml.Backend
dtype ml.DType
maxSequences int
hiddenSize int
dConv int
// slot mapping for recurrent state
slotForSeq map[int]int
refCount []int
freeSlots []int
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
// track any error from EnsureWritable to propagate later
writableError error
*kvcache.Recurrent
}
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
return &HybridCache{
kv: kvcache.NewCausalCache(shift),
hiddenSize: hiddenSize,
dConv: dConv,
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
}
}
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
Shift: shift,
ConvDim: dConv,
ConvChannels: hiddenSize,
RecurrentStateSize: 1, // LFM2 uses only conv state; keep a minimal recurrent buffer size.
CheckpointLogPrefix: "lfm2",
})
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *HybridCache) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *HybridCache) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
// Derive equal-length sequence layout for shortconv.
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
seqCounts := make(map[int]int)
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if _, ok := seqCounts[s]; !ok {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
if len(c.curSeqs) == 0 {
return nil
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
}
}
c.curSeqTokens = want
// When reserving memory for estimation, use fake slot assignments
// without modifying permanent state (slotForSeq, refCount)
if reserve {
c.curSlots = c.curSlots[:0]
slots := make([]int32, nSeqs)
for i := range nSeqs {
c.curSlots = append(c.curSlots, i)
slots[i] = int32(i)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
// Ensure slots exist for sequences in this batch
c.curSlots = c.curSlots[:0]
var newSlots []int // track newly allocated slots that need zeroing
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
c.curSlots = append(c.curSlots, slot)
}
// Zero conv state for newly allocated slots to clear stale data from previous sequences
if len(newSlots) > 0 {
c.zeroConvSlots(ctx, newSlots)
}
// Create a tensor for the current slots
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
// Reset writable state for new forward pass
c.writableEnsured = false
c.writableError = nil
return nil
}
func (c *HybridCache) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, kvcache.ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *HybridCache) freeSlot(slot int) {
// Bounds check before freeing
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// zeroConvSlots zeros the conv state for the given slots across all layers.
// This must be called when recycling slots to prevent stale state from affecting new sequences.
func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) {
if len(slots) == 0 || len(c.convStates) == 0 {
return
}
// Use input context for creating tensors
inputCtx := ctx.Input()
// Create slot indices tensor
slotIndices := make([]int32, len(slots))
for i, s := range slots {
slotIndices[i] = int32(s)
}
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
// Create zero tensor for the slots (SetRows requires F32 source)
zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots))
// Zero each layer's conv state for these slots
for _, buf := range c.convStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
// Returns an error if slot allocation fails.
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
// Copy existing conv state for all initialized layers
for _, buf := range c.convStates {
// buf: [dConv*hiddenSize, maxSlots]
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
// SetRows requires F32 source
srcF32 := src.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
}
}
// Rebuild current slots tensor
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
// On the first write to dst, EnsureWritable will create a private slot.
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
// Bounds check before decrementing
if dstSlot >= 0 && dstSlot < len(c.refCount) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
// src may not have a slot yet; dst will allocate on demand
return
}
// Bounds check before incrementing
if srcSlot >= 0 && srcSlot < len(c.refCount) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *HybridCache) CanResume(seq int, pos int32) bool {
return c.kv.CanResume(seq, pos)
}
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
// For recurrent state, any removal invalidates the state because
// the state at position N depends on all previous positions.
// Drop the slot mapping so it resets on next use.
slot, ok := c.slotForSeq[seq]
if !ok {
return nil
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
return &HybridCache{Recurrent: base}
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.curSlotsInput
return c.SlotsTensor()
}
func (c *HybridCache) seqTokens() int {
return c.curSeqTokens
return c.SeqTokens()
}
func (c *HybridCache) numSeqs() int {
return len(c.curSeqs)
}
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
c.convStates[layer] = buf
return buf
}
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
// Returns an error if copy-on-write allocation fails.
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
}
// UpdateConvState writes a new conv state for current batch sequences.
// newState must have shape [dConv, hiddenSize, nSeqs].
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
// SetRows requires F32 source
srcF32 := src.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
}
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *HybridCache) Seqs() []int {
return slices.Clone(c.curSeqs)
return c.NumSeqs()
}

View File

@@ -4,441 +4,39 @@ import (
"testing"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
// TestHybridCache tests verify the slot management logic of HybridCache.
// These tests focus on the recurrent state slot allocation, reference counting,
// and copy-on-write semantics without requiring a full ML backend.
func TestHybridCache_New(t *testing.T) {
cache := NewHybridCache(nil, 512, 2)
if cache == nil {
t.Fatal("expected cache to be created")
}
// createSlotOnlyCache creates a HybridCache with only the slot management
// fields initialized. Used to test slot logic in isolation.
func createSlotOnlyCache(maxSequences int) *HybridCache {
return &HybridCache{
hiddenSize: 256,
dConv: 3,
maxSequences: maxSequences,
refCount: make([]int, maxSequences),
freeSlots: initFreeSlots(maxSequences),
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
if cache.Recurrent == nil {
t.Fatal("expected embedded recurrent cache to be created")
}
}
func initFreeSlots(n int) []int {
slots := make([]int, 0, n)
for i := n - 1; i >= 0; i-- {
slots = append(slots, i)
}
return slots
}
func TestHybridCache_ImplementsCheckpointCache(t *testing.T) {
cache := NewHybridCache(nil, 512, 2)
func TestHybridCache_SlotAllocation(t *testing.T) {
cache := createSlotOnlyCache(4)
// Verify initial state
if len(cache.freeSlots) != 4 {
t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots))
}
// Allocate all slots
for range 4 {
slot, err := cache.allocSlot()
if err != nil {
t.Fatalf("allocSlot failed: %v", err)
}
cache.refCount[slot] = 1
}
// Should be full now
if len(cache.freeSlots) != 0 {
t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots))
}
// Trying to allocate another should fail
_, err := cache.allocSlot()
if err != kvcache.ErrKvCacheFull {
t.Errorf("expected ErrKvCacheFull, got %v", err)
if _, ok := any(cache).(kvcache.CheckpointCache); !ok {
t.Fatal("expected HybridCache to implement CheckpointCache")
}
}
func TestHybridCache_SlotReuse(t *testing.T) {
cache := createSlotOnlyCache(4)
func TestHybridCache_DefaultBatchState(t *testing.T) {
cache := NewHybridCache(nil, 512, 2)
// Allocate a slot
slot1, _ := cache.allocSlot()
cache.refCount[slot1] = 1
// Free it
cache.refCount[slot1] = 0
cache.freeSlot(slot1)
// Allocate again - should get the same slot back (LIFO)
slot2, _ := cache.allocSlot()
if slot2 != slot1 {
t.Errorf("expected slot %d to be reused, got %d", slot1, slot2)
}
}
func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Simulate sharing slot with seq 2 (copy-on-write style)
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Should share the same slot
if cache.slotForSeq[2] != slot1 {
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
if got := cache.numSeqs(); got != 0 {
t.Fatalf("expected 0 sequences before StartForward, got %d", got)
}
// Ref count should be 2
if cache.refCount[slot1] != 2 {
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Share with seq 2
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Unshare seq 2
cache.refCount[slot1]--
delete(cache.slotForSeq, 2)
// Ref count should be back to 1
if cache.refCount[slot1] != 1 {
t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1])
if got := cache.seqTokens(); got != 0 {
t.Fatalf("expected 0 sequence tokens before StartForward, got %d", got)
}
// Seq 2 should no longer have a slot
if _, ok := cache.slotForSeq[2]; ok {
t.Error("seq 2 should not have a slot after unshare")
}
}
func TestHybridCache_SlotFreeWhenUnused(t *testing.T) {
cache := createSlotOnlyCache(4)
initialFreeSlots := len(cache.freeSlots)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Free the slot when refCount drops to 0
cache.refCount[slot1]--
if cache.refCount[slot1] <= 0 {
cache.refCount[slot1] = 0
cache.freeSlot(slot1)
}
delete(cache.slotForSeq, 1)
// Slot should be freed
if len(cache.freeSlots) != initialFreeSlots {
t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots))
}
// Ref count should be 0
if cache.refCount[slot1] != 0 {
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_SlotOverwrite(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slots for seq 1 and seq 2
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
slot2, _ := cache.allocSlot()
cache.slotForSeq[2] = slot2
cache.refCount[slot2] = 1
initialFreeSlots := len(cache.freeSlots)
// Simulate overwriting seq 2's slot with slot1 (sharing)
// First free the old slot
cache.refCount[slot2]--
if cache.refCount[slot2] <= 0 {
cache.refCount[slot2] = 0
cache.freeSlot(slot2)
}
// Then share slot1
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Seq 2 should now share slot1
if cache.slotForSeq[2] != slot1 {
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
}
// Old slot2 should be freed
if len(cache.freeSlots) != initialFreeSlots+1 {
t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots))
}
}
func TestHybridCache_BoundsChecking(t *testing.T) {
cache := createSlotOnlyCache(4)
// Test freeing invalid slot (should not panic)
cache.freeSlot(-1)
cache.freeSlot(100) // out of bounds
// freeSlot does bounds checking, so invalid slots should be ignored
if len(cache.freeSlots) != 4 {
t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots))
}
}
func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) {
cache := createSlotOnlyCache(8)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Fork to seq 2, 3, 4 (all share slot1)
for _, seq := range []int{2, 3, 4} {
cache.slotForSeq[seq] = slot1
cache.refCount[slot1]++
}
// Ref count should be 4
if cache.refCount[slot1] != 4 {
t.Errorf("expected refCount 4, got %d", cache.refCount[slot1])
}
// Remove seq 2, 3
for _, seq := range []int{2, 3} {
delete(cache.slotForSeq, seq)
cache.refCount[slot1]--
}
if cache.refCount[slot1] != 2 {
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
}
// Slot should still be allocated (not in free list)
found := false
for _, s := range cache.freeSlots {
if s == slot1 {
found = true
break
}
}
if found {
t.Error("slot1 should not be in free list yet")
}
// Remove remaining sequences
for _, seq := range []int{1, 4} {
delete(cache.slotForSeq, seq)
cache.refCount[slot1]--
}
if cache.refCount[slot1] != 0 {
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_ChainedSharing(t *testing.T) {
cache := createSlotOnlyCache(8)
// Create seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Share 1 -> 2
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Share 2 -> 3 (should still share slot1)
cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1
cache.refCount[slot1]++
// All should share slot1
if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 {
t.Error("all sequences should share slot1")
}
if cache.refCount[slot1] != 3 {
t.Errorf("expected refCount 3, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_CacheParameters(t *testing.T) {
cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5
if cache.hiddenSize != 512 {
t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize)
}
if cache.dConv != 5 {
t.Errorf("expected dConv 5, got %d", cache.dConv)
}
}
func TestHybridCache_NumSeqs(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially no sequences
if cache.numSeqs() != 0 {
t.Errorf("expected 0 seqs, got %d", cache.numSeqs())
}
// Manually set up current batch state
cache.curSeqs = []int{1, 2, 3}
if cache.numSeqs() != 3 {
t.Errorf("expected 3 seqs, got %d", cache.numSeqs())
}
}
func TestHybridCache_SeqTokens(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially 0
if cache.seqTokens() != 0 {
t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens())
}
// Manually set up current batch state
cache.curSeqTokens = 16
if cache.seqTokens() != 16 {
t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens())
}
}
// Test that Seqs returns a clone of curSeqs
func TestHybridCache_Seqs_ReturnsClone(t *testing.T) {
cache := createSlotOnlyCache(4)
cache.curSeqs = []int{1, 2, 3}
seqs := cache.Seqs()
// Modify returned slice
seqs[0] = 999
// Original should be unchanged
if cache.curSeqs[0] != 1 {
t.Error("Seqs should return a clone, not the original slice")
}
}
func TestHybridCache_IsSupportedForBatch(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially not supported (no batch set up)
if cache.IsSupportedForBatch() {
t.Error("expected IsSupportedForBatch to be false initially")
}
// Set up a valid batch
cache.curSeqTokens = 1
cache.curSeqs = []int{1}
if !cache.IsSupportedForBatch() {
t.Error("expected IsSupportedForBatch to be true with valid batch")
}
}
func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) {
cache := createSlotOnlyCache(4)
// zeroConvSlots should handle empty slots without panicking
cache.zeroConvSlots(nil, nil)
cache.zeroConvSlots(nil, []int{})
// zeroConvSlots should handle empty convStates without panicking
cache.zeroConvSlots(nil, []int{0, 1, 2})
}
func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Free the slot (simulating sequence removal)
cache.refCount[slot1]--
cache.freeSlot(slot1)
delete(cache.slotForSeq, 1)
// Verify slot is in free list
if len(cache.freeSlots) != 4 {
t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots))
}
// Allocate for new seq 2 - should get recycled slot
slot2, _ := cache.allocSlot()
if slot2 != slot1 {
t.Errorf("expected recycled slot %d, got %d", slot1, slot2)
}
// This recycled slot would need zeroing in the real implementation
// The actual zeroing is tested via integration tests since it requires ML context
}
func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) {
cache := createSlotOnlyCache(4)
// Simulate the slot allocation flow from StartForward
// When a sequence doesn't have a slot, it gets allocated and tracked as "new"
newSlots := []int{}
// Seq 1 doesn't have a slot - allocate and track
seq := 1
if _, ok := cache.slotForSeq[seq]; !ok {
slot, err := cache.allocSlot()
if err != nil {
t.Fatalf("allocSlot failed: %v", err)
}
cache.slotForSeq[seq] = slot
cache.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
// Verify newSlots contains the allocated slot
if len(newSlots) != 1 {
t.Errorf("expected 1 new slot, got %d", len(newSlots))
}
// Seq 1 already has a slot - should NOT be tracked as new
newSlots2 := []int{}
if _, ok := cache.slotForSeq[seq]; !ok {
slot, _ := cache.allocSlot()
cache.slotForSeq[seq] = slot
cache.refCount[slot] = 1
newSlots2 = append(newSlots2, slot)
}
// Verify no new slots for existing sequence
if len(newSlots2) != 0 {
t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2))
t.Fatal("expected unsupported batch layout before StartForward")
}
}

View File

@@ -1,7 +1,11 @@
package lfm2
import (
"bytes"
"cmp"
"errors"
"fmt"
"image"
"math"
"github.com/ollama/ollama/fs"
@@ -25,8 +29,19 @@ type Options struct {
// per-layer head counts (LFM2 alternates attention and recurrent layers)
numHeadsByLayer []int
numKVHeadsByLayer []int
// MoE config
numExperts int
numExpertsUsed int
normTopKProb bool
expertGatingFunc uint32
}
const (
expertGatingFuncSoftmax = uint32(0)
expertGatingFuncSigmoid = uint32(2)
)
func (o Options) headDimValue() int {
// Head dim is shared across layers; fall back to first attention layer head count.
for _, h := range o.numHeadsByLayer {
@@ -67,18 +82,138 @@ type Model struct {
OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
VisionModel *VisionModel `gguf:"v"`
VisionProjector *VisionProjector `gguf:"mm"`
ImageProcessor ImageProcessor
imageTokenID int32
imageStartToken int32
imageEndToken int32
imageThumbnailID int32
imageRowColIDs map[imageGridPos]int32
useSpecialTokens bool
projectorOptions VisionProjectorOptions
Options
}
func New(c fs.Config) (model.Model, error) {
if c.Uint("expert_count") > 0 {
return nil, model.ErrUnsupportedModel
var _ model.MultimodalProcessor = (*Model)(nil)
type imageGridPos struct {
row int
col int
}
type visionEmbeddingLayout struct {
rows int
cols int
hasThumbnail bool
}
type visionChunkData struct {
tokens int
row int
col int
thumbnail bool
layout *visionEmbeddingLayout
}
func (m *Model) Validate() error {
if m.TokenEmbedding == nil {
return errors.New("lfm2: missing token_embd tensor")
}
if m.OutputNorm == nil {
return errors.New("lfm2: missing output_norm tensor")
}
if m.Output == nil {
return errors.New("lfm2: missing output tensor")
}
for i, layer := range m.Layers {
if layer.AttentionNorm == nil {
return fmt.Errorf("lfm2: missing blk.%d.attn_norm tensor", i)
}
if layer.MLPNorm == nil {
return fmt.Errorf("lfm2: missing blk.%d.ffn_norm tensor", i)
}
switch ff := layer.MLP.(type) {
case nil:
return fmt.Errorf("lfm2: missing blk.%d feed-forward tensors", i)
case *denseMLP:
if ff.Up == nil || ff.Down == nil || ff.Gate == nil {
return fmt.Errorf("lfm2: missing blk.%d dense feed-forward tensors", i)
}
case *sparseMLP:
if ff.Router == nil || ff.Gate == nil || ff.Up == nil || ff.Down == nil {
return fmt.Errorf("lfm2: missing blk.%d sparse feed-forward tensors", i)
}
default:
return fmt.Errorf("lfm2: unsupported feed-forward type at blk.%d", i)
}
switch op := layer.Operator.(type) {
case *Attention:
if op == nil || op.Query == nil || op.Key == nil || op.Value == nil || op.Output == nil || op.QueryNorm == nil || op.KeyNorm == nil {
return fmt.Errorf("lfm2: missing blk.%d attention tensors", i)
}
case *ShortConv:
if op == nil || op.Conv == nil || op.Conv.Weight == nil || op.InProj == nil || op.OutProj == nil {
return fmt.Errorf("lfm2: missing blk.%d shortconv tensors", i)
}
default:
return fmt.Errorf("lfm2: unsupported operator at blk.%d", i)
}
}
if m.VisionModel != nil {
if m.VisionModel.PatchEmbedding == nil {
return errors.New("lfm2: missing vision patch embedding tensors")
}
if m.VisionModel.PositionEmbedding == nil {
return errors.New("lfm2: missing vision position embedding tensors")
}
if m.VisionModel.PostLayerNorm == nil {
return errors.New("lfm2: missing vision post layer norm tensors")
}
if len(m.VisionModel.Layers) == 0 {
return errors.New("lfm2: missing vision encoder layers")
}
for i, layer := range m.VisionModel.Layers {
if layer.LayerNorm1 == nil || layer.LayerNorm2 == nil || layer.SelfAttention == nil || layer.MLP == nil {
return fmt.Errorf("lfm2: missing vision layer tensors at v.blk.%d", i)
}
if layer.SelfAttention.Query == nil || layer.SelfAttention.Key == nil || layer.SelfAttention.Value == nil || layer.SelfAttention.Output == nil {
return fmt.Errorf("lfm2: missing vision attention tensors at v.blk.%d", i)
}
if layer.MLP.Up == nil || layer.MLP.Down == nil {
return fmt.Errorf("lfm2: missing vision feed-forward tensors at v.blk.%d", i)
}
}
if m.VisionProjector == nil || m.VisionProjector.Linear1 == nil || m.VisionProjector.Linear2 == nil {
return errors.New("lfm2: missing multimodal projector tensors")
}
}
return nil
}
func New(c fs.Config) (model.Model, error) {
if c.String("tokenizer.ggml.model") != "gpt2" {
return nil, model.ErrUnsupportedTokenizer
}
numExperts := int(c.Uint("expert_count"))
isMoE := numExperts > 0
numExpertsUsed := int(c.Uint("expert_used_count"))
if isMoE {
if numExperts <= 0 {
return nil, fmt.Errorf("lfm2: invalid expert_count=%d", numExperts)
}
if numExpertsUsed <= 0 || numExpertsUsed > numExperts {
return nil, fmt.Errorf("lfm2: invalid expert_used_count=%d for expert_count=%d", numExpertsUsed, numExperts)
}
}
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -105,8 +240,16 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
VisionProjector: &VisionProjector{},
imageRowColIDs: make(map[imageGridPos]int32),
projectorOptions: VisionProjectorOptions{
scaleFactor: int(c.Uint("vision.projector.scale_factor", 2)),
useLayerNorm: c.Bool("vision.projector.use_layernorm", false),
},
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),
@@ -116,9 +259,65 @@ func New(c fs.Config) (model.Model, error) {
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
numExperts: numExperts,
numExpertsUsed: numExpertsUsed,
normTopKProb: c.Bool("norm_top_k_prob", true),
expertGatingFunc: c.Uint("expert_gating_func", expertGatingFuncSoftmax),
},
}
lookupTokenID := func(token string) int32 {
for i, t := range vocabulary.Values {
if t == token {
return int32(i)
}
}
return 0
}
resolveTokenID := func(explicitKey, token string, fallback uint32) int32 {
if explicitKey != "" {
if id := c.Uint(explicitKey); id != 0 {
return int32(id)
}
}
if tokenID := lookupTokenID(token); tokenID != 0 {
return tokenID
}
return int32(fallback)
}
m.imageTokenID = resolveTokenID("vision.image_token_id", "<image>", 396)
m.imageStartToken = resolveTokenID("vision.image_start_token_id", "<|image_start|>", 0)
m.imageEndToken = resolveTokenID("vision.image_end_token_id", "<|image_end|>", 0)
m.imageThumbnailID = resolveTokenID("vision.image_thumbnail_token_id", "<|img_thumbnail|>", 0)
m.useSpecialTokens = c.Bool("vision.use_image_special_tokens", true)
maxGridTokens := int(c.Uint("vision.max_tiles", 10))
if maxGridTokens <= 0 {
maxGridTokens = 10
}
for row := 1; row <= maxGridTokens; row++ {
for col := 1; col <= maxGridTokens; col++ {
token := fmt.Sprintf("<|img_row_%d_col_%d|>", row, col)
if tokenID := lookupTokenID(token); tokenID > 0 {
m.imageRowColIDs[imageGridPos{row: row, col: col}] = tokenID
}
}
}
if !m.useSpecialTokens {
m.imageStartToken = 0
m.imageEndToken = 0
m.imageThumbnailID = 0
m.imageRowColIDs = map[imageGridPos]int32{}
}
if c.Uint("vision.block_count") == 0 {
m.VisionModel = nil
m.VisionProjector = nil
}
type headCounts interface {
HeadCount() []uint64
HeadCountKV() []uint64
@@ -133,6 +332,14 @@ func New(c fs.Config) (model.Model, error) {
m.numHeadsByLayer = make([]int, len(m.Layers))
m.numKVHeadsByLayer = make([]int, len(m.Layers))
leadingDenseBlockCount := int(c.Uint("leading_dense_block_count"))
if leadingDenseBlockCount < 0 {
leadingDenseBlockCount = 0
}
if leadingDenseBlockCount > len(m.Layers) {
leadingDenseBlockCount = len(m.Layers)
}
for i := range m.Layers {
m.numHeadsByLayer[i] = int(headCount[i])
m.numKVHeadsByLayer[i] = int(headCountKV[i])
@@ -142,6 +349,12 @@ func New(c fs.Config) (model.Model, error) {
} else {
m.Layers[i].Operator = &Attention{}
}
if isMoE && i >= leadingDenseBlockCount {
m.Layers[i].MLP = &sparseMLP{}
} else {
m.Layers[i].MLP = &denseMLP{}
}
}
lCache := int(c.Uint("shortconv.l_cache"))
@@ -188,22 +401,74 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
return sa.Output.Forward(ctx, attention)
}
type MLP struct {
type FeedForward interface {
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
}
type denseMLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
func (mlp *denseMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type sparseMLP struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
Bias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
}
func (mlp *sparseMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
// hiddenState: [hidden, tokens]
routerLogits := mlp.Router.Forward(ctx, hiddenState)
probs := routerLogits.Softmax(ctx)
if opts.expertGatingFunc == expertGatingFuncSigmoid {
probs = routerLogits.Sigmoid(ctx)
}
selectionProbs := probs
if mlp.Bias != nil {
selectionProbs = selectionProbs.Add(ctx, mlp.Bias)
}
selectedExperts := selectionProbs.TopK(ctx, opts.numExpertsUsed)
routingWeights := probs.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
if opts.normTopKProb {
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
weightsSum := routingWeights.SumRows(ctx)
weightsSum = weightsSum.Clamp(ctx, 1e-6, float32(math.Inf(1)))
routingWeights = routingWeights.Div(ctx, weightsSum)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
}
// Build routing-weights branch early to enable topk-MoE fusion.
ctx.Forward(routingWeights)
hiddenState3D := hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
experts := mlp.Gate.Forward(ctx, hiddenState3D, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenState3D, selectedExperts))
experts = mlp.Down.Forward(ctx, experts, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
nextState := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextState = nextState.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextState
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Operator Operator
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
MLP FeedForward
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor {
@@ -229,10 +494,233 @@ func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tenso
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
func multimodalTokenCount(mm input.Multimodal) int {
if mm.Tensor != nil {
return mm.Tensor.Dim(1)
}
switch data := mm.Data.(type) {
case int:
return data
case int32:
return int(data)
case visionChunkData:
return data.tokens
case *visionChunkData:
if data != nil {
return data.tokens
}
}
return 0
}
func multimodalChunkInfo(mm input.Multimodal) visionChunkData {
switch data := mm.Data.(type) {
case visionChunkData:
return data
case *visionChunkData:
if data != nil {
return *data
}
}
return visionChunkData{
tokens: multimodalTokenCount(mm),
}
}
func multimodalLayout(mm []input.Multimodal) visionEmbeddingLayout {
layout := visionEmbeddingLayout{rows: 1, cols: 1}
if len(mm) == 0 {
return layout
}
first := multimodalChunkInfo(mm[0])
if first.layout != nil {
return *first.layout
}
return layout
}
func (m *Model) imageRowColToken(row, col int) int32 {
if row <= 0 || col <= 0 {
return 0
}
return m.imageRowColIDs[imageGridPos{row: row, col: col}]
}
func (m *Model) appendImageChunk(result []*input.Input, chunk input.Multimodal, imageToken int32, hash uint64) ([]*input.Input, error) {
tokenCount := multimodalTokenCount(chunk)
if tokenCount <= 0 {
return nil, errors.New("lfm2: multimodal input has no tokens")
}
result = append(result, &input.Input{
Token: imageToken,
Multimodal: []input.Multimodal{chunk},
MultimodalHash: hash,
SameBatch: tokenCount - 1,
})
for range tokenCount - 1 {
result = append(result, &input.Input{Token: imageToken})
}
return result, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if m.VisionModel == nil || m.VisionProjector == nil || len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
processedImages, layout, err := m.ImageProcessor.ProcessImage(img)
if err != nil {
return nil, err
}
if m.ImageProcessor.patchSize <= 0 {
return nil, errors.New("lfm2: invalid vision patch size")
}
layoutInfo := &visionEmbeddingLayout{
rows: layout.rows,
cols: layout.cols,
hasThumbnail: layout.hasThumbnail,
}
mm := make([]input.Multimodal, 0, len(processedImages))
for i, processed := range processedImages {
patches := visionPatchGrid{
Width: processed.size.X / m.ImageProcessor.patchSize,
Height: processed.size.Y / m.ImageProcessor.patchSize,
}
if patches.Width == 0 || patches.Height == 0 {
return nil, errors.New("lfm2: invalid resized image dimensions")
}
pixelValues := ctx.Input().FromFloats(processed.data, processed.size.X, processed.size.Y, m.ImageProcessor.numChannels)
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, patches)
projected := m.VisionProjector.Forward(ctx, visionOutputs, patches, m.projectorOptions)
chunk := visionChunkData{
tokens: projected.Dim(1),
row: processed.row,
col: processed.col,
thumbnail: processed.thumbnail,
}
if i == 0 {
chunk.layout = layoutInfo
}
mm = append(mm, input.Multimodal{
Tensor: projected,
Data: chunk,
})
}
return mm, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
imageToken := m.imageTokenID
if imageToken == 0 {
imageToken = 396
}
useSpecialTokens := m.useSpecialTokens || m.imageStartToken > 0 || m.imageEndToken > 0 || m.imageThumbnailID > 0 || len(m.imageRowColIDs) > 0
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
continue
}
layout := multimodalLayout(inp.Multimodal)
if layout.rows <= 0 {
layout.rows = 1
}
if layout.cols <= 0 {
layout.cols = 1
}
tiles := layout.rows * layout.cols
multitile := tiles > 1
if useSpecialTokens && m.imageStartToken > 0 {
result = append(result, &input.Input{Token: m.imageStartToken})
}
for i, mm := range inp.Multimodal {
chunk := multimodalChunkInfo(mm)
if chunk.tokens <= 0 {
chunk.tokens = multimodalTokenCount(mm)
}
if multitile && !chunk.thumbnail && chunk.row == 0 && chunk.col == 0 && i < tiles {
chunk.row = i/layout.cols + 1
chunk.col = i%layout.cols + 1
}
if multitile && layout.hasThumbnail && i == tiles {
chunk.thumbnail = true
}
if useSpecialTokens && multitile {
if chunk.thumbnail {
if m.imageThumbnailID > 0 {
result = append(result, &input.Input{Token: m.imageThumbnailID})
}
} else if marker := m.imageRowColToken(chunk.row, chunk.col); marker > 0 {
result = append(result, &input.Input{Token: marker})
}
}
var err error
result, err = m.appendImageChunk(result, input.Multimodal{
Tensor: mm.Tensor,
Data: chunk,
}, imageToken, inp.MultimodalHash)
if err != nil {
return nil, err
}
}
if useSpecialTokens && m.imageEndToken > 0 {
result = append(result, &input.Input{Token: m.imageEndToken})
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
if len(batch.Multimodal) > 0 {
// We splice vision embeddings into token embeddings in-place; duplicate to
// avoid aliasing the raw embedding output graph.
hiddenState = hiddenState.Duplicate(ctx)
}
for _, mm := range batch.Multimodal {
offset := mm.Index
for _, multimodal := range mm.Multimodal {
if multimodal.Tensor == nil {
continue
}
visionOutputs := multimodal.Tensor
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, offset*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
offset += visionOutputs.Dim(1)
}
}
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
@@ -251,4 +739,5 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func init() {
model.Register("lfm2", New)
model.Register("lfm2moe", New)
}

View File

@@ -0,0 +1,160 @@
package lfm2
import (
"testing"
"github.com/ollama/ollama/model/input"
)
func TestPostTokenizeWithSpecialImageTokens(t *testing.T) {
m := &Model{
imageTokenID: 396,
imageStartToken: 2,
imageEndToken: 3,
useSpecialTokens: true,
}
in := []*input.Input{
{Token: 11},
{Multimodal: []input.Multimodal{{Data: 64}}, MultimodalHash: 123},
{Token: 12},
}
out, err := m.PostTokenize(in)
if err != nil {
t.Fatalf("PostTokenize returned error: %v", err)
}
if len(out) != 68 {
t.Fatalf("expected 68 tokens, got %d", len(out))
}
if out[0].Token != 11 {
t.Fatalf("out[0].Token = %d, want 11", out[0].Token)
}
if out[1].Token != 2 {
t.Fatalf("out[1].Token = %d, want 2", out[1].Token)
}
firstImage := out[2]
if firstImage.Token != 396 {
t.Fatalf("out[2].Token = %d, want 396", firstImage.Token)
}
if len(firstImage.Multimodal) != 1 {
t.Fatalf("expected multimodal payload on first image token")
}
if firstImage.MultimodalHash != 123 {
t.Fatalf("out[2].MultimodalHash = %d, want 123", firstImage.MultimodalHash)
}
if firstImage.SameBatch != 63 {
t.Fatalf("out[2].SameBatch = %d, want 63", firstImage.SameBatch)
}
for i := 3; i < 66; i++ {
if out[i].Token != 396 {
t.Fatalf("out[%d].Token = %d, want 396", i, out[i].Token)
}
if len(out[i].Multimodal) != 0 {
t.Fatalf("out[%d] should not carry multimodal payload", i)
}
}
if out[66].Token != 3 {
t.Fatalf("out[66].Token = %d, want 3", out[66].Token)
}
if out[67].Token != 12 {
t.Fatalf("out[67].Token = %d, want 12", out[67].Token)
}
}
func TestPostTokenizeWithoutSpecialImageTokens(t *testing.T) {
m := &Model{
imageTokenID: 777,
useSpecialTokens: false,
}
in := []*input.Input{
{Multimodal: []input.Multimodal{{Data: 5}}, MultimodalHash: 9},
}
out, err := m.PostTokenize(in)
if err != nil {
t.Fatalf("PostTokenize returned error: %v", err)
}
if len(out) != 5 {
t.Fatalf("expected 5 tokens, got %d", len(out))
}
if out[0].Token != 777 || out[0].SameBatch != 4 || len(out[0].Multimodal) != 1 {
t.Fatalf("unexpected first token: %+v", *out[0])
}
for i := 1; i < 5; i++ {
if out[i].Token != 777 {
t.Fatalf("out[%d].Token = %d, want 777", i, out[i].Token)
}
if len(out[i].Multimodal) != 0 {
t.Fatalf("out[%d] should not carry multimodal payload", i)
}
}
}
func TestPostTokenizeMultiTileLayoutTokens(t *testing.T) {
m := &Model{
imageTokenID: 396,
imageStartToken: 498,
imageEndToken: 499,
imageThumbnailID: 497,
imageRowColIDs: map[imageGridPos]int32{
{row: 1, col: 1}: 397,
{row: 1, col: 2}: 398,
},
useSpecialTokens: true,
}
layout := &visionEmbeddingLayout{rows: 1, cols: 2, hasThumbnail: true}
in := []*input.Input{{
Multimodal: []input.Multimodal{
{Data: visionChunkData{tokens: 3, row: 1, col: 1, layout: layout}},
{Data: visionChunkData{tokens: 3, row: 1, col: 2}},
{Data: visionChunkData{tokens: 2, thumbnail: true}},
},
MultimodalHash: 1,
}}
out, err := m.PostTokenize(in)
if err != nil {
t.Fatalf("PostTokenize returned error: %v", err)
}
got := make([]int32, len(out))
for i := range out {
got[i] = out[i].Token
}
want := []int32{
498, // <|image_start|>
397, // <|img_row_1_col_1|>
396, 396, 396,
398, // <|img_row_1_col_2|>
396, 396, 396,
497, // <|img_thumbnail|>
396, 396,
499, // <|image_end|>
}
if len(got) != len(want) {
t.Fatalf("len(out) = %d, want %d", len(got), len(want))
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("out[%d].Token = %d, want %d", i, got[i], want[i])
}
}
if len(out[2].Multimodal) != 1 || len(out[6].Multimodal) != 1 || len(out[10].Multimodal) != 1 {
t.Fatalf("expected multimodal payload on first token of each chunk")
}
if out[2].SameBatch != 2 || out[6].SameBatch != 2 || out[10].SameBatch != 1 {
t.Fatalf("unexpected SameBatch values: [%d %d %d]", out[2].SameBatch, out[6].SameBatch, out[10].SameBatch)
}
}

View File

@@ -0,0 +1,184 @@
package lfm2
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
const lfm2VisionBatchSize = 1
type visionPatchGrid struct {
Width int
Height int
}
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output,alt:attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), lfm2VisionBatchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), lfm2VisionBatchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), lfm2VisionBatchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), lfm2VisionBatchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
return mlp.Down.Forward(ctx, mlp.Up.Forward(ctx, hiddenState).GELU(ctx))
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"ln2"`
MLP *VisionMLP
}
func (l *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState)
return hiddenState.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize, numHeads int
imageSize, patchSize int
eps float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
PostLayerNorm *nn.LayerNorm `gguf:"post_ln"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, patches visionPatchGrid) ml.Tensor {
numPatches := patches.Width * patches.Height
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
if m.PositionEmbedding != nil {
posTokens := m.PositionEmbedding.Weight.Dim(1)
source := int(math.Sqrt(float64(posTokens)))
var positionEmbeddings ml.Tensor
if source > 0 && source*source == posTokens && (source != patches.Width || source != patches.Height) {
// SigLIP2 NAFlex-style position interpolation for variable image sizes.
positionIDs := ctx.Arange(0, float32(posTokens), 1, ml.DTypeI32)
positionEmbeddings = m.PositionEmbedding.Forward(ctx, positionIDs)
positionEmbeddings = positionEmbeddings.Reshape(ctx, -1, source, source)
positionEmbeddings = positionEmbeddings.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
positionEmbeddings = positionEmbeddings.Interpolate(ctx, [4]int{
patches.Width,
patches.Height,
hiddenState.Dim(0),
1,
}, ml.SamplingModeBilinear)
positionEmbeddings = positionEmbeddings.Permute(ctx, 1, 2, 0, 3)
positionEmbeddings = positionEmbeddings.Contiguous(ctx, -1, patches.Width*patches.Height)
} else {
positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeI32)
positionEmbeddings = m.PositionEmbedding.Forward(ctx, positionIDs)
}
hiddenState = hiddenState.Add(ctx, positionEmbeddings)
}
for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
}
return m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length", 1152)),
numHeads: int(c.Uint("vision.attention.head_count", 16)),
imageSize: int(c.Uint("vision.image_size", 256)),
patchSize: int(c.Uint("vision.patch_size", 16)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
},
}
}
type VisionProjector struct {
LayerNorm *nn.LayerNorm `gguf:"layer_norm"`
Linear1 *nn.Linear `gguf:"1"`
Linear2 *nn.Linear `gguf:"2"`
}
type VisionProjectorOptions struct {
scaleFactor int
useLayerNorm bool
}
func (p *VisionProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, patches visionPatchGrid, opts VisionProjectorOptions) ml.Tensor {
hiddenSize := visionOutputs.Dim(0)
featureMap := visionOutputs
merge := max(opts.scaleFactor, 1)
if merge > 1 {
width := patches.Width
height := patches.Height
featureMap = featureMap.Reshape(ctx, hiddenSize, width, height)
// Match llama.cpp patch merger: pad spatial dims to merge factor.
padWidth := (merge - width%merge) % merge
padHeight := (merge - height%merge) % merge
if padWidth != 0 || padHeight != 0 {
featureMap = featureMap.Pad(ctx, 0, padWidth, padHeight, 0)
width += padWidth
height += padHeight
}
featureMap = featureMap.Reshape(ctx, hiddenSize*merge, width/merge, height)
featureMap = featureMap.Permute(ctx, 0, 2, 1).Contiguous(ctx, hiddenSize*merge*merge, height/merge, width/merge)
featureMap = featureMap.Permute(ctx, 0, 2, 1).Contiguous(ctx)
featureMap = featureMap.Contiguous(ctx, featureMap.Dim(0), featureMap.Dim(1)*featureMap.Dim(2))
}
if opts.useLayerNorm && p.LayerNorm != nil {
featureMap = p.LayerNorm.Forward(ctx, featureMap, 1e-5)
}
featureMap = p.Linear1.Forward(ctx, featureMap).GELU(ctx)
return p.Linear2.Forward(ctx, featureMap)
}

View File

@@ -0,0 +1,260 @@
package lfm2
import (
"image"
stdimage "image/draw"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize, patchSize, numChannels int
downsampleFactor int
imageMean, imageStd [3]float32
doImageSplitting bool
minTiles int
maxTiles int
useThumbnail bool
tileSize int
minImageTokens int
maxImageTokens int
maxPixelsTolerance float64
}
type processedVisionImage struct {
data []float32
size image.Point
row int
col int
thumbnail bool
}
type processedVisionLayout struct {
rows int
cols int
hasThumbnail bool
}
func newImageProcessor(c fs.Config) ImageProcessor {
mean := c.Floats("vision.image_mean")
std := c.Floats("vision.image_std")
processor := ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 256)),
patchSize: int(c.Uint("vision.patch_size", 16)),
numChannels: int(c.Uint("vision.num_channels", 3)),
downsampleFactor: int(c.Uint("vision.projector.scale_factor", 2)),
imageMean: [3]float32{0.5, 0.5, 0.5},
imageStd: [3]float32{0.5, 0.5, 0.5},
doImageSplitting: c.Bool("vision.do_image_splitting", true),
minTiles: int(c.Uint("vision.min_tiles", 2)),
maxTiles: int(c.Uint("vision.max_tiles", 10)),
useThumbnail: c.Bool("vision.use_thumbnail", true),
tileSize: int(c.Uint("vision.tile_size", 512)),
minImageTokens: int(c.Uint("vision.min_image_tokens", 64)),
maxImageTokens: int(c.Uint("vision.max_image_tokens", 256)),
maxPixelsTolerance: float64(c.Float("vision.max_pixels_tolerance", 2.0)),
}
if len(mean) >= 3 {
processor.imageMean = [3]float32{mean[0], mean[1], mean[2]}
}
if len(std) >= 3 {
processor.imageStd = [3]float32{std[0], std[1], std[2]}
}
// Keep defaults aligned with HF unless explicitly configured.
if processor.downsampleFactor <= 0 {
processor.downsampleFactor = 2
}
if processor.patchSize <= 0 {
processor.patchSize = 16
}
if processor.tileSize <= 0 {
processor.tileSize = 512
}
if processor.minTiles <= 0 {
processor.minTiles = 2
}
if processor.maxTiles < processor.minTiles {
processor.maxTiles = processor.minTiles
}
if processor.minImageTokens <= 0 {
processor.minImageTokens = 64
}
if processor.maxImageTokens < processor.minImageTokens {
processor.maxImageTokens = processor.minImageTokens
}
if processor.maxPixelsTolerance <= 0 {
processor.maxPixelsTolerance = 2.0
}
return processor
}
func (p ImageProcessor) ProcessImage(img image.Image) ([]processedVisionImage, processedVisionLayout, error) {
img = imageproc.Composite(img)
orig := img.Bounds().Size()
resizedWidth, resizedHeight := p.smartResize(orig.Y, orig.X)
layout := processedVisionLayout{rows: 1, cols: 1}
if p.shouldSplit(orig.Y, orig.X) {
gridWidth, gridHeight, targetWidth, targetHeight := p.gridLayout(orig.Y, orig.X)
layout.rows = gridHeight
layout.cols = gridWidth
layout.hasThumbnail = p.useThumbnail && gridWidth*gridHeight != 1
resized := imageproc.Resize(img, image.Point{X: targetWidth, Y: targetHeight}, imageproc.ResizeBilinear)
images := make([]processedVisionImage, 0, gridWidth*gridHeight+1)
for row := range gridHeight {
for col := range gridWidth {
rect := image.Rect(
col*p.tileSize,
row*p.tileSize,
(col+1)*p.tileSize,
(row+1)*p.tileSize,
)
tile := cropImage(resized, rect)
images = append(images, processedVisionImage{
data: imageproc.Normalize(tile, p.imageMean, p.imageStd, true, true),
size: tile.Bounds().Size(),
row: row + 1,
col: col + 1,
})
}
}
if layout.hasThumbnail {
thumbnail := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
images = append(images, processedVisionImage{
data: imageproc.Normalize(thumbnail, p.imageMean, p.imageStd, true, true),
size: thumbnail.Bounds().Size(),
thumbnail: true,
})
}
return images, layout, nil
}
single := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
return []processedVisionImage{{
data: imageproc.Normalize(single, p.imageMean, p.imageStd, true, true),
size: single.Bounds().Size(),
}}, layout, nil
}
func (p ImageProcessor) shouldSplit(height, width int) bool {
if !p.doImageSplitting || p.minTiles == 1 && p.maxTiles == 1 {
return false
}
totalFactor := p.patchSize * p.downsampleFactor
hBar := max(p.patchSize, roundByFactor(height, totalFactor))
wBar := max(p.patchSize, roundByFactor(width, totalFactor))
limit := float64(p.maxImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor)
limit *= p.maxPixelsTolerance
return float64(hBar*wBar) > limit
}
func (p ImageProcessor) smartResize(height, width int) (int, int) {
totalFactor := p.patchSize * p.downsampleFactor
minPixels := p.minImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor
maxPixels := p.maxImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor
hBar := max(totalFactor, roundByFactor(height, totalFactor))
wBar := max(totalFactor, roundByFactor(width, totalFactor))
if hBar*wBar > maxPixels {
beta := math.Sqrt(float64(height*width) / float64(maxPixels))
hBar = max(totalFactor, int(math.Floor(float64(height)/beta/float64(totalFactor)))*totalFactor)
wBar = max(totalFactor, int(math.Floor(float64(width)/beta/float64(totalFactor)))*totalFactor)
} else if hBar*wBar < minPixels {
beta := math.Sqrt(float64(minPixels) / float64(height*width))
hBar = int(math.Ceil(float64(height)*beta/float64(totalFactor))) * totalFactor
wBar = int(math.Ceil(float64(width)*beta/float64(totalFactor))) * totalFactor
}
return wBar, hBar
}
func (p ImageProcessor) gridLayout(height, width int) (gridWidth, gridHeight, targetWidth, targetHeight int) {
aspectRatio := float64(width) / float64(height)
targetRatios := p.targetRatios()
bestRatio := clipImageSize{width: 1, height: 1}
bestRatioDiff := math.MaxFloat64
area := float64(width * height)
for _, ratio := range targetRatios {
targetAspect := float64(ratio.width) / float64(ratio.height)
ratioDiff := math.Abs(aspectRatio - targetAspect)
if ratioDiff < bestRatioDiff {
bestRatioDiff = ratioDiff
bestRatio = ratio
continue
}
if ratioDiff == bestRatioDiff {
targetArea := float64(p.tileSize * p.tileSize * ratio.width * ratio.height)
if area > 0.5*targetArea {
bestRatio = ratio
}
}
}
return bestRatio.width, bestRatio.height, p.tileSize * bestRatio.width, p.tileSize * bestRatio.height
}
type clipImageSize struct {
width int
height int
}
func (p ImageProcessor) targetRatios() []clipImageSize {
targetRatios := make([]clipImageSize, 0, p.maxTiles*p.maxTiles)
for n := p.minTiles; n <= p.maxTiles; n++ {
for w := 1; w <= n; w++ {
for h := 1; h <= n; h++ {
if w*h < p.minTiles || w*h > p.maxTiles {
continue
}
targetRatios = append(targetRatios, clipImageSize{width: w, height: h})
}
}
}
unique := targetRatios[:0]
for _, ratio := range targetRatios {
if slices.Contains(unique, ratio) {
continue
}
unique = append(unique, ratio)
}
slices.SortFunc(unique, func(a, b clipImageSize) int {
return a.width*a.height - b.width*b.height
})
return unique
}
func roundByFactor(number, factor int) int {
if factor <= 0 {
return number
}
return int(math.RoundToEven(float64(number)/float64(factor))) * factor
}
func cropImage(img image.Image, rect image.Rectangle) image.Image {
dst := image.NewRGBA(image.Rect(0, 0, rect.Dx(), rect.Dy()))
stdimage.Draw(dst, dst.Bounds(), img, rect.Min, stdimage.Src)
return dst
}

View File

@@ -0,0 +1,105 @@
package lfm2
import (
"image"
"image/color"
"testing"
)
func TestProcessImageSingleTile(t *testing.T) {
p := ImageProcessor{
patchSize: 16,
downsampleFactor: 2,
numChannels: 3,
imageMean: [3]float32{0.5, 0.5, 0.5},
imageStd: [3]float32{0.5, 0.5, 0.5},
doImageSplitting: true,
minTiles: 2,
maxTiles: 10,
useThumbnail: true,
tileSize: 512,
minImageTokens: 64,
maxImageTokens: 256,
maxPixelsTolerance: 2.0,
}
img := image.NewRGBA(image.Rect(0, 0, 320, 320))
out, layout, err := p.ProcessImage(img)
if err != nil {
t.Fatalf("ProcessImage returned error: %v", err)
}
if layout.rows != 1 || layout.cols != 1 || layout.hasThumbnail {
t.Fatalf("layout = %+v, want rows=1 cols=1 hasThumbnail=false", layout)
}
if len(out) != 1 {
t.Fatalf("len(out) = %d, want 1", len(out))
}
if out[0].size != (image.Point{X: 320, Y: 320}) {
t.Fatalf("single image size = %+v, want 320x320", out[0].size)
}
if out[0].thumbnail {
t.Fatalf("single image should not be marked as thumbnail")
}
}
func TestProcessImageDynamicTiling(t *testing.T) {
p := ImageProcessor{
patchSize: 16,
downsampleFactor: 2,
numChannels: 3,
imageMean: [3]float32{0.5, 0.5, 0.5},
imageStd: [3]float32{0.5, 0.5, 0.5},
doImageSplitting: true,
minTiles: 2,
maxTiles: 10,
useThumbnail: true,
tileSize: 512,
minImageTokens: 64,
maxImageTokens: 256,
maxPixelsTolerance: 2.0,
}
// Wide image that should trigger multi-tile splitting.
img := image.NewRGBA(image.Rect(0, 0, 3000, 1000))
fill := color.RGBA{R: 120, G: 90, B: 60, A: 255}
for y := range 1000 {
for x := range 3000 {
img.Set(x, y, fill)
}
}
out, layout, err := p.ProcessImage(img)
if err != nil {
t.Fatalf("ProcessImage returned error: %v", err)
}
if layout.rows*layout.cols <= 1 {
t.Fatalf("expected multi-tile layout, got %+v", layout)
}
if !layout.hasThumbnail {
t.Fatalf("expected thumbnail for multi-tile layout")
}
wantLen := layout.rows*layout.cols + 1
if len(out) != wantLen {
t.Fatalf("len(out) = %d, want %d", len(out), wantLen)
}
for i := range layout.rows * layout.cols {
if out[i].size != (image.Point{X: 512, Y: 512}) {
t.Fatalf("tile[%d] size = %+v, want 512x512", i, out[i].size)
}
if out[i].thumbnail {
t.Fatalf("tile[%d] should not be marked as thumbnail", i)
}
}
thumb := out[len(out)-1]
if !thumb.thumbnail {
t.Fatalf("last chunk should be thumbnail")
}
if thumb.size.X%32 != 0 || thumb.size.Y%32 != 0 {
t.Fatalf("thumbnail size = %+v, want dimensions aligned to 32", thumb.size)
}
}

View File

@@ -15,6 +15,7 @@ import (
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
_ "github.com/ollama/ollama/model/models/nemotronh"
_ "github.com/ollama/ollama/model/models/nomicbert"
_ "github.com/ollama/ollama/model/models/olmo3"
_ "github.com/ollama/ollama/model/models/qwen2"

View File

@@ -0,0 +1,88 @@
package nemotronh
import (
"fmt"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// Attention implements simple attention without RoPE for Nemotron-H.
// Unlike Qwen3Next, Nemotron-H attention has:
// - No RoPE (position info comes from Mamba2 layers)
// - Standard scaled dot-product attention
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
hiddenDim := hiddenStates.Dim(0)
nSeqTokens := hiddenStates.Dim(1)
switch hiddenStates.Dim(2) {
case 0:
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
case 1:
default:
return nil, ErrUnsupportedBatchLayout
}
// Nemotron-H is currently clamped to num_parallel=1.
if cache != nil && cache.IsSupportedForBatch() {
if cache.numSeqs() != 1 {
return nil, ErrUnsupportedBatchLayout
}
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
return nil, ErrUnsupportedBatchLayout
}
}
batchSize := nSeqTokens
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize)
headDim := opts.getHeadDim()
if headDim <= 0 {
return nil, fmt.Errorf("nemotronh: invalid attention head dimension %d", headDim)
}
// Q projection
query := a.Query.Forward(ctx, hiddenStates)
if query.Dim(0)%headDim != 0 {
return nil, fmt.Errorf("nemotronh: query dim %d not divisible by head dim %d", query.Dim(0), headDim)
}
numHeads := query.Dim(0) / headDim
query = query.Reshape(ctx, headDim, numHeads, batchSize)
// K projection
key := a.Key.Forward(ctx, hiddenStates)
if key.Dim(0)%headDim != 0 {
return nil, fmt.Errorf("nemotronh: key dim %d not divisible by head dim %d", key.Dim(0), headDim)
}
numKVHeads := key.Dim(0) / headDim
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
// V projection
value := a.Value.Forward(ctx, hiddenStates)
if value.Dim(0)%headDim != 0 {
return nil, fmt.Errorf("nemotronh: value dim %d not divisible by head dim %d", value.Dim(0), headDim)
}
if value.Dim(0)/headDim != numKVHeads {
return nil, fmt.Errorf("nemotronh: key heads %d and value heads %d do not match", numKVHeads, value.Dim(0)/headDim)
}
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
// Standard attention computation (no RoPE)
scale := opts.attentionScale
if scale == 0 {
scale = 1.0 / math.Sqrt(float64(headDim))
}
attention := nn.Attention(ctx, query, key, value, scale, cache)
// Flatten heads
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
// Output projection
return a.Output.Forward(ctx, attention), nil
}

View File

@@ -0,0 +1,55 @@
package nemotronh
import (
"errors"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
// with the layer requirements.
var ErrUnsupportedBatchLayout = errors.New("nemotronh: unsupported batch layout")
var (
_ kvcache.Cache = (*HybridCache)(nil)
_ kvcache.CheckpointCache = (*HybridCache)(nil)
)
// HybridCache adapts the shared recurrent cache base for Nemotron-H naming.
type HybridCache struct {
*kvcache.Recurrent
}
func NewHybridCache(convDim, convChannels, ssmStateSize int) *HybridCache {
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
Shift: Shift,
ConvDim: convDim,
ConvChannels: convChannels,
RecurrentStateSize: ssmStateSize,
CheckpointLogPrefix: "nemotronh",
})
return &HybridCache{Recurrent: base}
}
// SSMState returns the SSM state for current batch sequences.
func (c *HybridCache) SSMState(ctx ml.Context, layer int, dState, headDim, nHead int) (ml.Tensor, error) {
return c.RecurrentState4D(ctx, layer, dState, headDim, nHead)
}
// UpdateSSMState writes a new SSM state for current batch sequences.
func (c *HybridCache) UpdateSSMState(ctx ml.Context, layer int, newState ml.Tensor) {
c.UpdateRecurrentState(ctx, layer, newState)
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.SlotsTensor()
}
func (c *HybridCache) seqTokens() int {
return c.SeqTokens()
}
func (c *HybridCache) numSeqs() int {
return c.NumSeqs()
}

View File

@@ -0,0 +1,197 @@
package nemotronh
import (
"log/slog"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// convKernel wraps the 1D convolution kernel tensor
type convKernel struct {
Weight ml.Tensor `gguf:"weight"`
}
// Mamba2 implements the Mamba2 SSM layer for Nemotron-H.
// The forward pass follows llama.cpp's build_mamba2_layer:
// 1. Input projection: zxBCdt = SSMIn @ hidden
// 2. Split: z, xBC, dt
// 3. Concat with conv state, apply SSMConv, save new conv state
// 4. Apply SiLU to convolved xBC
// 5. Split: x, B, C
// 6. Add dt bias
// 7. SSMScan: y = SSMScan(state, x, dt, A, B, C, ids)
// 8. D skip: y = y + x * D
// 9. Swiglu with z: y = z * silu(y)
// 10. Group RMSNorm
// 11. Output projection
type Mamba2 struct {
SSMIn *nn.Linear `gguf:"ssm_in"` // n_embd → d_in_proj (2*d_inner + 2*n_group*d_state + n_head)
SSMConv1D *convKernel `gguf:"ssm_conv1d"` // conv kernel
SSMConv1DB ml.Tensor `gguf:"ssm_conv1d.bias"`
SSMDtB ml.Tensor `gguf:"ssm_dt.bias"` // dt bias [n_head]
SSMA ml.Tensor `gguf:"ssm_a"` // A parameter [1, n_head]
SSMD ml.Tensor `gguf:"ssm_d"` // D skip connection [1, n_head]
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` // group norm
SSMOut *nn.Linear `gguf:"ssm_out"` // output projection
Layer int
}
func (m *Mamba2) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
layer := m.Layer
hiddenDim := hiddenStates.Dim(0)
nSeqTokens := hiddenStates.Dim(1)
switch hiddenStates.Dim(2) {
case 0:
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
case 1:
default:
return nil, ErrUnsupportedBatchLayout
}
// Nemotron-H is currently clamped to num_parallel=1.
if cache != nil && cache.IsSupportedForBatch() {
if cache.numSeqs() != 1 {
return nil, ErrUnsupportedBatchLayout
}
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
return nil, ErrUnsupportedBatchLayout
}
}
nSeqs := 1
dConv := opts.ssmDConv
dInner := opts.ssmDInner
dState := opts.ssmDState
nHead := opts.ssmNHead
headDim := dInner / nHead
nGroup := opts.ssmNGroup
// {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
// d_in_proj = 2*d_inner + 2*n_group*d_state + n_head
zxBCdt := m.SSMIn.Forward(ctx, hiddenStates)
// Split into z, xBC, dt
// z: [head_dim, n_head, n_seq_tokens, n_seqs]
z := zxBCdt.Slice(ctx, 0, 0, dInner, 1)
z = z.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
// xBC: [d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs]
xBCSize := dInner + 2*nGroup*dState
xBC := zxBCdt.Slice(ctx, 0, dInner, dInner+xBCSize, 1)
if nSeqTokens == 1 {
xBC = xBC.Reshape(ctx, xBCSize, 1, nSeqs)
}
// dt: [n_head, n_seq_tokens, n_seqs]
dt := zxBCdt.Slice(ctx, 0, 2*dInner+2*nGroup*dState, 2*dInner+2*nGroup*dState+nHead, 1)
if nSeqTokens == 1 {
dt = dt.Reshape(ctx, nHead, 1, nSeqs)
} else {
dt = dt.Contiguous(ctx, nHead, nSeqTokens, nSeqs)
}
// Get conv state from cache
convStates, err := cache.ConvState(ctx, layer)
if err != nil {
slog.Warn("nemotronh: failed to get conv state, using zeros", "layer", layer, "error", err)
convStates = ctx.Input().Zeros(ml.DTypeF32, dConv-1, xBCSize, nSeqs)
}
// Reshape conv states: [d_conv-1, xBCSize, n_seqs]
convStates = convStates.Reshape(ctx, dConv-1, xBCSize, nSeqs)
// For decode (n_seq_tokens == 1), reshape avoids a transpose/contiguous pair.
var xBCT ml.Tensor
if nSeqTokens == 1 {
xBCT = xBC.Reshape(ctx, 1, xBCSize, nSeqs)
} else {
// Prefill path: [xBCSize, n_seq_tokens, n_seqs] -> [n_seq_tokens, xBCSize, n_seqs]
xBCT = xBC.Permute(ctx, 1, 0, 2, 3)
}
// Concatenate with conv state: [d_conv-1 + n_seq_tokens, xBCSize, n_seqs]
convInput := convStates.Concat(ctx, xBCT, 0)
// Save new conv state (last d_conv-1 columns)
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+dConv-1, 1)
cache.UpdateConvState(ctx, layer, lastConvStates)
// Apply SSM convolution
xBC = convInput.SSMConv(ctx, m.SSMConv1D.Weight)
// Add conv bias
if m.SSMConv1DB != nil {
xBC = xBC.Add(ctx, m.SSMConv1DB)
}
// Apply SiLU
xBC = xBC.SILU(ctx)
// Split xBC into x, B, C
// x: [head_dim, n_head, n_seq_tokens, n_seqs]
x := xBC.Slice(ctx, 0, 0, dInner, 1)
x = x.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
// B: [d_state, n_group, n_seq_tokens, n_seqs]
B := xBC.Slice(ctx, 0, dInner, dInner+nGroup*dState, 1)
B = B.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
// C: [d_state, n_group, n_seq_tokens, n_seqs]
C := xBC.Slice(ctx, 0, dInner+nGroup*dState, dInner+2*nGroup*dState, 1)
C = C.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
// Add dt bias
dt = dt.Add(ctx, m.SSMDtB)
// Get SSM state from cache
state, err := cache.SSMState(ctx, layer, dState, headDim, nHead)
if err != nil {
slog.Warn("nemotronh: failed to get SSM state, using zeros", "layer", layer, "error", err)
state = ctx.Input().Zeros(ml.DTypeF32, dState, headDim, nHead, nSeqs)
}
// SSMScan
// state: [d_state, head_dim, n_head, n_seqs]
// returns: [head_dim, n_head, n_seq_tokens, n_seqs] concatenated with new state
ySsm := state.SSMScan(ctx, x, dt, m.SSMA, B, C, cache.slotsTensor())
// ySsm is a packed 1D buffer: [y (nSeqTokens*headDim*nHead*nSeqs), newState]
yElems := headDim * nHead * nSeqTokens * nSeqs
y := ySsm.View(ctx, 0, yElems).Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
stateOffsetBytes := yElems * x.Stride(0)
stateElems := dState * headDim * nHead * nSeqs
newState := ySsm.View(ctx, stateOffsetBytes, stateElems)
newState = newState.Reshape(ctx, dState, headDim, nHead, nSeqs)
// Update SSM state in cache
cache.UpdateSSMState(ctx, layer, newState)
// D skip connection: y = y + x * D
if m.SSMD != nil {
// SSMD shape: [1, n_head] -> broadcast to [head_dim, n_head, n_seq_tokens, n_seqs]
xD := x.Mul(ctx, m.SSMD)
y = y.Add(ctx, xD)
}
// Swiglu with z: y = z * silu(y)
y = z.SILU(ctx, y)
// Group RMSNorm
if m.SSMNorm != nil {
// Reshape for group norm: [d_inner/n_group, n_group, n_seq_tokens, n_seqs]
innerPerGroup := dInner / nGroup
y = y.Reshape(ctx, innerPerGroup, nGroup, nSeqTokens, nSeqs)
y = m.SSMNorm.Forward(ctx, y, opts.eps)
}
// Reshape back to [d_inner, n_seq_tokens, n_seqs]
y = y.Reshape(ctx, dInner, nSeqTokens, nSeqs)
// Output projection
out := m.SSMOut.Forward(ctx, y)
// Reshape to 2D for consistency with attention output
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
}

View File

@@ -0,0 +1,417 @@
package nemotronh
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
// Options contains model configuration
type Options struct {
hiddenSize int
numHeads int // attention heads
numKVHeads int // KV heads for attention layers
headDim int
eps float32
// Mamba2 SSM config
ssmDConv int // conv kernel size
ssmDInner int // inner dimension (d_inner)
ssmDState int // state dimension
ssmNHead int // number of SSM heads (dt_rank)
ssmNGroup int // number of groups for B, C
// Per-layer configuration
isRecurrent []bool // true = Mamba2, false = attention or FFN
nFF []int // n_ff per layer (0 = attention-only)
// Attention scale
attentionScale float64
// MoE config
numExperts int
numExpertsUsed int
expertWeightsNorm bool
expertWeightsScale float32
expertWeightsNormClip float32
}
func (o Options) getHeadDim() int {
if o.headDim > 0 {
return o.headDim
}
if o.numHeads <= 0 {
return 0
}
return o.hiddenSize / o.numHeads
}
// Operator is the interface for layer operators (Mamba2 or Attention)
type Operator interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
}
// MLP is the interface for feedforward networks
type MLP interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
}
// Layer represents a single transformer layer
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Operator Operator // Mamba2, Attention, or nil (for FFN-only layers)
MLP MLP // Dense or MoE FFN, or nil
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
residual := hiddenStates
// Pre-layer norm
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
// Layer operator (Mamba2, Attention, or FFN)
if l.Operator != nil {
var err error
hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, cache, opts)
if err != nil {
return nil, err
}
} else if l.MLP != nil {
// FFN-only layer
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
}
// Output projection for last layer
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
// Residual connection
return hiddenStates.Add(ctx, residual), nil
}
// Model is the main Nemotron-H model
type Model struct {
model.Base
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Layers []Layer `gguf:"blk"`
*Options
}
// Shift is used for KV cache position shifting.
// Nemotron-H attention does not apply RoPE, so keys do not need to be transformed.
func Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
cache := m.Cache.(*HybridCache)
for i, layer := range m.Layers {
cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
var err error
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, outputs, cache, m.Options)
if err != nil {
return nil, err
}
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers)
// Get per-layer configuration from GGUF metadata
// Use the same interface pattern as qwen3next
type perLayerConfig interface {
HeadCount() []uint64
HeadCountKV() []uint64
FFNLength() []uint64
}
var headCount []uint64
var headCountKV []uint64
var ffnLength []uint64
if plc, ok := c.(perLayerConfig); ok {
headCount = plc.HeadCount()
headCountKV = plc.HeadCountKV()
ffnLength = plc.FFNLength()
}
// Build per-layer arrays with defaults
isRecurrent := make([]bool, numLayers)
nFF := make([]int, numLayers)
for i := range numLayers {
// Get per-layer values
kvHeads := uint64(1) // Default non-zero
if i < len(headCountKV) {
kvHeads = headCountKV[i]
}
ff := uint64(0)
if i < len(ffnLength) {
ff = ffnLength[i]
}
nFF[i] = int(ff)
// A layer is recurrent IFF n_head_kv == 0 AND n_ff == 0
// This matches llama.cpp behavior for Nemotron-H
isRecurrent[i] = kvHeads == 0 && ff == 0
}
// Determine if MoE
isMoE := c.Uint("expert_count") > 0
for i := range layers {
if isRecurrent[i] {
// Mamba2 layer
layers[i].Operator = &Mamba2{Layer: i}
} else if nFF[i] == 0 {
// Attention-only layer (n_head_kv > 0, n_ff == 0)
layers[i].Operator = &Attention{}
} else {
// FFN layer (n_ff > 0)
if isMoE {
layers[i].MLP = &MoESparse{}
} else {
layers[i].MLP = &Dense{}
}
}
}
// Get attention head configuration
numHeads := int(c.Uint("attention.head_count"))
if numHeads == 0 {
for i := range numLayers {
if i < len(headCount) && i < len(headCountKV) && headCount[i] > 0 && headCountKV[i] > 0 {
numHeads = int(headCount[i])
break
}
}
}
numKVHeads := int(c.Uint("attention.head_count_kv"))
if numKVHeads == 0 {
for i := range numLayers {
if i < len(headCountKV) && i < len(ffnLength) && headCountKV[i] > 0 && ffnLength[i] == 0 {
numKVHeads = int(headCountKV[i])
break
}
}
if numKVHeads == 0 {
numKVHeads = numHeads
}
}
headDim := int(c.Uint("attention.head_dim"))
if headDim == 0 {
if keyLength := int(c.Uint("attention.key_length")); keyLength > 0 {
headDim = keyLength
} else if numHeads > 0 {
headDim = int(c.Uint("embedding_length")) / numHeads
}
}
if headDim <= 0 {
return nil, fmt.Errorf("nemotronh: invalid attention head dimension")
}
if numHeads <= 0 {
// Attention layers derive per-layer head counts from projection weights.
// Keep a non-zero default to avoid invalid option math.
numHeads = 1
}
numExperts := int(c.Uint("expert_count"))
numExpertsUsed := int(c.Uint("expert_used_count"))
if numExperts > 0 {
if numExpertsUsed <= 0 || numExpertsUsed > numExperts {
return nil, fmt.Errorf("nemotronh: invalid expert_used_count=%d for expert_count=%d", numExpertsUsed, numExperts)
}
}
opts := &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: numHeads,
numKVHeads: numKVHeads,
headDim: headDim,
eps: c.Float("attention.layer_norm_rms_epsilon"),
ssmDConv: int(c.Uint("ssm.conv_kernel")),
ssmDInner: int(c.Uint("ssm.inner_size")),
ssmDState: int(c.Uint("ssm.state_size")),
ssmNHead: int(c.Uint("ssm.time_step_rank")),
ssmNGroup: int(c.Uint("ssm.group_count")),
isRecurrent: isRecurrent,
nFF: nFF,
attentionScale: float64(c.Float("attention.scale")),
numExperts: numExperts,
numExpertsUsed: numExpertsUsed,
expertWeightsNorm: c.Bool("expert_weights_norm", false),
expertWeightsScale: c.Float("expert_weights_scale", 1.0),
expertWeightsNormClip: c.Float("expert_weights_norm_clip", 0),
}
// Calculate cache dimensions
convDim := max(0, opts.ssmDConv-1)
convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
ssmHeadDim := 0
if opts.ssmNHead > 0 {
ssmHeadDim = opts.ssmDInner / opts.ssmNHead
}
ssmStateSize := opts.ssmDState * ssmHeadDim * opts.ssmNHead
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
Layers: layers,
Options: opts,
}
m.Cache = NewHybridCache(convDim, convChannels, ssmStateSize)
return &m, nil
}
func init() {
model.Register("nemotron_h", New)
model.Register("nemotron_h_moe", New)
}
// Ensure Model implements model.Model
var _ model.Model = (*Model)(nil)
// Dense implements standard feedforward with ReLU-squared activation
type Dense struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (d *Dense) Forward(ctx ml.Context, x ml.Tensor, opts *Options) ml.Tensor {
// up -> ReLU-squared -> down
up := d.Up.Forward(ctx, x)
up = up.RELU(ctx)
up = up.Mul(ctx, up) // Square
return d.Down.Forward(ctx, up)
}
// MoESparse implements MoE with shared experts for Nemotron-H-MoE
type MoESparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
Bias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
LatentIn *nn.Linear `gguf:"ffn_latent_in"`
LatentOut *nn.Linear `gguf:"ffn_latent_out"`
// Shared experts
SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
}
func (m *MoESparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenDim := hiddenStates.Dim(0)
seqLen := hiddenStates.Dim(1)
batchSize := hiddenStates.Dim(2)
if batchSize == 0 {
batchSize = 1
}
hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, seqLen*batchSize)
// Router logits with sigmoid gating
routerLogits := m.Router.Forward(ctx, hiddenStates2D)
// Weights come from unbiased sigmoid probabilities.
probs := routerLogits.Sigmoid(ctx)
// Selection uses optional bias.
selectionProbs := probs
if m.Bias != nil {
selectionProbs = selectionProbs.Add(ctx, m.Bias)
}
// Select top-k experts
selectedExperts := selectionProbs.TopK(ctx, opts.numExpertsUsed)
routingWeights := probs.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
// Normalize routing weights
if opts.expertWeightsNorm {
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
weightsSum := routingWeights.SumRows(ctx)
weightsSum = weightsSum.Clamp(ctx, 6.103515625e-5, float32(math.MaxFloat32))
routingWeights = routingWeights.Div(ctx, weightsSum)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
}
// Scale routing weights
if opts.expertWeightsScale != 1.0 {
routingWeights = routingWeights.Scale(ctx, float64(opts.expertWeightsScale))
}
routedInput := hiddenStates2D
if m.LatentIn != nil {
routedInput = m.LatentIn.Forward(ctx, routedInput)
}
hiddenStates3D := routedInput.Reshape(ctx, routedInput.Dim(0), 1, routedInput.Dim(1))
// Expert computation with ReLU-squared activation
upOut := m.Up.Forward(ctx, hiddenStates3D, selectedExperts)
upOut = upOut.RELU(ctx)
upOut = upOut.Mul(ctx, upOut) // Square
experts := m.Down.Forward(ctx, upOut, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
// Sum over experts
moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
if m.LatentOut != nil {
moeOut = m.LatentOut.Forward(ctx, moeOut)
}
// Add shared experts if present
if m.SharedUp != nil {
sharedUp := m.SharedUp.Forward(ctx, hiddenStates2D)
sharedUp = sharedUp.RELU(ctx)
sharedUp = sharedUp.Mul(ctx, sharedUp) // Square
sharedOut := m.SharedDown.Forward(ctx, sharedUp)
moeOut = moeOut.Add(ctx, sharedOut)
}
return moeOut
}

View File

@@ -32,6 +32,8 @@ type LFM2Parser struct {
hasThinkingSupport bool
needsThinkingLeadingTrim bool // trim leading whitespace after <think> tag
needsContentLeadingTrim bool // trim leading whitespace after </think> tag
toolNames map[string]struct{}
hasTools bool
}
func (p *LFM2Parser) HasToolSupport() bool {
@@ -63,6 +65,13 @@ func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.T
}
func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.toolNames = make(map[string]struct{}, len(tools))
p.hasTools = len(tools) > 0
for _, tool := range tools {
if tool.Function.Name != "" {
p.toolNames[tool.Function.Name] = struct{}{}
}
}
p.setInitialState(lastMessage, thinkValue)
return tools
}
@@ -105,9 +114,33 @@ func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string,
}
}
// Fallback for models that emit bare tool calls without <|tool_call_*|> wrappers.
if done && len(toolCalls) == 0 && p.hasTools {
candidate := strings.TrimSpace(contentSb.String())
if fallbackCalls, parseErr := p.parseToolCallsContent(candidate); parseErr == nil && p.toolCallsAllowed(fallbackCalls) {
contentSb.Reset()
toolCalls = append(toolCalls, fallbackCalls...)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *LFM2Parser) toolCallsAllowed(calls []api.ToolCall) bool {
if len(calls) == 0 {
return false
}
if len(p.toolNames) == 0 {
return true
}
for _, call := range calls {
if _, ok := p.toolNames[call.Function.Name]; !ok {
return false
}
}
return true
}
func (p *LFM2Parser) parseEvents() []lfm2Event {
var all []lfm2Event
@@ -269,36 +302,12 @@ func (p *LFM2Parser) eat() ([]lfm2Event, bool) {
return events, false
}
// parseToolCallsContent parses one or more tool calls from content
// Supports JSON format and Python-style format including multiple calls: [func1(...),func2(...)]
// parseToolCallsContent parses one or more Python-style tool calls.
// Example: [func1(arg='v'), func2(x=1)]
func (p *LFM2Parser) parseToolCallsContent(content string) ([]api.ToolCall, error) {
content = strings.TrimSpace(content)
// Try JSON format first: {"name": "func", "arguments": {...}}
var parsed struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" {
var args api.ToolCallFunctionArguments
if len(parsed.Arguments) > 0 {
if err := json.Unmarshal(parsed.Arguments, &args); err != nil {
return nil, err
}
} else {
args = api.NewToolCallFunctionArguments()
}
return []api.ToolCall{{
Function: api.ToolCallFunction{
Name: parsed.Name,
Arguments: args,
},
}}, nil
}
// Try Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
// Parse Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
return p.parsePythonStyleToolCalls(content)
}
@@ -417,21 +426,16 @@ func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error)
// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2"
func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error {
// Simple state machine to parse key='value' pairs
// Handles: command='ls', flag="-la", count=42, enabled=true
var key string
i := 0
for i < len(argsStr) {
// Skip whitespace
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
// Skip separators and whitespace.
for i < len(argsStr) && (argsStr[i] == ',' || unicode.IsSpace(rune(argsStr[i]))) {
i++
}
if i >= len(argsStr) {
break
}
// Parse key
keyStart := i
for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' {
i++
@@ -439,60 +443,238 @@ func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error
if i >= len(argsStr) || argsStr[i] != '=' {
return errors.New("invalid argument: expected '='")
}
key = strings.TrimSpace(argsStr[keyStart:i])
key := strings.TrimSpace(argsStr[keyStart:i])
if key == "" {
return errors.New("invalid argument: empty key")
}
i++ // skip '='
// Skip whitespace after =
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') {
for i < len(argsStr) && unicode.IsSpace(rune(argsStr[i])) {
i++
}
// Parse value
var value string
if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') {
// Quoted string
quote := argsStr[i]
i++
valueStart := i
for i < len(argsStr) && argsStr[i] != quote {
if argsStr[i] == '\\' && i+1 < len(argsStr) {
i += 2 // skip escaped char
} else {
i++
}
}
value = argsStr[valueStart:i]
if i < len(argsStr) {
i++ // skip closing quote
}
args.Set(key, value)
} else {
// Unquoted value (number, bool, etc)
valueStart := i
for i < len(argsStr) && argsStr[i] != ',' {
i++
}
value = strings.TrimSpace(argsStr[valueStart:i])
// Try to parse as number or bool
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
args.Set(key, v)
} else if v, err := strconv.ParseFloat(value, 64); err == nil {
args.Set(key, v)
} else if value == "true" {
args.Set(key, true)
} else if value == "false" {
args.Set(key, false)
} else {
args.Set(key, value)
}
if i >= len(argsStr) {
return errors.New("invalid argument: missing value")
}
// Skip comma and whitespace
for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
value, next, err := parsePythonArgValue(argsStr, i)
if err != nil {
return err
}
args.Set(key, value)
i = next
// Optional trailing comma before next key/value.
if i < len(argsStr) && argsStr[i] == ',' {
i++
}
}
return nil
}
func parsePythonArgValue(s string, i int) (any, int, error) {
if i >= len(s) {
return nil, i, errors.New("invalid argument: missing value")
}
// Quoted string literal.
if s[i] == '\'' || s[i] == '"' {
quote := s[i]
i++
start := i
for i < len(s) {
if s[i] == '\\' && i+1 < len(s) {
i += 2
continue
}
if s[i] == quote {
value := s[start:i]
i++
return value, i, nil
}
i++
}
return nil, i, errors.New("invalid argument: unterminated string")
}
// Unquoted literal. Consume until top-level comma.
start := i
depthParen, depthSquare, depthCurly := 0, 0, 0
inString := false
var quote byte
escaped := false
for i < len(s) {
ch := s[i]
if inString {
if escaped {
escaped = false
} else if ch == '\\' {
escaped = true
} else if ch == quote {
inString = false
}
i++
continue
}
switch ch {
case '\'', '"':
inString = true
quote = ch
case '(':
depthParen++
case ')':
if depthParen > 0 {
depthParen--
}
case '[':
depthSquare++
case ']':
if depthSquare > 0 {
depthSquare--
}
case '{':
depthCurly++
case '}':
if depthCurly > 0 {
depthCurly--
}
case ',':
if depthParen == 0 && depthSquare == 0 && depthCurly == 0 {
token := strings.TrimSpace(s[start:i])
value, err := parsePythonLiteral(token)
return value, i, err
}
}
i++
}
token := strings.TrimSpace(s[start:i])
value, err := parsePythonLiteral(token)
return value, i, err
}
func parsePythonLiteral(token string) (any, error) {
switch token {
case "":
return "", nil
case "true", "True":
return true, nil
case "false", "False":
return false, nil
case "null", "None":
return nil, nil
}
if v, err := strconv.ParseInt(token, 10, 64); err == nil {
return v, nil
}
if v, err := strconv.ParseFloat(token, 64); err == nil {
return v, nil
}
if strings.HasPrefix(token, "[") || strings.HasPrefix(token, "{") {
var parsed any
if err := json.Unmarshal([]byte(token), &parsed); err == nil {
return parsed, nil
}
if converted, err := pythonLiteralToJSON(token); err == nil {
if err := json.Unmarshal([]byte(converted), &parsed); err == nil {
return parsed, nil
}
}
}
return token, nil
}
func pythonLiteralToJSON(s string) (string, error) {
var out strings.Builder
out.Grow(len(s) + len(s)/8)
inString := false
var quote byte
escaped := false
for i := 0; i < len(s); i++ {
ch := s[i]
if inString {
if escaped {
out.WriteByte(ch)
escaped = false
continue
}
if ch == '\\' {
out.WriteByte(ch)
escaped = true
continue
}
if ch == quote {
out.WriteByte('"')
inString = false
continue
}
if quote == '\'' && ch == '"' {
out.WriteString(`\"`)
continue
}
out.WriteByte(ch)
continue
}
if ch == '\'' || ch == '"' {
inString = true
quote = ch
escaped = false
out.WriteByte('"')
continue
}
// Replace Python identifiers with JSON equivalents when outside strings.
if isIdentStart(ch) {
j := i + 1
for j < len(s) && isIdentPart(s[j]) {
j++
}
ident := s[i:j]
switch ident {
case "True":
out.WriteString("true")
case "False":
out.WriteString("false")
case "None":
out.WriteString("null")
default:
out.WriteString(ident)
}
i = j - 1
continue
}
out.WriteByte(ch)
}
if inString {
return "", errors.New("unterminated string")
}
return out.String(), nil
}
func isIdentStart(b byte) bool {
return (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || b == '_'
}
func isIdentPart(b byte) bool {
return isIdentStart(b) || (b >= '0' && b <= '9')
}

View File

@@ -39,7 +39,7 @@ func TestLFM2Parser(t *testing.T) {
},
{
name: "tool_call_simple",
input: "I'll check the weather.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
input: "I'll check the weather.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>",
expectedContent: "I'll check the weather.",
expectedCalls: []api.ToolCall{
{
@@ -55,7 +55,7 @@ func TestLFM2Parser(t *testing.T) {
},
{
name: "multiple_tool_calls",
input: "Getting weather for both cities.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|><|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"London\"}}<|tool_call_end|>",
input: "Getting weather for both cities.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|><|tool_call_start|>[get_weather(location=\"London\")]<|tool_call_end|>",
expectedContent: "Getting weather for both cities.",
expectedCalls: []api.ToolCall{
{
@@ -79,7 +79,7 @@ func TestLFM2Parser(t *testing.T) {
},
{
name: "complex_tool_arguments",
input: "Processing data.<|tool_call_start|>{\"name\":\"process_data\",\"arguments\":{\"items\":[\"item1\",\"item2\"],\"config\":{\"enabled\":true,\"threshold\":0.95}}}<|tool_call_end|>",
input: "Processing data.<|tool_call_start|>[process_data(items=['item1','item2'], config={'enabled': True, 'threshold': 0.95})]<|tool_call_end|>",
expectedContent: "Processing data.",
expectedCalls: []api.ToolCall{
{
@@ -96,7 +96,7 @@ func TestLFM2Parser(t *testing.T) {
},
{
name: "thinking_with_tool_call",
input: "Let me check the weather...</think>I'll get that for you.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
input: "Let me check the weather...</think>I'll get that for you.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>",
expectedThinking: "Let me check the weather...",
expectedContent: "I'll get that for you.",
expectedCalls: []api.ToolCall{
@@ -144,16 +144,16 @@ func TestLFM2Parser(t *testing.T) {
hasThinking: true,
},
{
name: "tool_call_with_unicode_args",
input: "Searching for information.<|tool_call_start|>{\"name\":\"search\",\"arguments\":{\"query\":\"北京天气\",\"language\":\"中文\"}}<|tool_call_end|>",
name: "tool_call_with_text_args",
input: "Searching for information.<|tool_call_start|>[search(query='beijing weather', language='zh')]<|tool_call_end|>",
expectedContent: "Searching for information.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: testArgs(map[string]any{
"query": "北京天气",
"language": "中文",
"query": "beijing weather",
"language": "zh",
}),
},
},
@@ -169,7 +169,7 @@ func TestLFM2Parser(t *testing.T) {
},
{
name: "empty_tool_call_args",
input: "Pinging server.<|tool_call_start|>{\"name\":\"ping\",\"arguments\":{}}<|tool_call_end|>",
input: "Pinging server.<|tool_call_start|>[ping()]<|tool_call_end|>",
expectedContent: "Pinging server.",
expectedCalls: []api.ToolCall{
{
@@ -353,7 +353,7 @@ func TestLFM2Parser_Streaming(t *testing.T) {
},
{
name: "streaming_tool_call",
chunks: []string{"I'll check weather.", "<|tool_call_start|>", "{\"name\":\"get_weather\",", "\"arguments\":{\"location\":\"Paris\"}}", "<|tool_call_end|>"},
chunks: []string{"I'll check weather.", "<|tool_call_start|>", "[get_weather(", "location=\"Paris\")]", "<|tool_call_end|>"},
expectedContent: "I'll check weather.",
expectedCalls: []api.ToolCall{
{
@@ -381,16 +381,16 @@ func TestLFM2Parser_Streaming(t *testing.T) {
hasThinking: false,
},
{
name: "streaming_tool_call_with_split_json",
chunks: []string{"Processing.", "<|tool_call_start|>{\"name\":\"calc\",\"arguments\":{\"x\":", "42,\"y\":", "24}}<|tool_call_end|>"},
name: "streaming_tool_call_with_split_python",
chunks: []string{"Processing.", "<|tool_call_start|>", "[calc(", "x=42, ", "y=24)]", "<|tool_call_end|>"},
expectedContent: "Processing.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "calc",
Arguments: testArgs(map[string]any{
"x": float64(42),
"y": float64(24),
"x": int64(42),
"y": int64(24),
}),
},
},
@@ -516,8 +516,8 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) {
expectError bool
}{
{
name: "valid_tool_call",
content: `{"name":"get_weather","arguments":{"location":"Paris"}}`,
name: "python_style_single_call",
content: `get_weather(location="Paris")`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
@@ -528,21 +528,33 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) {
},
},
{
name: "complex_arguments",
content: `{"name":"process_data","arguments":{"items":["a","b"],"config":{"enabled":true}}}`,
name: "python_style_with_brackets",
content: `[get_weather(location="Paris")]`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process_data",
Name: "get_weather",
Arguments: testArgs(map[string]any{
"items": []interface{}{"a", "b"},
"config": map[string]interface{}{"enabled": true},
"location": "Paris",
}),
},
},
},
{
name: "empty_arguments",
content: `{"name":"ping","arguments":{}}`,
name: "python_style_complex_arguments",
content: `process(items=['a', 'b'], config={'enabled': True})`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process",
Arguments: testArgs(map[string]any{
"items": []any{"a", "b"},
"config": map[string]any{"enabled": true},
}),
},
},
},
{
name: "python_style_empty_arguments",
content: `ping()`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "ping",
@@ -551,44 +563,13 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) {
},
},
{
name: "unicode_in_tool_name",
content: `{"name":"获取天气","arguments":{"城市":"北京"}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: testArgs(map[string]any{
"城市": "北京",
}),
},
},
},
{
name: "numeric_arguments",
content: `{"name":"calculate","arguments":{"x":3.14,"y":42,"enabled":true}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: testArgs(map[string]any{
"x": 3.14,
"y": float64(42),
"enabled": true,
}),
},
},
},
{
name: "invalid_json",
content: `{invalid json}`,
name: "missing_parenthesis",
content: `get_weather location="Paris")`,
expectError: true,
},
{
name: "missing_name",
content: `{"arguments":{"arg":"value"}}`,
expectError: true,
},
{
name: "empty_name",
content: `{"name":"","arguments":{"arg":"value"}}`,
name: "invalid_argument_format",
content: `bash(command)`,
expectError: true,
},
}
@@ -645,6 +626,24 @@ func TestLFM2Parser_parseToolCallsContent(t *testing.T) {
},
},
},
{
name: "python_style_complex_literals",
content: `[AskUserQuestion(question="What's up?", headers=['Hello!', 'How can I help you?'], options=['Debugging help', 'Code writing assistance'], multiSelect=False, metadata={'priority': 1, 'active': True})]`,
expected: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "AskUserQuestion",
Arguments: testArgs(map[string]any{
"question": "What's up?",
"headers": []any{"Hello!", "How can I help you?"},
"options": []any{"Debugging help", "Code writing assistance"},
"multiSelect": false,
"metadata": map[string]any{"priority": float64(1), "active": true},
}),
},
},
},
},
{
name: "single_python_style_call",
content: `bash(command='ls -la')`,
@@ -1086,3 +1085,106 @@ func TestLFM2Parser_EdgeCases(t *testing.T) {
})
}
}
func TestLFM2Parser_BareToolCallFallback(t *testing.T) {
parser := &LFM2Parser{}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
},
},
}
parser.Init(tools, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add(`[get_weather(location="Paris")]`, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name get_weather, got %q", calls[0].Function.Name)
}
}
func TestLFM2Parser_BareUnknownToolCallDoesNotParse(t *testing.T) {
parser := &LFM2Parser{}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
},
},
}
parser.Init(tools, nil, &api.ThinkValue{Value: false})
input := `[unknown_tool(location="Paris")]`
content, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if content != input {
t.Fatalf("expected content to be preserved, got %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestLFM2Parser_ImagePlaceholdersPreserved(t *testing.T) {
tests := []struct {
name string
input string
}{
{
name: "indexed_img_placeholder",
input: "[img-0]describe this image",
},
{
name: "template_image_placeholder",
input: "<image>describe this image",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &LFM2Parser{}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "bash",
},
},
}
parser.Init(tools, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add(tt.input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if content != tt.input {
t.Fatalf("expected content %q, got %q", tt.input, content)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
})
}
}

View File

@@ -57,6 +57,8 @@ func TestBuiltInParsersStillWork(t *testing.T) {
{"qwen3"},
{"qwen3-thinking"},
{"qwen3-coder"},
{"lfm2"},
{"lfm2-thinking"},
{"harmony"},
}

View File

@@ -1,7 +1,9 @@
package renderers
import (
"bytes"
"encoding/json"
"sort"
"strings"
"github.com/ollama/ollama/api"
@@ -9,18 +11,219 @@ import (
type LFM2Renderer struct {
IsThinking bool
useImgTags bool
}
const lfm2BOSToken = "<|startoftext|>"
const (
lfm2ToolListStartTag = "<|tool_list_start|>"
lfm2ToolListEndTag = "<|tool_list_end|>"
lfm2ToolCallStartTag = "<|tool_call_start|>"
lfm2ToolCallEndTag = "<|tool_call_end|>"
lfm2ToolResponseStartTag = "<|tool_response_start|>"
lfm2ToolResponseEndTag = "<|tool_response_end|>"
)
func lfm2RenderSystemContent(content any) string {
switch v := content.(type) {
case string:
return v
case []any:
var sb strings.Builder
for _, item := range v {
obj, ok := item.(map[string]any)
if !ok {
continue
}
if itemType, _ := obj["type"].(string); itemType == "text" {
if text, ok := obj["text"].(string); ok {
sb.WriteString(text)
}
}
}
return sb.String()
default:
return ""
}
}
func lfm2JSON(v any) string {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.SetEscapeHTML(false)
if err := enc.Encode(v); err != nil {
fallback, _ := json.Marshal(v)
return string(fallback)
}
encoded := bytes.TrimSuffix(buf.Bytes(), []byte{'\n'})
// HF `tojson` defaults to `json.dumps(..., separators=None)`, which inserts
// a space after commas and colons.
var out strings.Builder
out.Grow(len(encoded) + len(encoded)/8)
inString := false
escaped := false
for i, b := range encoded {
out.WriteByte(b)
if inString {
if escaped {
escaped = false
continue
}
if b == '\\' {
escaped = true
continue
}
if b == '"' {
inString = false
}
continue
}
if b == '"' {
inString = true
continue
}
if (b == ':' || b == ',') && i+1 < len(encoded) {
next := encoded[i+1]
if next != ' ' && next != '\n' && next != '\r' && next != '\t' {
out.WriteByte(' ')
}
}
}
return out.String()
}
func lfm2ImagePlaceholder(useImgTags bool) string {
if useImgTags {
return "[img]"
}
return "<image>"
}
func lfm2RenderContent(content any, useImgTags bool) string {
switch v := content.(type) {
case string:
return v
case []any:
var sb strings.Builder
for _, item := range v {
obj, ok := item.(map[string]any)
if !ok {
sb.WriteString(lfm2JSON(item))
continue
}
itemType, _ := obj["type"].(string)
switch itemType {
case "image":
sb.WriteString(lfm2ImagePlaceholder(useImgTags))
case "text":
if text, ok := obj["text"].(string); ok {
sb.WriteString(text)
} else {
sb.WriteString(lfm2JSON(item))
}
default:
sb.WriteString(lfm2JSON(item))
}
}
return sb.String()
default:
return lfm2JSON(content)
}
}
func lfm2ToolSchema(tool api.Tool) any {
if tool.Function.Name == "" {
return tool
}
// LFM2 tool prompts in HF are function-schema objects (name/description/parameters)
// rather than OpenAI's {type,function} wrapper.
return tool.Function
}
func lfm2ToolCallArgument(v any) string {
return lfm2JSON(v)
}
func lfm2RenderToolCalls(calls []api.ToolCall) string {
var sb strings.Builder
sb.WriteString(lfm2ToolCallStartTag)
sb.WriteString("[")
for i, tc := range calls {
if i > 0 {
sb.WriteString(",")
}
sb.WriteString(tc.Function.Name)
sb.WriteString("(")
keys := make([]string, 0, tc.Function.Arguments.Len())
for key := range tc.Function.Arguments.All() {
keys = append(keys, key)
}
sort.Strings(keys)
for j, key := range keys {
if j > 0 {
sb.WriteString(",")
}
value, _ := tc.Function.Arguments.Get(key)
sb.WriteString(key)
sb.WriteString("=")
sb.WriteString(lfm2ToolCallArgument(value))
}
sb.WriteString(")")
}
sb.WriteString("]")
sb.WriteString(lfm2ToolCallEndTag)
return sb.String()
}
func (r *LFM2Renderer) renderMessageContent(message api.Message) string {
content := lfm2RenderContent(message.Content, r.useImgTags)
if len(message.Images) == 0 {
return content
}
// chatPrompt may already have inserted [img] / [img-n] placeholders.
if strings.Contains(content, "[img]") || strings.Contains(content, "[img-") || strings.Contains(content, "<image>") {
return content
}
var sb strings.Builder
placeholder := lfm2ImagePlaceholder(r.useImgTags)
for range message.Images {
sb.WriteString(placeholder)
}
sb.WriteString(content)
return sb.String()
}
func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
// Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer
// Match the source chat_template.jinja.
sb.WriteString(lfm2BOSToken)
// Extract first system message if present (to combine with tools)
var firstSystemContent string
startIdx := 0
if len(messages) > 0 && messages[0].Role == "system" {
firstSystemContent = messages[0].Content
firstSystemContent = lfm2RenderSystemContent(messages[0].Content)
startIdx = 1
}
@@ -29,18 +232,17 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
if firstSystemContent != "" {
firstSystemContent += "\n"
}
firstSystemContent += "List of tools: ["
firstSystemContent += "List of tools: "
firstSystemContent += lfm2ToolListStartTag
firstSystemContent += "["
for i, tool := range tools {
toolJSON, err := json.Marshal(tool)
if err != nil {
return "", err
}
firstSystemContent += string(toolJSON)
firstSystemContent += lfm2JSON(lfm2ToolSchema(tool))
if i < len(tools)-1 {
firstSystemContent += ", "
}
}
firstSystemContent += "]"
firstSystemContent += lfm2ToolListEndTag
}
// Output first system block if it has content
@@ -50,6 +252,8 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
sb.WriteString("<|im_end|>\n")
}
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
// Find the index of the last assistant message for thinking stripping
lastAssistantIndex := -1
for i := len(messages) - 1; i >= startIdx; i-- {
@@ -59,85 +263,43 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
}
}
// Track whether we need to add generation prompt
needsGenerationPrompt := len(messages) > 0
for i := startIdx; i < len(messages); i++ {
message := messages[i]
switch message.Role {
case "system":
// Additional system messages (after the first) are rendered normally
sb.WriteString("<|im_start|>system\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
lastMessage := i == len(messages)-1
prefill := lastMessage && message.Role == "assistant"
case "user":
sb.WriteString("<|im_start|>user\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true
sb.WriteString("<|im_start|>")
sb.WriteString(message.Role)
sb.WriteString("\n")
case "assistant":
sb.WriteString("<|im_start|>assistant\n")
// Check if this is the last assistant message
isLastAssistant := i == lastAssistantIndex
// Process content (may need thinking stripped)
content := message.Content
// Handle thinking tags in assistant content
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
if strings.Contains(content, "</think>") {
parts := strings.SplitN(content, "</think>", 2)
if len(parts) > 1 {
if !isLastAssistant && !keepPastThinking {
// Strip thinking entirely for past assistant messages
content = strings.TrimSpace(parts[1])
} else {
// Preserve thinking but trim whitespace after </think>
content = parts[0] + "</think>" + strings.TrimLeft(parts[1], " \t\n\r")
}
}
}
if len(message.ToolCalls) > 0 {
// Assistant with tool calls - write content first (if any after stripping)
if content != "" {
sb.WriteString(content)
}
for _, toolCall := range message.ToolCalls {
sb.WriteString("<|tool_call_start|>")
toolCallJSON := map[string]any{
"name": toolCall.Function.Name,
"arguments": toolCall.Function.Arguments,
}
callJSON, _ := json.Marshal(toolCallJSON)
sb.WriteString(string(callJSON))
sb.WriteString("<|tool_call_end|>")
}
} else {
sb.WriteString(content)
content := r.renderMessageContent(message)
if message.Role == "assistant" && !keepPastThinking && i != lastAssistantIndex {
if idx := strings.LastIndex(content, "</think>"); idx >= 0 {
content = strings.TrimSpace(content[idx+len("</think>"):])
}
}
if message.Role == "assistant" && len(message.ToolCalls) > 0 && !strings.Contains(content, lfm2ToolCallStartTag) {
content = lfm2RenderToolCalls(message.ToolCalls) + content
}
if message.Role == "tool" && !strings.Contains(content, lfm2ToolResponseStartTag) {
content = lfm2ToolResponseStartTag + content + lfm2ToolResponseEndTag
}
sb.WriteString(content)
if !prefill {
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true
case "tool":
// Tool responses are rendered as plain messages per the chat template
sb.WriteString("<|im_start|>tool\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true
}
}
// Add generation prompt
needsGenerationPrompt := true
if len(messages) > 0 && messages[len(messages)-1].Role == "assistant" {
needsGenerationPrompt = false
}
if needsGenerationPrompt {
// RenderWithRenderer uses add_generation_prompt=true for chat rendering,
// unless we're prefilling a trailing assistant message.
sb.WriteString("<|im_start|>assistant\n")
// Note: Model is a "thinking-only" model - it will output <think> itself
// We don't add <think> tag to the prompt
}
return sb.String(), nil

View File

@@ -8,73 +8,113 @@ import (
"github.com/ollama/ollama/api"
)
func TestLFM2Renderer(t *testing.T) {
func TestLFM2Renderer_ChatTemplateParity(t *testing.T) {
tests := []struct {
name string
renderer *LFM2Renderer
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
}{
{
name: "basic user message",
name: "user_only",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "user", Content: "Hello!"},
{Role: "user", Content: "Hello"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
expected: "<|startoftext|><|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple system messages rendered separately",
messages: []api.Message{
{Role: "system", Content: "First instruction."},
{Role: "system", Content: "Second instruction."},
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multi-turn conversation",
messages: []api.Message{
{Role: "user", Content: "What is 2+2?"},
{Role: "assistant", Content: "The answer is 4."},
{Role: "user", Content: "Thanks!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "only system message",
name: "system_and_user",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hi"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n",
expected: "<|startoftext|><|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n",
},
{
// When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false)
name: "user-assistant-user: last assistant preserves thinking",
name: "tools_without_system",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reasoning</think>A1"},
{Role: "user", Content: "Q2"},
{Role: "user", Content: "Use tools"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reasoning</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n",
expected: "<|startoftext|><|im_start|>system\nList of tools: <|tool_list_start|>[{\"name\": \"get_weather\", \"parameters\": {\"type\": \"object\", \"properties\": null}}]<|tool_list_end|><|im_end|>\n" +
"<|im_start|>user\nUse tools<|im_end|>\n<|im_start|>assistant\n",
},
{
// With two assistants, first is stripped (not last), second preserved (is last)
name: "multi-turn thinking: first stripped, second preserved",
name: "first_system_combined_with_tools",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "system", Content: "Follow instructions."},
{Role: "user", Content: "Do work"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "tool_a",
Parameters: api.ToolFunctionParameters{
Type: "object",
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "tool_b",
Parameters: api.ToolFunctionParameters{
Type: "object",
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|startoftext|><|im_start|>system\nFollow instructions.\nList of tools: <|tool_list_start|>[{\"name\": \"tool_a\", \"parameters\": {\"type\": \"object\", \"properties\": null}}, {\"name\": \"tool_b\", \"parameters\": {\"type\": \"object\", \"properties\": null}}]<|tool_list_end|><|im_end|>\n" +
"<|im_start|>user\nDo work<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "assistant_tool_calls_and_tool_responses_are_rendered",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "user", Content: "Call a tool"},
{
Role: "assistant",
Content: "",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|startoftext|><|im_start|>user\nCall a tool<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|><|im_end|>\n<|im_start|>tool\n<|tool_response_start|>22C<|tool_response_end|><|im_end|>\n<|im_start|>assistant\n",
},
{
name: "thinking_strips_non_last_assistant_when_disabled",
renderer: &LFM2Renderer{IsThinking: true},
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reason1</think>A1"},
@@ -82,11 +122,11 @@ func TestLFM2Renderer(t *testing.T) {
{Role: "assistant", Content: "<think>reason2</think>A2"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
expected: "<|startoftext|><|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2",
},
{
// With thinking enabled (keep_past_thinking=true), both preserved
name: "multi-turn thinking: both preserved when thinking enabled",
name: "thinking_preserves_past_assistant_when_enabled",
renderer: &LFM2Renderer{IsThinking: true},
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reason1</think>A1"},
@@ -94,334 +134,137 @@ func TestLFM2Renderer(t *testing.T) {
{Role: "assistant", Content: "<think>reason2</think>A2"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
expected: "<|startoftext|><|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2",
},
{
name: "assistant with tool calls",
name: "arbitrary_roles_are_rendered_verbatim",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "developer", Content: "Do X"},
{Role: "user", Content: "Hi"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
expected: "<|startoftext|><|im_start|>developer\nDo X<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "assistant with content and tool calls",
name: "empty_messages_still_add_generation_prompt",
renderer: &LFM2Renderer{IsThinking: false},
messages: nil,
thinkValue: &api.ThinkValue{Value: false},
expected: "<|startoftext|><|im_start|>assistant\n",
},
{
name: "assistant_prefill_no_generation_prompt",
renderer: &LFM2Renderer{IsThinking: false},
messages: []api.Message{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: "Let me check.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tool response",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{Role: "assistant", Content: "Let me check."},
{Role: "tool", Content: "22C, Sunny"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple tool calls",
messages: []api.Message{
{Role: "user", Content: "Get weather for Paris and London"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "London",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tools definitions with system message",
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "What's the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tools definitions without system message",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "multiple tools without system message",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "get_time",
Description: "Get time",
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nList of tools: [{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather\",\"parameters\":{\"type\":\"\",\"properties\":null}}}, {\"type\":\"function\",\"function\":{\"name\":\"get_time\",\"description\":\"Get time\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "user-tool sequence",
messages: []api.Message{
{Role: "user", Content: "Check weather"},
{Role: "tool", Content: "22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "full tool call cycle",
messages: []api.Message{
{Role: "user", Content: "Check weather"},
{Role: "assistant", Content: "Let me check"},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "It's 22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "unicode content",
messages: []api.Message{
{Role: "user", Content: "你好世界! مرحبا 🌍"},
{Role: "assistant", Content: "Hello! 👋"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "newlines in content",
messages: []api.Message{
{Role: "user", Content: "Line 1\nLine 2\n\nLine 4"},
{Role: "assistant", Content: "Response with\nmultiple\nlines"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "empty assistant content",
messages: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: ""},
{Role: "user", Content: "OK"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n",
},
{
// Generation prompt does NOT include <think> - model outputs it
name: "generation prompt has no think tag",
messages: []api.Message{
{Role: "user", Content: "Think hard"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n",
},
{
// Interleaved: thinking before tool call - last assistant preserves thinking
name: "thinking before tool call (last assistant)",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>I need to check the weather</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>I need to check the weather</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
},
{
// Two assistants with tool calls - first has thinking stripped
name: "two assistants with tools: first thinking stripped",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>checking</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
},
{
// Two assistants with tools - both preserved when thinking enabled
name: "two assistants with tools: both preserved when thinking enabled",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>checking</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>checking</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
},
{
// Content before thinking before tool call
name: "content then thinking then tool call",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "Let me check.<think>Using weather API</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<think>Using weather API</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
expected: "<|startoftext|><|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\nHello",
},
}
renderer := &LFM2Renderer{IsThinking: true}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
rendered, err := tt.renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatalf("Render() error = %v", err)
}
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestLFM2Renderer_Images(t *testing.T) {
tests := []struct {
name string
renderer *LFM2Renderer
message api.Message
expected string
}{
{
name: "single_image_default_placeholder",
renderer: &LFM2Renderer{},
message: api.Message{
Role: "user",
Content: "Describe this image.",
Images: []api.ImageData{api.ImageData("img1")},
},
expected: "<|startoftext|><|im_start|>user\n<image>Describe this image.<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple_images_default_placeholder",
renderer: &LFM2Renderer{},
message: api.Message{
Role: "user",
Content: "Describe these images.",
Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")},
},
expected: "<|startoftext|><|im_start|>user\n<image><image>Describe these images.<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "single_image_img_tag_placeholder",
renderer: &LFM2Renderer{useImgTags: true},
message: api.Message{
Role: "user",
Content: "Describe this image.",
Images: []api.ImageData{api.ImageData("img1")},
},
expected: "<|startoftext|><|im_start|>user\n[img]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "existing_indexed_img_placeholder_not_duplicated",
renderer: &LFM2Renderer{useImgTags: true},
message: api.Message{
Role: "user",
Content: "[img-0]Describe this image.",
Images: []api.ImageData{api.ImageData("img1")},
},
expected: "<|startoftext|><|im_start|>user\n[img-0]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "existing_template_image_placeholder_not_duplicated",
renderer: &LFM2Renderer{},
message: api.Message{
Role: "user",
Content: "<image>Describe this image.",
Images: []api.ImageData{api.ImageData("img1")},
},
expected: "<|startoftext|><|im_start|>user\n<image>Describe this image.<|im_end|>\n<|im_start|>assistant\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.renderer.Render([]api.Message{tt.message}, nil, nil)
if err != nil {
t.Fatalf("Render() error = %v", err)
}
if diff := cmp.Diff(tt.expected, got); diff != "" {
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestLFM2Renderer_JSONFormatting(t *testing.T) {
tool := api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: "echo",
Description: "<html>",
Parameters: api.ToolFunctionParameters{
Type: "object",
},
},
}
got := lfm2JSON(tool)
want := "{\"type\": \"function\", \"function\": {\"name\": \"echo\", \"description\": \"<html>\", \"parameters\": {\"type\": \"object\", \"properties\": null}}}"
if diff := cmp.Diff(want, got); diff != "" {
t.Fatalf("lfm2JSON mismatch (-want +got):\n%s", diff)
}
}

View File

@@ -85,9 +85,9 @@ func rendererForName(name string) Renderer {
case "glm-ocr":
return &GlmOcrRenderer{}
case "lfm2":
return &LFM2Renderer{IsThinking: false}
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
case "lfm2-thinking":
return &LFM2Renderer{IsThinking: true}
return &LFM2Renderer{IsThinking: true, useImgTags: RenderImgTags}
default:
return nil
}

View File

@@ -447,7 +447,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
// Some architectures are not safe with num_parallel > 1.
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
numParallel = 1
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
}

View File

@@ -18,7 +18,9 @@
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
handle->ctx = (void*) DLOPEN(path);
CHECK(handle->ctx != NULL);
if (handle->ctx == NULL) {
return 1;
}
return 0;
}

View File

@@ -55,6 +55,30 @@ func tryLoadFromDir(dir string) bool {
return false
}
// tryLoadByName attempts to load the library using just its name,
// allowing the system to use rpath, LD_LIBRARY_PATH, or standard search paths.
// Returns true if the library was successfully loaded.
func tryLoadByName() bool {
libraryName := "libmlxc.dylib"
if runtime.GOOS == "linux" {
libraryName = "libmlxc.so"
}
cPath := C.CString(libraryName)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
return false
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
C.mlx_dynamic_unload(&handle)
return false
}
return true
}
func init() {
switch runtime.GOOS {
case "darwin":
@@ -73,6 +97,11 @@ func init() {
}
}
// Try loading via rpath/standard library search
if tryLoadByName() {
return
}
// Build search paths: executable directory, then build directories
var searchDirs []string
if exe, err := os.Executable(); err == nil {

View File

@@ -8,10 +8,10 @@ import (
"log/slog"
"sync"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/tokenizer"
)
// Model is the interface that model implementations must satisfy.

View File

@@ -7,7 +7,6 @@ import (
"errors"
"log/slog"
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
@@ -126,13 +125,5 @@ func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
return ""
}
if text := b.String(); utf8.ValidString(text) {
b.Reset()
return text
} else if b.Len() >= utf8.UTFMax {
b.Reset()
return text
}
return ""
return flushValidUTF8Prefix(b)
}

View File

@@ -12,12 +12,12 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/tokenizer"
)
type Request struct {

View File

@@ -0,0 +1,47 @@
package mlxrunner
import (
"bytes"
"unicode/utf8"
)
// flushValidUTF8Prefix returns and consumes the longest valid UTF-8 prefix
// currently buffered, leaving any incomplete trailing bytes in place.
func flushValidUTF8Prefix(b *bytes.Buffer) string {
data := b.Bytes()
if len(data) == 0 {
return ""
}
prefix := validUTF8PrefixLen(data)
if prefix == 0 {
return ""
}
text := string(data[:prefix])
b.Next(prefix)
return text
}
func validUTF8PrefixLen(data []byte) int {
i := 0
prefix := 0
for i < len(data) {
r, size := utf8.DecodeRune(data[i:])
if r == utf8.RuneError && size == 1 {
if !utf8.FullRune(data[i:]) {
break
}
// Invalid UTF-8 byte; consume one byte to guarantee forward progress.
i++
prefix = i
continue
}
i += size
prefix = i
}
return prefix
}

View File

@@ -0,0 +1,46 @@
package mlxrunner
import (
"bytes"
"testing"
)
func TestFlushValidUTF8Prefix_PreservesIncompleteRune(t *testing.T) {
var b bytes.Buffer
b.Write([]byte{0xE3, 0x81})
if got := flushValidUTF8Prefix(&b); got != "" {
t.Fatalf("first flush = %q, want empty", got)
}
b.Write([]byte{0x93, 0xE3})
if got := flushValidUTF8Prefix(&b); got != "こ" {
t.Fatalf("second flush = %q, want %q", got, "こ")
}
if got := b.Bytes(); !bytes.Equal(got, []byte{0xE3}) {
t.Fatalf("buffer after second flush = %v, want %v", got, []byte{0xE3})
}
b.Write([]byte{0x82, 0x93})
if got := flushValidUTF8Prefix(&b); got != "ん" {
t.Fatalf("third flush = %q, want %q", got, "ん")
}
if b.Len() != 0 {
t.Fatalf("buffer not empty after third flush: %d", b.Len())
}
}
func TestFlushValidUTF8Prefix_ValidText(t *testing.T) {
var b bytes.Buffer
b.WriteString("hello 世界")
if got := flushValidUTF8Prefix(&b); got != "hello 世界" {
t.Fatalf("flush = %q, want %q", got, "hello 世界")
}
if b.Len() != 0 {
t.Fatalf("buffer not empty after flush: %d", b.Len())
}
}

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {

View File

@@ -9,12 +9,12 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {

108
x/tokenizer/tokenizer.go Normal file
View File

@@ -0,0 +1,108 @@
//go:build mlx
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
//
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
// - GPT-2 byte-level encoding (OpenAI tiktoken)
// - HuggingFace tokenizer.json pretokenizer patterns
// - SentencePiece ▁-style space handling
package tokenizer
import "regexp"
// TokenizerType identifies the tokenization algorithm
type TokenizerType int
const (
TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE
TokenizerSentencePiece // SentencePiece with ▁ for spaces
)
// Vocabulary holds the tokenizer vocabulary and merges
type Vocabulary struct {
Values []string
Reverse map[string]int32
Merges map[string]int
BOS int32
EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has <eos> and <end_of_turn>)
PAD int32 // Padding token (often <|endoftext|> or <pad>)
AddBOS bool
AddEOS bool
// Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found)
byteTokens [256]int32
}
// Tokenizer handles BPE and SentencePiece tokenization
type Tokenizer struct {
vocab *Vocabulary
pretokenizer *regexp.Regexp
specialTokens map[string]int32 // Special tokens for direct lookup
sortedSpecialTokens []string // Special tokens sorted by length, longest first
typ TokenizerType // Algorithm type
}
// Precomputed GPT-2 byte-level encoding table
// Maps byte values to their encoded rune equivalents
var byteToRune [256]rune
func init() {
for b := 0; b < 256; b++ {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
byteToRune[b] = r
}
}
// VocabSize returns the vocabulary size
func (t *Tokenizer) VocabSize() int {
return len(t.vocab.Values)
}
// BOS returns the beginning of sequence token ID
func (t *Tokenizer) BOS() int32 {
return t.vocab.BOS
}
// EOS returns the first end of sequence token ID (for backwards compatibility)
func (t *Tokenizer) EOS() int32 {
if len(t.vocab.EOS) > 0 {
return t.vocab.EOS[0]
}
return -1
}
// EOSTokens returns all end of sequence token IDs
func (t *Tokenizer) EOSTokens() []int32 {
return t.vocab.EOS
}
// PAD returns the padding token ID, or -1 if not set
func (t *Tokenizer) PAD() int32 {
return t.vocab.PAD
}
// IsEOS returns true if the token ID is an end of sequence token
func (t *Tokenizer) IsEOS(id int32) bool {
for _, eos := range t.vocab.EOS {
if id == eos {
return true
}
}
return false
}
// GetSpecialToken returns the token ID for a special token string
func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
id, ok := t.specialTokens[name]
return id, ok
}

View File

@@ -0,0 +1,251 @@
//go:build mlx
package tokenizer
import (
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
var (
benchmarkSinkIDs []int32
benchmarkSinkStr string
benchmarkSinkTok *Tokenizer
)
const benchmarkWordPieceJSON = `{
"model": {
"type": "WordPiece",
"vocab": {
"[UNK]": 0,
"hello": 1,
"##world": 2,
"##ly": 3,
"##hello": 4
}
},
"added_tokens": []
}`
const benchmarkSentencePieceJSON = `{
"model": {
"type": "BPE",
"vocab": {
"\u2581": 0,
"h": 1,
"e": 2,
"l": 3,
"o": 4,
"w": 5,
"r": 6,
"d": 7,
"<0x0A>": 8
},
"merges": []
},
"decoder": {
"type": "Sequence",
"decoders": [
{
"type": "Replace",
"pattern": {
"String": "\u2581"
}
}
]
},
"added_tokens": []
}`
func benchmarkMiniLlamaPath(tb testing.TB) string {
tb.Helper()
_, filename, _, ok := runtime.Caller(0)
if !ok {
tb.Fatal("failed to resolve benchmark file path")
}
return filepath.Join(filepath.Dir(filename), "..", "imagegen", "tokenizer", "testdata", "mini_llama.json")
}
func benchmarkLoadMiniLlama(tb testing.TB) *Tokenizer {
tb.Helper()
data := benchmarkLoadMiniLlamaBytes(tb)
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load mini llama tokenizer: %v", err)
}
return tok
}
func benchmarkLoadMiniLlamaBytes(tb testing.TB) []byte {
tb.Helper()
data, err := os.ReadFile(benchmarkMiniLlamaPath(tb))
if err != nil {
tb.Fatalf("failed to read mini llama tokenizer: %v", err)
}
return data
}
func benchmarkLoadFromBytes(tb testing.TB, data []byte) *Tokenizer {
tb.Helper()
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load tokenizer from bytes: %v", err)
}
return tok
}
func BenchmarkTokenizerEncodeBPE(b *testing.B) {
tok := benchmarkLoadMiniLlama(b)
inputs := []struct {
name string
text string
}{
{name: "short", text: "Hello, world!"},
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
{name: "long_sequential", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 80)},
{name: "long_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
{name: "huge_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)},
{name: "special_tokens", text: "<|begin_of_text|>system\nYou are concise.<|end_of_text|>"},
}
for _, input := range inputs {
b.Run(input.name, func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(input.text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(input.text, false)
}
})
}
}
func BenchmarkTokenizerDecodeBPE(b *testing.B) {
tok := benchmarkLoadMiniLlama(b)
inputs := []struct {
name string
text string
}{
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
{name: "long", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
}
for _, input := range inputs {
ids := tok.Encode(input.text, false)
b.Run(input.name, func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(input.text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
})
}
}
func BenchmarkTokenizerLoadFromBytes(b *testing.B) {
data := benchmarkLoadMiniLlamaBytes(b)
config := &TokenizerConfig{
TokenizerConfigJSON: []byte(`{
"bos_token": {"content": "<|begin_of_text|>"},
"eos_token": {"content": "<|end_of_text|>"},
"add_bos_token": true
}`),
GenerationConfigJSON: []byte(`{"bos_token_id": 128000, "eos_token_id": 128001}`),
}
b.Run("without_config", func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
tok, err := LoadFromBytes(data)
if err != nil {
b.Fatalf("LoadFromBytes failed: %v", err)
}
benchmarkSinkTok = tok
}
})
b.Run("with_config", func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
tok, err := LoadFromBytesWithConfig(data, config)
if err != nil {
b.Fatalf("LoadFromBytesWithConfig failed: %v", err)
}
benchmarkSinkTok = tok
}
})
}
func BenchmarkTokenizerEncodeWordPiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
text := strings.Repeat("helloworldly", 16)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(text, false)
}
}
func BenchmarkTokenizerDecodeWordPiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
text := strings.Repeat("helloworldly", 16)
ids := tok.Encode(text, false)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
}
func BenchmarkTokenizerEncodeSentencePiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
text := strings.Repeat("hello world\n", 64)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(text, false)
}
}
func BenchmarkTokenizerDecodeSentencePiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
text := strings.Repeat("hello world\n", 64)
ids := tok.Encode(text, false)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
}

View File

@@ -0,0 +1,175 @@
//go:build mlx
package tokenizer
import "container/heap"
type bpeMergeNode struct {
prev int
next int
token string
}
type bpePair struct {
left int
right int
rank int
value string
}
type bpePairHeap []*bpePair
func (h bpePairHeap) Len() int { return len(h) }
func (h bpePairHeap) Less(i, j int) bool {
return h[i].rank < h[j].rank || (h[i].rank == h[j].rank && h[i].left < h[j].left)
}
func (h bpePairHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *bpePairHeap) Push(x any) {
*h = append(*h, x.(*bpePair))
}
func (h *bpePairHeap) Pop() any {
old := *h
n := len(old)
item := old[n-1]
*h = old[:n-1]
return item
}
// encodeBPEMerge encodes using BPE merge algorithm.
// Uses the heap/linked-list pair merge strategy from tokenizer/bytepairencoding.go:
// merge the lowest-rank valid pair, then only recheck adjacent pairs.
func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
runes := []rune(encoded)
if len(runes) == 0 {
return ids
}
nodes := make([]bpeMergeNode, len(runes))
for i := range runes {
nodes[i] = bpeMergeNode{
prev: i - 1,
next: i + 1,
token: string(runes[i]),
}
}
pairwise := func(left, right int) *bpePair {
if left < 0 || right >= len(nodes) {
return nil
}
if nodes[left].token == "" || nodes[right].token == "" {
return nil
}
leftToken, rightToken := nodes[left].token, nodes[right].token
rank, ok := t.vocab.Merges[leftToken+" "+rightToken]
if !ok {
return nil
}
value := leftToken + rightToken
if _, ok := t.vocab.Reverse[value]; !ok {
return nil
}
return &bpePair{
left: left,
right: right,
rank: rank,
value: value,
}
}
pairs := bpePairHeap{}
heap.Init(&pairs)
for i := 0; i < len(runes)-1; i++ {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(&pairs, pair)
}
}
for pairs.Len() > 0 {
pair := heap.Pop(&pairs).(*bpePair)
left, right := nodes[pair.left], nodes[pair.right]
if left.token == "" || right.token == "" {
continue
}
if left.next != pair.right || right.prev != pair.left {
continue
}
if left.token+right.token != pair.value {
continue
}
nodes[pair.left].token = pair.value
nodes[pair.right].token = ""
nodes[pair.left].next = right.next
if right.next < len(nodes) {
nodes[right.next].prev = pair.left
}
if pair := pairwise(nodes[pair.left].prev, pair.left); pair != nil {
heap.Push(&pairs, pair)
}
if pair := pairwise(pair.left, nodes[pair.left].next); pair != nil {
heap.Push(&pairs, pair)
}
}
for _, node := range nodes {
if node.token == "" {
continue
}
if id, ok := t.vocab.Reverse[node.token]; ok {
ids = append(ids, id)
continue
}
ids = t.appendByteFallback(ids, node.token)
}
return ids
}
func (t *Tokenizer) appendByteFallback(ids []int32, token string) []int32 {
if t.typ == TokenizerBPE {
for _, r := range token {
if b, ok := decodeByteLevelRune(r); ok {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
}
return ids
}
// SentencePiece fallback uses the UTF-8 bytes for <0xNN> tokens.
for _, b := range []byte(token) {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
return ids
}
func decodeByteLevelRune(r rune) (byte, bool) {
switch {
case r >= 0x00 && r <= 0xFF:
return byte(r), true
case r == 0x0100:
return 0x00, true
case r == 0x0143:
return 0x00ad, true
case r > 0x0100 && r <= 0x0120:
return byte(r - 0x0100), true
case r > 0x0120 && r <= 0x0142:
return byte(r - 0x00a2), true
default:
return 0, false
}
}

View File

@@ -0,0 +1,137 @@
//go:build mlx
package tokenizer
import (
"runtime"
"strings"
"testing"
)
func equalIDs(a, b []int32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func TestEncodeRoundtripMiniLlama(t *testing.T) {
tok := benchmarkLoadMiniLlama(t)
inputs := []string{
"",
"hello",
"hello world",
" hello world ",
"don't we'll they're",
"1234567890",
"こんにちは世界",
"Hello 世界",
"func main() {}",
"<|begin_of_text|>system\nYou are concise.<|end_of_text|>",
strings.Repeat("The quick brown fox jumps over the lazy dog. ", 32),
}
for _, input := range inputs {
ids := tok.Encode(input, false)
got := tok.Decode(ids)
if got != input {
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
}
}
}
func TestSplitBySpecialTokensGreedyLongest(t *testing.T) {
data := []byte(`{
"model": {
"type": "BPE",
"vocab": {"a": 0, "b": 1},
"merges": []
},
"added_tokens": [
{"id": 2, "content": "<tag>", "special": true},
{"id": 3, "content": "<tag>x", "special": true}
]
}`)
tok, err := LoadFromBytes(data)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
input := "a<tag>xb"
want := []string{"a", "<tag>x", "b"}
got := tok.splitBySpecialTokens(input)
if len(got) != len(want) {
t.Fatalf("split length mismatch: got %v want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
}
}
}
func TestSplitBySpecialTokensFallbackWithoutCache(t *testing.T) {
data := []byte(`{
"model": {
"type": "BPE",
"vocab": {"a": 0, "b": 1},
"merges": []
},
"added_tokens": [
{"id": 2, "content": "<tag>", "special": true},
{"id": 3, "content": "<tag>x", "special": true}
]
}`)
tok, err := LoadFromBytes(data)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
input := "a<tag>xb"
want := []string{"a", "<tag>x", "b"}
// Simulate construction outside loader path where cache is not set.
tok.sortedSpecialTokens = nil
got := tok.splitBySpecialTokens(input)
if len(got) != len(want) {
t.Fatalf("split length mismatch: got %v want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
}
}
}
func TestEncodeDeterministicAcrossGOMAXPROCS(t *testing.T) {
tok := benchmarkLoadMiniLlama(t)
input := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)
prev := runtime.GOMAXPROCS(0)
defer runtime.GOMAXPROCS(prev)
runtime.GOMAXPROCS(1)
seq := tok.Encode(input, false)
if prev < 2 {
runtime.GOMAXPROCS(2)
} else {
runtime.GOMAXPROCS(prev)
}
par := tok.Encode(input, false)
if !equalIDs(seq, par) {
t.Fatalf("encode mismatch between sequential and parallel paths: seq=%d par=%d", len(seq), len(par))
}
}

View File

@@ -0,0 +1,56 @@
//go:build mlx
package tokenizer
import (
"strconv"
"strings"
)
// Decode converts token IDs back to text
func (t *Tokenizer) Decode(ids []int32) string {
var sb strings.Builder
for _, id := range ids {
if int(id) >= len(t.vocab.Values) {
continue
}
token := t.vocab.Values[id]
switch t.typ {
case TokenizerSentencePiece:
// SentencePiece style: replace ▁ with space, decode byte tokens
token = strings.ReplaceAll(token, "▁", " ")
// Handle byte fallback tokens like <0x0D>
if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' {
if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil {
sb.WriteByte(byte(v))
continue
}
}
sb.WriteString(token)
default:
// GPT-2 BPE style: decode byte-level encoding
for _, r := range token {
switch {
case r == 0x0100:
// Mirror GGML tokenizer behavior for NULL byte.
// 0x00 is omitted during decode.
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// Write as byte, not UTF-8 encoded rune
sb.WriteByte(byte(r))
}
}
}
return sb.String()
}

View File

@@ -0,0 +1,289 @@
//go:build mlx
package tokenizer
import (
"runtime"
"sort"
"strings"
"sync"
"unicode"
"unicode/utf8"
)
const (
encodeParallelMinInputBytes = 4 * 1024
encodeParallelMinChunksPerWorker = 8
)
type tokenMatch struct {
start int
end int
}
type encodeChunk struct {
text string
isSpecial bool
}
// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines)
func isNonNewlineWhitespace(s string) bool {
if s == "" {
return false
}
for _, r := range s {
if r == '\n' || r == '\r' {
return false
}
if !unicode.IsSpace(r) {
return false
}
}
return true
}
// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements
func (t *Tokenizer) splitBySpecialTokens(s string) []string {
if len(t.specialTokens) == 0 {
return []string{s}
}
tokens := t.sortedSpecialTokens
if len(tokens) == 0 {
// Fallback for tokenizers constructed outside the loaders.
tokens = make([]string, 0, len(t.specialTokens))
for tok := range t.specialTokens {
tokens = append(tokens, tok)
}
sort.Slice(tokens, func(i, j int) bool {
return len(tokens[i]) > len(tokens[j])
})
}
var result []string
remaining := s
for len(remaining) > 0 {
found := false
for _, tok := range tokens {
if strings.HasPrefix(remaining, tok) {
result = append(result, tok)
remaining = remaining[len(tok):]
found = true
break
}
}
if !found {
// Find next special token position
nextPos := len(remaining)
for _, tok := range tokens {
if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos {
nextPos = idx
}
}
if nextPos > 0 {
result = append(result, remaining[:nextPos])
}
remaining = remaining[nextPos:]
}
}
return result
}
func adjustWhitespaceBoundary(part string, curr, next *tokenMatch) {
m := part[curr.start:curr.end]
nextText := part[next.start:next.end]
if !isNonNewlineWhitespace(m) || len(nextText) == 0 {
return
}
firstRune, _ := utf8.DecodeRuneInString(nextText)
if !unicode.IsLetter(firstRune) {
return
}
lastSpaceStart := curr.end
for j := curr.end; j > curr.start; {
r, size := utf8.DecodeLastRuneInString(part[curr.start:j])
if unicode.IsSpace(r) {
lastSpaceStart = j - size
break
}
j -= size
}
if lastSpaceStart > curr.start {
curr.end = lastSpaceStart
next.start = lastSpaceStart
} else {
next.start = curr.start
curr.end = curr.start
}
}
func (t *Tokenizer) forEachPartChunk(part string, fn func(encodeChunk)) {
if _, ok := t.specialTokens[part]; ok {
fn(encodeChunk{text: part, isSpecial: true})
return
}
if t.pretokenizer == nil {
fn(encodeChunk{text: part, isSpecial: false})
return
}
re := t.pretokenizer
offset := 0
loc := re.FindStringIndex(part[offset:])
if loc == nil {
return
}
curr := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
offset += loc[1]
for {
loc = re.FindStringIndex(part[offset:])
if loc == nil {
if curr.end > curr.start {
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
}
return
}
next := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
offset += loc[1]
adjustWhitespaceBoundary(part, &curr, &next)
if curr.end > curr.start {
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
}
curr = next
}
}
func (t *Tokenizer) appendEncodedChunk(ids []int32, c encodeChunk) []int32 {
if c.isSpecial {
if id, ok := t.specialTokens[c.text]; ok {
return append(ids, id)
}
return ids
}
return t.encodeChunkInto(c.text, ids)
}
// Encode tokenizes text to token IDs.
// Parallel encoding is used only for very large inputs with enough chunks per worker.
func (t *Tokenizer) Encode(s string, addBOS bool) []int32 {
// First: split by special tokens
parts := t.splitBySpecialTokens(s)
// Fast path: encode sequentially without materializing chunk slices.
if len(s) < encodeParallelMinInputBytes {
var ids []int32
for _, part := range parts {
t.forEachPartChunk(part, func(c encodeChunk) {
ids = t.appendEncodedChunk(ids, c)
})
}
if addBOS && t.vocab.BOS >= 0 {
ids = append([]int32{t.vocab.BOS}, ids...)
}
return ids
}
// For large inputs collect chunks to enable parallel processing.
var allChunks []encodeChunk
for _, part := range parts {
t.forEachPartChunk(part, func(c encodeChunk) {
allChunks = append(allChunks, c)
})
}
// Encode chunks. Use the parallel path only when the chunk count is
// large enough to amortize goroutine/synchronization overhead.
useParallel := true
numWorkers := runtime.GOMAXPROCS(0)
if numWorkers > len(allChunks) {
numWorkers = len(allChunks)
}
if numWorkers < 2 || len(allChunks) < numWorkers*encodeParallelMinChunksPerWorker {
useParallel = false
}
var ids []int32
if !useParallel {
for _, c := range allChunks {
ids = t.appendEncodedChunk(ids, c)
}
} else {
chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers
results := make([][]int32, numWorkers)
var wg sync.WaitGroup
for i := 0; i < numWorkers; i++ {
start := i * chunksPer
end := start + chunksPer
if end > len(allChunks) {
end = len(allChunks)
}
if start >= end {
continue
}
wg.Add(1)
go func(i int, chunks []encodeChunk) {
defer wg.Done()
var r []int32
for _, c := range chunks {
r = t.appendEncodedChunk(r, c)
}
results[i] = r
}(i, allChunks[start:end])
}
wg.Wait()
for _, r := range results {
ids = append(ids, r...)
}
}
if addBOS && t.vocab.BOS >= 0 {
ids = append([]int32{t.vocab.BOS}, ids...)
}
return ids
}
// encodeChunkInto appends encoded tokens to ids and returns the extended slice.
// Uses BPE merge algorithm for both BPE and SentencePiece tokenization.
func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 {
if s == "" {
return ids
}
// Apply encoding transformation
// SentencePiece: replace space with ▁
// BPE: convert bytes using precomputed table (GPT-2 byte-level encoding)
var encoded string
if t.typ == TokenizerSentencePiece {
encoded = strings.ReplaceAll(s, " ", "▁")
} else {
var sb strings.Builder
sb.Grow(len(s) * 2)
for i := 0; i < len(s); i++ {
sb.WriteRune(byteToRune[s[i]])
}
encoded = sb.String()
}
// Fast path: check if entire chunk is a single token
if id, ok := t.vocab.Reverse[encoded]; ok {
return append(ids, id)
}
return t.encodeBPEMerge(encoded, ids)
}

View File

@@ -0,0 +1,207 @@
//go:build mlx
package tokenizer
import (
"bufio"
"encoding/json"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
func llama32GGMLFixturePath(tb testing.TB, file string) string {
tb.Helper()
_, filename, _, ok := runtime.Caller(0)
if !ok {
tb.Fatal("failed to resolve test file path")
}
return filepath.Join(filepath.Dir(filename), "..", "..", "tokenizer", "testdata", "llama3.2", file)
}
func loadLlama32FromGGMLFixture(tb testing.TB) *Tokenizer {
tb.Helper()
f, err := os.Open(llama32GGMLFixturePath(tb, "encoder.json"))
if err != nil {
tb.Fatalf("failed to open encoder.json: %v", err)
}
defer f.Close()
vocab := make(map[string]int32)
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
tb.Fatalf("failed to decode encoder.json: %v", err)
}
type addedToken struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
}
var addedTokens []addedToken
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
if _, ok := vocab[token]; !ok {
id := int32(len(vocab))
vocab[token] = id
addedTokens = append(addedTokens, addedToken{ID: id, Content: token, Special: true})
}
}
mf, err := os.Open(llama32GGMLFixturePath(tb, "vocab.bpe"))
if err != nil {
tb.Fatalf("failed to open vocab.bpe: %v", err)
}
defer mf.Close()
var merges []string
scanner := bufio.NewScanner(mf)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") {
continue
}
line = strings.TrimSpace(line)
if line != "" {
merges = append(merges, line)
}
}
if err := scanner.Err(); err != nil {
tb.Fatalf("failed to read vocab.bpe: %v", err)
}
payload := struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges []string `json:"merges"`
} `json:"model"`
PreTokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
} `json:"pre_tokenizer"`
AddedTokens []addedToken `json:"added_tokens"`
}{}
payload.Model.Type = "BPE"
payload.Model.Vocab = vocab
payload.Model.Merges = merges
payload.PreTokenizer.Type = "Sequence"
payload.PreTokenizer.Pretokenizers = []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
}{
{
Type: "Split",
Pattern: struct {
Regex string `json:"Regex"`
}{
Regex: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
},
},
}
payload.AddedTokens = addedTokens
data, err := json.Marshal(payload)
if err != nil {
tb.Fatalf("failed to marshal synthetic tokenizer.json: %v", err)
}
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load tokenizer from fixture data: %v", err)
}
return tok
}
func TestGGMLLlamaKnownEncodings(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := map[string][]int32{
"hello world": {15339, 1917},
"hello <|end_of_text|>": {15339, 220, 128001},
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
}
for input, want := range cases {
got := tok.Encode(input, false)
if !equalIDs(got, want) {
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
}
}
}
func TestGGMLLlamaRepeatedZeros(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := map[int][]int32{
1: {15},
2: {410},
3: {931},
4: {931, 15},
5: {931, 410},
6: {931, 931},
7: {931, 931, 15},
8: {931, 931, 410},
9: {931, 931, 931},
10: {931, 931, 931, 15},
11: {931, 931, 931, 410},
12: {931, 931, 931, 931},
13: {931, 931, 931, 931, 15},
14: {931, 931, 931, 931, 410},
15: {931, 931, 931, 931, 931},
16: {931, 931, 931, 931, 931, 15},
17: {931, 931, 931, 931, 931, 410},
}
for n, want := range cases {
input := strings.Repeat("0", n)
got := tok.Encode(input, false)
if !equalIDs(got, want) {
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
}
}
}
func TestGGMLLlamaRoundtripAndByteBehavior(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
}
for _, input := range cases {
ids := tok.Encode(input, false)
got := tok.Decode(ids)
if got != input {
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
}
}
// Match GGML tokenizer behavior: 0x00 is omitted when decoding.
ids := tok.Encode(string(rune(0x00)), false)
got := tok.Decode(ids)
if got != "" {
t.Fatalf("expected empty decode for 0x00, got %q (ids=%v)", got, ids)
}
}

View File

@@ -0,0 +1,458 @@
//go:build mlx
package tokenizer
import (
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
)
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
type TokenizerConfig struct {
TokenizerConfigJSON []byte // tokenizer_config.json content
GenerationConfigJSON []byte // generation_config.json content
SpecialTokensMapJSON []byte // special_tokens_map.json content
ConfigJSON []byte // config.json content
}
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
// This is useful when loading from blob storage where the file content is already in memory.
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
func LoadFromBytes(data []byte) (*Tokenizer, error) {
return loadFromTokenizerJSON(data)
}
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
// This is useful when loading from blob storage where companion config files are also blobs.
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
t, err := loadFromTokenizerJSON(data)
if err != nil {
return nil, err
}
if config == nil {
return t, nil
}
// Apply special token configs from provided data
loadSpecialTokenConfigFromBytes(t, config)
return t, nil
}
// loadFromTokenizerJSON parses tokenizer.json content from bytes.
func loadFromTokenizerJSON(data []byte) (*Tokenizer, error) {
var raw struct {
Model struct {
Type string `json:"type"` // "BPE"
Vocab map[string]int32 `json:"vocab"`
Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only)
} `json:"model"`
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
Decoder json.RawMessage `json:"decoder"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer: %w", err)
}
// Covers SentencePiece and BPE models
if raw.Model.Type != "BPE" {
return nil, fmt.Errorf("unsupported tokenizer type: %s", raw.Model.Type)
}
// Parse merges - can be []string (Llama) or [][]string (GPT-OSS).
var mergesStrings []string
if raw.Model.Merges != nil {
var mergesArrays [][]string
if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil {
// Try array of arrays format
if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil {
return nil, fmt.Errorf("failed to parse merges: %w", err)
}
// Convert [][]string to []string
mergesStrings = make([]string, len(mergesArrays))
for i, pair := range mergesArrays {
if len(pair) != 2 {
return nil, fmt.Errorf("failed to parse merges: expected merge pair of length 2, got %d", len(pair))
}
mergesStrings[i] = pair[0] + " " + pair[1]
}
}
}
// Build tokenizer
t := &Tokenizer{
vocab: &Vocabulary{
Values: make([]string, len(raw.Model.Vocab)),
Reverse: raw.Model.Vocab,
Merges: make(map[string]int, len(mergesStrings)),
BOS: -1,
PAD: -1,
},
specialTokens: make(map[string]int32),
}
// Build values array
for token, id := range raw.Model.Vocab {
if int(id) >= len(t.vocab.Values) {
newValues := make([]string, id+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[id] = token
}
// Build merges map
for i, merge := range mergesStrings {
t.vocab.Merges[merge] = i
}
// Add all added_tokens to vocabulary and special tokens map.
// HuggingFace treats ALL added_tokens as special for tokenization purposes -
// they bypass BPE and get their own token ID. The "special" flag just indicates
// if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
// to treat all added_tokens as special to match HuggingFace behavior.
for _, tok := range raw.AddedTokens {
if int(tok.ID) >= len(t.vocab.Values) {
newValues := make([]string, tok.ID+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[tok.ID] = tok.Content
t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
}
// Precompute byte token IDs for <0xNN> fallback
initByteTokens(t)
// Determine tokenizer type
switch {
case detectSentencePiece(raw.Decoder):
t.typ = TokenizerSentencePiece
default:
t.typ = TokenizerBPE
}
// Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer)
if t.typ == TokenizerBPE {
pattern := extractPretokenizer(raw.PreTokenizer)
if pattern == "" {
pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
}
re, err := regexp.Compile(rewritePatternForRE2(pattern))
if err != nil {
return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err)
}
t.pretokenizer = re
}
cacheSortedSpecialTokens(t)
return t, nil
}
func cacheSortedSpecialTokens(t *Tokenizer) {
if len(t.specialTokens) == 0 {
t.sortedSpecialTokens = nil
return
}
tokens := make([]string, 0, len(t.specialTokens))
for tok := range t.specialTokens {
tokens = append(tokens, tok)
}
sort.Slice(tokens, func(i, j int) bool {
return len(tokens[i]) > len(tokens[j])
})
t.sortedSpecialTokens = tokens
}
type specialTokenConfigData struct {
tokenizerConfigJSON []byte
generationConfigJSON []byte
specialTokensMapJSON []byte
configJSON []byte
}
func applySpecialTokenConfig(t *Tokenizer, config specialTokenConfigData) {
parseTokenIDs := func(v interface{}) []int32 {
switch val := v.(type) {
case float64:
return []int32{int32(val)}
case []interface{}:
ids := make([]int32, 0, len(val))
for _, id := range val {
if f, ok := id.(float64); ok {
ids = append(ids, int32(f))
}
}
return ids
}
return nil
}
// Priority 1: generation_config.json
if len(config.generationConfigJSON) > 0 {
var genConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.generationConfigJSON, &genConfig); err == nil {
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
// Priority 2: config.json
if len(config.configJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
var modelConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.configJSON, &modelConfig); err == nil {
if len(t.vocab.EOS) == 0 {
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
}
if t.vocab.BOS < 0 {
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
}
// Priority 3: tokenizer_config.json
if len(config.tokenizerConfigJSON) > 0 {
var tokConfig struct {
BOSToken interface{} `json:"bos_token"`
EOSToken interface{} `json:"eos_token"`
PADToken interface{} `json:"pad_token"`
AddBOSToken *bool `json:"add_bos_token"`
AddEOSToken *bool `json:"add_eos_token"`
}
if err := json.Unmarshal(config.tokenizerConfigJSON, &tokConfig); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
if tokConfig.AddBOSToken != nil {
t.vocab.AddBOS = *tokConfig.AddBOSToken
}
if tokConfig.AddEOSToken != nil {
t.vocab.AddEOS = *tokConfig.AddEOSToken
}
}
}
// Priority 4: special_tokens_map.json
if len(config.specialTokensMapJSON) > 0 {
var tokensMap map[string]interface{}
if err := json.Unmarshal(config.specialTokensMapJSON, &tokensMap); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
}
}
}
// extractTokenString extracts the token string from various formats used in HuggingFace configs.
// Tokens can be represented as:
// - string: "token"
// - object: {"content": "token", ...}
func extractTokenString(v interface{}) string {
if v == nil {
return ""
}
// Direct string
if s, ok := v.(string); ok {
return s
}
// Object with content field
if m, ok := v.(map[string]interface{}); ok {
if content, ok := m["content"].(string); ok {
return content
}
}
return ""
}
// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be
// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features:
// - (?!\S) negative lookahead - RE2 doesn't support this
// - (?i:...) inline case-insensitive groups - RE2 doesn't support this
//
// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex().
// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word).
// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior.
func rewritePatternForRE2(pattern string) string {
// Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex()
pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`)
// Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style)
// IMPORTANT: Must be done before the non-optional version to avoid partial replacement
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`)
// Expand case-insensitive contraction pattern to explicit alternations
// (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`)
return pattern
}
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
applySpecialTokenConfig(t, specialTokenConfigData{
tokenizerConfigJSON: config.TokenizerConfigJSON,
generationConfigJSON: config.GenerationConfigJSON,
specialTokensMapJSON: config.SpecialTokensMapJSON,
configJSON: config.ConfigJSON,
})
}
// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces)
// vs GPT-2 byte-level encoding
func detectSentencePiece(data json.RawMessage) bool {
if data == nil {
return false
}
// Check for Sequence decoder with Replace step (SentencePiece style)
var seq struct {
Type string `json:"type"`
Decoders []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"String"`
} `json:"pattern"`
} `json:"decoders"`
}
if err := json.Unmarshal(data, &seq); err == nil {
if seq.Type == "Sequence" {
for _, dec := range seq.Decoders {
// Look for Replace decoder that converts ▁ to space
if dec.Type == "Replace" && dec.Pattern.String == "▁" {
return true
}
}
}
}
// Check for direct ByteLevel decoder (GPT-2 style)
var simple struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &simple); err == nil {
if simple.Type == "ByteLevel" {
return false
}
}
return false
}
// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding
func initByteTokens(t *Tokenizer) {
for i := range t.vocab.byteTokens {
t.vocab.byteTokens[i] = -1
}
for b := 0; b < 256; b++ {
token := fmt.Sprintf("<0x%02X>", b)
if id, ok := t.vocab.Reverse[token]; ok {
t.vocab.byteTokens[b] = id
}
}
}
// extractPretokenizer extracts the regex pattern from the pre_tokenizer config
func extractPretokenizer(data json.RawMessage) string {
if data == nil {
return ""
}
// Try to parse as a single Split pretokenizer
var single struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
}
if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" {
return single.Pattern.Regex
}
// Try to parse as Sequence of pretokenizers - use first Split pattern
var seq struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
}
if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
for _, pt := range seq.Pretokenizers {
if pt.Type == "Split" && pt.Pattern.Regex != "" {
return pt.Pattern.Regex
}
}
}
return ""
}

View File

@@ -0,0 +1,26 @@
//go:build mlx
package tokenizer
import (
"strings"
"testing"
)
func TestLoadFromBytesRejectsWordPiece(t *testing.T) {
data := []byte(`{
"model": {
"type": "WordPiece",
"vocab": {"[UNK]": 0, "hello": 1}
},
"added_tokens": []
}`)
_, err := LoadFromBytes(data)
if err == nil {
t.Fatal("expected WordPiece load to fail")
}
if !strings.Contains(err.Error(), "unsupported tokenizer type: WordPiece") {
t.Fatalf("unexpected error: %v", err)
}
}