From 87288ced4fbbcde15f1bab6126cb3bd03ee7dcfe Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Tue, 28 Apr 2026 11:50:12 -0700 Subject: [PATCH] New models (#15861) * mlx: add laguna model support * convert: support fp8 safetensors import Decode HF F8_E4M3 safetensors with block scale companions into GGUF-supported tensor types, and record which output tensors came from FP8 source weights. Use that source-precision metadata during create quantization: default FP8-sourced GGUFs to Q8_0, keep non-FP8 tensors at their original precision for Q8_0, and promote non-FP8 quantizable tensors to Q8_0 for Q4_K requests. * ggml: add laguna model support * server: preserve generate logprobs with builtin parsers Generate requests were dropping logprob-only chunks whenever a builtin parser buffered visible content. Chat already handled this case, but generate only forwarded chunks with visible response, thinking, or tool-call output. Keep generate chunks that carry logprobs even when the builtin parser has not flushed visible content yet, and add a regression test that exercises the behavior with a generic thinking parser. * review comments - perf improvements * ggml: implement nemotron 3 nano omni * add poolside integration * update poolside doc * adapt to new cache setup * fix test * fix test --------- Co-authored-by: Eva Ho --- cmd/launch/integrations_test.go | 50 +- cmd/launch/launch.go | 1 + cmd/launch/poolside.go | 51 + cmd/launch/poolside_test.go | 88 ++ cmd/launch/registry.go | 21 +- cmd/launch/runner_exec_only_test.go | 12 + convert/convert.go | 8 + convert/convert_laguna.go | 604 ++++++++ convert/convert_laguna_test.go | 450 ++++++ convert/convert_nemotron_h.go | 409 ++++++ convert/convert_nemotron_h_test.go | 310 +++++ convert/reader.go | 6 +- convert/reader_safetensors.go | 429 +++++- convert/reader_test.go | 229 ++++ convert/tensor.go | 52 + convert/tensor_test.go | 56 +- convert/tokenizer.go | 2 + docs/docs.json | 3 +- docs/integrations/index.mdx | 1 + docs/integrations/poolside.mdx | 54 + fs/ggml/ggml.go | 5 +- model/models/laguna/model.go | 444 ++++++ model/models/laguna/model_test.go | 237 ++++ model/models/models.go | 1 + model/models/nemotronh/imageproc.go | 355 +++++ model/models/nemotronh/model.go | 25 +- model/models/nemotronh/model_audio.go | 511 +++++++ model/models/nemotronh/model_omni.go | 239 ++++ model/models/nemotronh/model_omni_test.go | 606 ++++++++ model/models/nemotronh/model_vision.go | 348 +++++ model/models/nemotronh/process_audio.go | 328 +++++ model/parsers/laguna.go | 498 +++++++ model/parsers/laguna_test.go | 484 +++++++ model/parsers/nemotron3nano.go | 70 +- model/parsers/nemotron3nano_test.go | 50 +- model/parsers/parsers.go | 2 + model/renderers/laguna.go | 111 ++ model/renderers/laguna_test.go | 339 +++++ model/renderers/nemotron3nano.go | 297 +++- .../renderers/nemotron3nano_reference_test.go | 614 +++++++++ model/renderers/nemotron3nano_test.go | 551 +------- model/renderers/renderer.go | 2 + .../nemotron3nano_chat_template.jinja2 | 222 +++ server/create.go | 19 +- server/laguna_quantization_test.go | 90 ++ server/quantization.go | 118 +- server/quantization_test.go | 89 ++ server/routes.go | 6 +- server/routes_create_test.go | 158 +++ server/routes_generate_test.go | 113 ++ server/sched.go | 2 +- x/create/client/create.go | 32 +- x/create/client/create_test.go | 66 +- x/create/create.go | 19 +- x/create/laguna.go | 59 + x/create/laguna_test.go | 200 +++ x/mlxrunner/imports.go | 1 + x/mlxrunner/mlx/act.go | 13 + x/models/laguna/laguna.go | 1216 +++++++++++++++++ x/models/laguna/laguna_test.go | 509 +++++++ x/tokenizer/tokenizer_load.go | 4 +- x/tokenizer/tokenizer_load_test.go | 28 + 62 files changed, 11284 insertions(+), 633 deletions(-) create mode 100644 cmd/launch/poolside.go create mode 100644 cmd/launch/poolside_test.go create mode 100644 convert/convert_laguna.go create mode 100644 convert/convert_laguna_test.go create mode 100644 docs/integrations/poolside.mdx create mode 100644 model/models/laguna/model.go create mode 100644 model/models/laguna/model_test.go create mode 100644 model/models/nemotronh/imageproc.go create mode 100644 model/models/nemotronh/model_audio.go create mode 100644 model/models/nemotronh/model_omni.go create mode 100644 model/models/nemotronh/model_omni_test.go create mode 100644 model/models/nemotronh/model_vision.go create mode 100644 model/models/nemotronh/process_audio.go create mode 100644 model/parsers/laguna.go create mode 100644 model/parsers/laguna_test.go create mode 100644 model/renderers/laguna.go create mode 100644 model/renderers/laguna_test.go create mode 100644 model/renderers/nemotron3nano_reference_test.go create mode 100644 model/renderers/testdata/nemotron3nano_chat_template.jinja2 create mode 100644 server/laguna_quantization_test.go create mode 100644 x/create/laguna.go create mode 100644 x/create/laguna_test.go create mode 100644 x/models/laguna/laguna.go create mode 100644 x/models/laguna/laguna_test.go diff --git a/cmd/launch/integrations_test.go b/cmd/launch/integrations_test.go index ff2c2e3aa..c59bf8113 100644 --- a/cmd/launch/integrations_test.go +++ b/cmd/launch/integrations_test.go @@ -75,8 +75,7 @@ func TestIntegrationLookup(t *testing.T) { } func TestIntegrationRegistry(t *testing.T) { - expectedIntegrations := []string{"claude", "codex", "kimi", "droid", "opencode", "hermes"} - + expectedIntegrations := []string{"claude", "codex", "kimi", "droid", "opencode", "hermes", "pool"} for _, name := range expectedIntegrations { t.Run(name, func(t *testing.T) { r, ok := integrations[name] @@ -1494,6 +1493,11 @@ func TestIntegration_InstallHint(t *testing.T) { input: "openclaw", wantURL: "https://docs.openclaw.ai", }, + { + name: "pool has hint", + input: "pool", + wantURL: "https://github.com/poolsideai/pool", + }, { name: "unknown has no hint", input: "unknown", @@ -1549,7 +1553,19 @@ func TestListIntegrationInfos(t *testing.T) { for _, info := range infos { got = append(got, info.Name) } - if diff := compareStrings(got, integrationOrder); diff != "" { + + want := append([]string(nil), integrationOrder...) + if poolsideGOOS == "windows" { + filtered := make([]string, 0, len(want)) + for _, name := range want { + if name != "pool" { + filtered = append(filtered, name) + } + } + want = filtered + } + + if diff := compareStrings(got, want); diff != "" { t.Fatalf("launcher integration order mismatch: %s", diff) } }) @@ -1567,6 +1583,9 @@ func TestListIntegrationInfos(t *testing.T) { t.Run("includes known integrations", func(t *testing.T) { known := map[string]bool{"claude": false, "codex": false, "opencode": false} + if poolsideGOOS != "windows" { + known["pool"] = false + } for _, info := range infos { if _, ok := known[info.Name]; ok { known[info.Name] = true @@ -1601,6 +1620,17 @@ func TestListIntegrationInfos(t *testing.T) { } }) } +func TestListIntegrationInfos_HidesPoolsideOnWindows(t *testing.T) { + prev := poolsideGOOS + poolsideGOOS = "windows" + t.Cleanup(func() { poolsideGOOS = prev }) + + for _, info := range ListIntegrationInfos() { + if info.Name == "pool" { + t.Fatal("expected pool to be hidden on Windows") + } + } +} func TestBuildModelList_Descriptions(t *testing.T) { t.Run("installed recommended has base description", func(t *testing.T) { @@ -1707,6 +1737,20 @@ func TestIntegration_AutoInstallable(t *testing.T) { } } +func TestEnsureIntegrationInstalled_PoolsideUnsupportedOnWindows(t *testing.T) { + prev := poolsideGOOS + poolsideGOOS = "windows" + t.Cleanup(func() { poolsideGOOS = prev }) + + err := EnsureIntegrationInstalled("pool", &Poolside{}) + if err == nil { + t.Fatal("expected Windows unsupported error") + } + if !strings.Contains(err.Error(), "not currently supported on Windows") { + t.Fatalf("expected Windows warning, got %v", err) + } +} + func TestIntegrationModels(t *testing.T) { tmpDir := t.TempDir() setTestHome(t, tmpDir) diff --git a/cmd/launch/launch.go b/cmd/launch/launch.go index fcc3d7329..3cd90f942 100644 --- a/cmd/launch/launch.go +++ b/cmd/launch/launch.go @@ -213,6 +213,7 @@ Supported integrations: opencode OpenCode openclaw OpenClaw (aliases: clawdbot, moltbot) pi Pi + pool Poolside vscode    VS Code (aliases: code) Examples: diff --git a/cmd/launch/poolside.go b/cmd/launch/poolside.go new file mode 100644 index 000000000..d9a9bb17a --- /dev/null +++ b/cmd/launch/poolside.go @@ -0,0 +1,51 @@ +package launch + +import ( + "fmt" + "os" + "os/exec" + "runtime" + + "github.com/ollama/ollama/envconfig" +) + +// Poolside implements Runner for Poolside's CLI. +type Poolside struct{} + +var poolsideGOOS = runtime.GOOS + +func (p *Poolside) String() string { return "Poolside" } + +func poolsideUnsupportedError() error { + return fmt.Errorf("Warning: Poolside is not currently supported on Windows") +} + +func (p *Poolside) args(model string, extra []string) []string { + var args []string + if model != "" { + args = append(args, "-m", model) + } + args = append(args, extra...) + return args +} + +func (p *Poolside) Run(model string, args []string) error { + if poolsideGOOS == "windows" { + return poolsideUnsupportedError() + } + + bin, err := exec.LookPath("pool") + if err != nil { + return fmt.Errorf("pool is not installed") + } + + cmd := exec.Command(bin, p.args(model, args)...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = append(os.Environ(), + "POOLSIDE_STANDALONE_BASE_URL="+envconfig.Host().String()+"/v1", + "POOLSIDE_API_KEY=ollama", + ) + return cmd.Run() +} diff --git a/cmd/launch/poolside_test.go b/cmd/launch/poolside_test.go new file mode 100644 index 000000000..9608aefe6 --- /dev/null +++ b/cmd/launch/poolside_test.go @@ -0,0 +1,88 @@ +package launch + +import ( + "os" + "path/filepath" + "runtime" + "slices" + "strings" + "testing" +) + +func TestPoolsideArgs(t *testing.T) { + p := &Poolside{} + + tests := []struct { + name string + model string + extra []string + want []string + }{ + {name: "with model", model: "qwen3.5", want: []string{"-m", "qwen3.5"}}, + {name: "without model", extra: []string{"session"}, want: []string{"session"}}, + {name: "with model and extra args", model: "llama3.2", extra: []string{"--foo", "bar"}, want: []string{"-m", "llama3.2", "--foo", "bar"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := p.args(tt.model, tt.extra) + if !slices.Equal(got, tt.want) { + t.Fatalf("args(%q, %v) = %v, want %v", tt.model, tt.extra, got, tt.want) + } + }) + } +} + +func TestPoolsideRunSetsOllamaEnv(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("uses POSIX shell fake binary") + } + + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "pool.log") + poolPath := filepath.Join(tmpDir, "pool") + script := "#!/bin/sh\n" + + "printf 'base=%s\\nkey=%s\\nargs=%s\\n' \"$POOLSIDE_STANDALONE_BASE_URL\" \"$POOLSIDE_API_KEY\" \"$*\" > \"" + logPath + "\"\n" + if err := os.WriteFile(poolPath, []byte(script), 0o755); err != nil { + t.Fatalf("failed to write fake pool binary: %v", err) + } + + t.Setenv("PATH", tmpDir) + t.Setenv("OLLAMA_HOST", "http://127.0.0.1:11434") + + p := &Poolside{} + if err := p.Run("qwen3.5", []string{"session"}); err != nil { + t.Fatalf("Run returned error: %v", err) + } + + data, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("failed to read pool log: %v", err) + } + + got := string(data) + if !strings.Contains(got, "base=http://127.0.0.1:11434/v1") { + t.Fatalf("expected Poolside base URL override in log, got:\n%s", got) + } + if !strings.Contains(got, "key=ollama") { + t.Fatalf("expected Poolside API key override in log, got:\n%s", got) + } + if !strings.Contains(got, "args=-m qwen3.5 session") { + t.Fatalf("expected model and extra args in log, got:\n%s", got) + } +} + +func TestPoolsideRunWindowsUnsupported(t *testing.T) { + prev := poolsideGOOS + poolsideGOOS = "windows" + t.Cleanup(func() { poolsideGOOS = prev }) + + p := &Poolside{} + err := p.Run("kimi-k2.6:cloud", nil) + if err == nil { + t.Fatal("expected Windows unsupported error") + } + if !strings.Contains(err.Error(), "not currently supported on Windows") { + t.Fatalf("expected Windows warning, got %v", err) + } +} diff --git a/cmd/launch/registry.go b/cmd/launch/registry.go index f390142ca..10d54c423 100644 --- a/cmd/launch/registry.go +++ b/cmd/launch/registry.go @@ -33,7 +33,7 @@ type IntegrationInfo struct { Description string } -var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi"} +var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi", "pool"} var integrationSpecs = []*IntegrationSpec{ { @@ -166,6 +166,18 @@ var integrationSpecs = []*IntegrationSpec{ Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"}, }, }, + { + Name: "pool", + Runner: &Poolside{}, + Description: "Poolside's software agent for enterprise development", + Install: IntegrationInstallSpec{ + CheckInstalled: func() bool { + _, err := exec.LookPath("pool") + return err == nil + }, + URL: "https://github.com/poolsideai/pool", + }, + }, { Name: "hermes", Runner: &Hermes{}, @@ -285,6 +297,9 @@ func ListVisibleIntegrationSpecs() []IntegrationSpec { if spec.Hidden { continue } + if spec.Name == "pool" && poolsideGOOS == "windows" { + continue + } visible = append(visible, *spec) } @@ -399,6 +414,10 @@ func EnsureIntegrationInstalled(name string, runner Runner) error { return fmt.Errorf("%s is not installed", runner) } + if integration.spec.Name == "pool" && poolsideGOOS == "windows" { + return poolsideUnsupportedError() + } + if integration.installed { return nil } diff --git a/cmd/launch/runner_exec_only_test.go b/cmd/launch/runner_exec_only_test.go index 137ed8192..66d3c67fe 100644 --- a/cmd/launch/runner_exec_only_test.go +++ b/cmd/launch/runner_exec_only_test.go @@ -45,6 +45,14 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) { return filepath.Join(home, ".pi", "agent", "models.json") }, }, + { + name: "pool", + binary: "pool", + runner: &Poolside{}, + checkPath: func(home string) string { + return filepath.Join(home, ".poolside", "config") + }, + }, { name: "kimi", binary: "kimi", @@ -57,6 +65,10 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + if tt.name == "pool" && poolsideGOOS == "windows" { + t.Skip("Poolside is intentionally unsupported on Windows") + } + home := t.TempDir() setTestHome(t, home) diff --git a/convert/convert.go b/convert/convert.go index 876ac54c0..9092f00de 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -316,6 +316,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) { conv = &deepseek2Model{} case "Glm4MoeLiteForCausalLM": conv = &glm4MoeLiteModel{} + case "LagunaForCausalLM": + conv = &lagunaModel{} case "GlmOcrForConditionalGeneration": conv = &glmOcrModel{} case "Lfm2ForCausalLM", "Lfm2MoeForCausalLM": @@ -324,6 +326,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) { conv = &lfm2VLTextModel{} case "Qwen3NextForCausalLM", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration": conv = &qwen3NextModel{} + case "NemotronH_Nano_VL_V2", "NemotronH_Nano_Omni_Reasoning_V3": + conv = &nemotronHNanoVLModel{} case "NemotronHForCausalLM": conv = &nemotronHModel{} default: @@ -387,6 +391,10 @@ func ConvertModel(fsys fs.FS, f *os.File) error { } func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error { + for k, v := range sourceTensorKV(ts) { + kv[k] = v + } + for i := range ts { ts[i].Shape = slices.Clone(ts[i].Shape) slices.Reverse(ts[i].Shape) diff --git a/convert/convert_laguna.go b/convert/convert_laguna.go new file mode 100644 index 000000000..f2f36e190 --- /dev/null +++ b/convert/convert_laguna.go @@ -0,0 +1,604 @@ +package convert + +import ( + "cmp" + "encoding/json" + "fmt" + iofs "io/fs" + "math" + "strings" + + "github.com/ollama/ollama/fs/ggml" +) + +type lagunaModel struct { + ModelParameters + + NumHiddenLayers uint32 `json:"num_hidden_layers"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + HeadDim uint32 `json:"head_dim"` + RMSNormEPS float32 `json:"rms_norm_eps"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + SlidingWindow uint32 `json:"sliding_window"` + PartialRotaryFactor float32 `json:"partial_rotary_factor"` + Gating lagunaGatingMode `json:"gating"` + QKNormType string `json:"qk_norm_type"` + + LayerTypes []string `json:"layer_types"` + NumAttentionHeadsPerLayer []uint32 `json:"num_attention_heads_per_layer"` + + NumExperts uint32 `json:"num_experts"` + NumExpertsPerTok uint32 `json:"num_experts_per_tok"` + MoEIntermediateSize uint32 `json:"moe_intermediate_size"` + SharedExpertIntermediateSize uint32 `json:"shared_expert_intermediate_size"` + NormTopKProb bool `json:"norm_topk_prob"` + MoeRoutedScalingFactor float32 `json:"moe_routed_scaling_factor"` + MoERouterUseSigmoid bool `json:"moe_router_use_sigmoid"` + MoEApplyRouterWeightOnInput bool `json:"moe_apply_router_weight_on_input"` + DecoderSparseStep uint32 `json:"decoder_sparse_step"` + MLPOnlyLayers []uint32 `json:"mlp_only_layers"` + MLPLayerTypes []string `json:"mlp_layer_types"` + + RopeParameters lagunaRopeParameters `json:"rope_parameters"` + SwaRopeParameters lagunaRopeParameters `json:"swa_rope_parameters"` + + SwaAttentionSinkEnabled bool `json:"swa_attention_sink_enabled"` +} + +type lagunaGatingMode string + +type lagunaRopeParameters struct { + RopeTheta float32 `json:"rope_theta"` + RopeType string `json:"rope_type"` + Type string `json:"type"` + Factor float32 `json:"factor"` + OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` + BetaSlow float32 `json:"beta_slow"` + BetaFast float32 `json:"beta_fast"` + AttentionFactor float32 `json:"attention_factor"` + PartialRotaryFactor float32 `json:"partial_rotary_factor"` +} + +type lagunaRopeConfig struct { + flat lagunaRopeParameters + full lagunaRopeParameters + sliding lagunaRopeParameters + nested bool +} + +func (g *lagunaGatingMode) UnmarshalJSON(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err == nil { + *g = lagunaGatingMode(s) + return nil + } + + var enabled bool + if err := json.Unmarshal(b, &enabled); err == nil { + if enabled { + *g = "true" + } else { + *g = "false" + } + return nil + } + + if string(b) == "null" { + return nil + } + return fmt.Errorf("unsupported Laguna gating JSON value %s", string(b)) +} + +func (g lagunaGatingMode) perHead() bool { + return strings.EqualFold(string(g), "per-head") || strings.EqualFold(string(g), "true") +} + +func (r *lagunaRopeConfig) UnmarshalJSON(b []byte) error { + if string(b) == "null" { + return nil + } + + var probe map[string]json.RawMessage + if err := json.Unmarshal(b, &probe); err != nil { + return err + } + + if len(probe) == 0 { + return nil + } + + if raw, ok := probe["full_attention"]; ok { + r.nested = true + if err := json.Unmarshal(raw, &r.full); err != nil { + return err + } + if raw = probe["sliding_attention"]; raw != nil { + if err := json.Unmarshal(raw, &r.sliding); err != nil { + return err + } + } + return nil + } + + if raw, ok := probe["global_attention"]; ok { + r.nested = true + if err := json.Unmarshal(raw, &r.full); err != nil { + return err + } + if raw = probe["sliding_attention"]; raw != nil { + if err := json.Unmarshal(raw, &r.sliding); err != nil { + return err + } + } + return nil + } + + return json.Unmarshal(b, &r.flat) +} + +func (r lagunaRopeConfig) fullParams() lagunaRopeParameters { + if r.nested { + return r.full + } + return r.flat +} + +func (r lagunaRopeConfig) slidingParams() (lagunaRopeParameters, bool) { + if !r.nested { + return lagunaRopeParameters{}, false + } + return r.sliding, true +} + +func (r lagunaRopeParameters) ropeType() string { + return cmp.Or(r.RopeType, r.Type) +} + +func (r lagunaRopeParameters) withDefaultPartialRotaryFactor(v float32) lagunaRopeParameters { + if r.PartialRotaryFactor == 0 { + r.PartialRotaryFactor = v + } + return r +} + +func (r lagunaRopeParameters) empty() bool { + return r == (lagunaRopeParameters{}) +} + +type rawLagunaModel struct { + ModelParameters + + NumHiddenLayers uint32 `json:"num_hidden_layers"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + HeadDim uint32 `json:"head_dim"` + RMSNormEPS float32 `json:"rms_norm_eps"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + SlidingWindow uint32 `json:"sliding_window"` + PartialRotaryFactor float32 `json:"partial_rotary_factor"` + Gating lagunaGatingMode `json:"gating"` + QKNormType string `json:"qk_norm_type"` + + LayerTypes []string `json:"layer_types"` + NumAttentionHeadsPerLayer []uint32 `json:"num_attention_heads_per_layer"` + + NumExperts uint32 `json:"num_experts"` + NumExpertsPerTok uint32 `json:"num_experts_per_tok"` + MoEIntermediateSize uint32 `json:"moe_intermediate_size"` + SharedExpertIntermediateSize uint32 `json:"shared_expert_intermediate_size"` + NormTopKProb *bool `json:"norm_topk_prob"` + MoeRoutedScalingFactor float32 `json:"moe_routed_scaling_factor"` + MoERouterUseSigmoid *bool `json:"moe_router_use_sigmoid"` + MoEApplyRouterWeightOnInput bool `json:"moe_apply_router_weight_on_input"` + DecoderSparseStep uint32 `json:"decoder_sparse_step"` + MLPOnlyLayers []uint32 `json:"mlp_only_layers"` + MLPLayerTypes []string `json:"mlp_layer_types"` + + RopeParameters lagunaRopeConfig `json:"rope_parameters"` + SwaRopeParameters lagunaRopeParameters `json:"swa_rope_parameters"` + + SwaAttentionSinkEnabled bool `json:"swa_attention_sink_enabled"` +} + +func (p *lagunaModel) UnmarshalJSON(b []byte) error { + var raw rawLagunaModel + if err := json.Unmarshal(b, &raw); err != nil { + return err + } + + mlpOnlyLayers, err := lagunaDenseLayers(raw.MLPOnlyLayers, raw.MLPLayerTypes) + if err != nil { + return err + } + + fullRope := raw.RopeParameters.fullParams().withDefaultPartialRotaryFactor(cmp.Or(raw.PartialRotaryFactor, float32(1))) + swaRope := raw.SwaRopeParameters + if nestedSwa, ok := raw.RopeParameters.slidingParams(); ok && !nestedSwa.empty() { + swaRope = nestedSwa + } + swaRope = swaRope.withDefaultPartialRotaryFactor(cmp.Or(fullRope.PartialRotaryFactor, float32(1))) + + *p = lagunaModel{ + ModelParameters: raw.ModelParameters, + NumHiddenLayers: raw.NumHiddenLayers, + HiddenSize: raw.HiddenSize, + IntermediateSize: raw.IntermediateSize, + NumAttentionHeads: raw.NumAttentionHeads, + NumKeyValueHeads: raw.NumKeyValueHeads, + HeadDim: raw.HeadDim, + RMSNormEPS: raw.RMSNormEPS, + MaxPositionEmbeddings: raw.MaxPositionEmbeddings, + SlidingWindow: raw.SlidingWindow, + PartialRotaryFactor: cmp.Or(raw.PartialRotaryFactor, fullRope.PartialRotaryFactor), + Gating: raw.Gating, + QKNormType: cmp.Or(raw.QKNormType, "rmsnorm"), + LayerTypes: raw.LayerTypes, + NumAttentionHeadsPerLayer: raw.NumAttentionHeadsPerLayer, + NumExperts: raw.NumExperts, + NumExpertsPerTok: raw.NumExpertsPerTok, + MoEIntermediateSize: raw.MoEIntermediateSize, + SharedExpertIntermediateSize: raw.SharedExpertIntermediateSize, + NormTopKProb: defaultBool(raw.NormTopKProb, true), + MoeRoutedScalingFactor: raw.MoeRoutedScalingFactor, + MoERouterUseSigmoid: defaultBool(raw.MoERouterUseSigmoid, true), + MoEApplyRouterWeightOnInput: raw.MoEApplyRouterWeightOnInput, + DecoderSparseStep: raw.DecoderSparseStep, + MLPOnlyLayers: mlpOnlyLayers, + MLPLayerTypes: raw.MLPLayerTypes, + RopeParameters: fullRope, + SwaRopeParameters: swaRope, + SwaAttentionSinkEnabled: raw.SwaAttentionSinkEnabled, + } + return nil +} + +func defaultBool(v *bool, fallback bool) bool { + if v == nil { + return fallback + } + return *v +} + +const ( + lagunaGatingFuncSoftmax uint32 = 1 + lagunaGatingFuncSigmoid uint32 = 2 + + lagunaLayerTypeGlobal uint32 = 0 + lagunaLayerTypeSliding uint32 = 1 +) + +func (p *lagunaModel) KV(t *Tokenizer) KV { + kv := p.ModelParameters.KV(t) + kv["general.architecture"] = "laguna" + // Laguna's chat template and built-in renderer both emit the leading + // special token explicitly. Auto-prepending BOS here would duplicate it. + kv["tokenizer.ggml.add_bos_token"] = false + kv["tokenizer.ggml.pre"] = "laguna" + // Laguna does not need tokenizer.chat_template at runtime: Ollama create + // sets the Laguna renderer/parser from the architecture, and the renderer + // owns prompt formatting. + delete(kv, "tokenizer.chat_template") + + kv["laguna.block_count"] = p.NumHiddenLayers + kv["laguna.context_length"] = p.MaxPositionEmbeddings + kv["laguna.embedding_length"] = p.HiddenSize + kv["laguna.feed_forward_length"] = p.IntermediateSize + + if len(p.NumAttentionHeadsPerLayer) == int(p.NumHiddenLayers) { + kv["laguna.attention.head_count"] = p.NumAttentionHeadsPerLayer + } else { + kv["laguna.attention.head_count"] = p.NumAttentionHeads + } + kv["laguna.attention.head_count_kv"] = p.NumKeyValueHeads + kv["laguna.attention.key_length"] = p.HeadDim + kv["laguna.attention.value_length"] = p.HeadDim + kv["laguna.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS + kv["laguna.attention.sliding_window"] = p.SlidingWindow + kv["laguna.attention.sink_enabled"] = p.SwaAttentionSinkEnabled + + if len(p.LayerTypes) > 0 { + encoded := make([]uint32, len(p.LayerTypes)) + slidingPattern := make([]bool, len(p.LayerTypes)) + for i, layerType := range p.LayerTypes { + if lagunaLayerIsSliding(layerType) { + encoded[i] = lagunaLayerTypeSliding + slidingPattern[i] = true + } else { + encoded[i] = lagunaLayerTypeGlobal + } + } + kv["laguna.attention.layer_types"] = encoded + kv["laguna.attention.sliding_window_pattern"] = slidingPattern + } + + if p.Gating.perHead() { + kv["laguna.attention.gating_type"] = uint32(1) + } else { + kv["laguna.attention.gating_type"] = uint32(0) + } + kv["laguna.attention.qk_norm"] = p.QKNormType == "rmsnorm" + + kv["laguna.expert_count"] = p.NumExperts + kv["laguna.expert_used_count"] = p.NumExpertsPerTok + kv["laguna.expert_feed_forward_length"] = p.MoEIntermediateSize + kv["laguna.expert_shared_feed_forward_length"] = p.SharedExpertIntermediateSize + kv["laguna.expert_shared_count"] = uint32(1) + kv["laguna.expert_weights_norm"] = p.NormTopKProb + kv["laguna.expert_weights_scale"] = p.MoeRoutedScalingFactor + kv["laguna.expert_gating_func"] = lagunaMoeGatingFunc(p.MoERouterUseSigmoid) + kv["laguna.decoder_sparse_step"] = cmp.Or(p.DecoderSparseStep, uint32(1)) + + if leading, ok := lagunaLeadingDensePrefix(p.MLPOnlyLayers); ok { + kv["laguna.leading_dense_block_count"] = leading + } + if len(p.MLPOnlyLayers) > 0 { + kv["laguna.dense_layers"] = p.MLPOnlyLayers + } + + ropeType := p.RopeParameters.ropeType() + kv["laguna.rope.freq_base"] = cmp.Or(p.RopeParameters.RopeTheta, float32(10000)) + kv["laguna.rope.scaling.type"] = ropeType + ropeFactor := cmp.Or(p.RopeParameters.Factor, float32(1)) + kv["laguna.rope.scaling.factor"] = ropeFactor + kv["laguna.rope.scaling.original_context_length"] = p.RopeParameters.OriginalMaxPositionEmbeddings + kv["laguna.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast + kv["laguna.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow + kv["laguna.rope.scaling.attn_factor"] = lagunaAttentionFactor(ropeType, ropeFactor, p.RopeParameters.AttentionFactor) + kv["laguna.rope.partial_rotary_factor"] = cmp.Or(p.PartialRotaryFactor, float32(1)) + + swaRopeType := p.SwaRopeParameters.ropeType() + kv["laguna.rope.swa.freq_base"] = cmp.Or(p.SwaRopeParameters.RopeTheta, float32(10000)) + kv["laguna.rope.swa.scaling.type"] = cmp.Or(swaRopeType, "linear") + kv["laguna.rope.swa.scaling.factor"] = cmp.Or(p.SwaRopeParameters.Factor, float32(1)) + kv["laguna.rope.swa.partial_rotary_factor"] = cmp.Or(p.SwaRopeParameters.PartialRotaryFactor, float32(1)) + + headDim := p.HeadDim + if headDim == 0 && p.NumAttentionHeads > 0 { + headDim = p.HiddenSize / p.NumAttentionHeads + } + kv["laguna.rope.dimension_count"] = lagunaRopeDim(headDim, cmp.Or(p.PartialRotaryFactor, float32(1))) + kv["laguna.rope.swa.dimension_count"] = lagunaRopeDim(headDim, cmp.Or(p.SwaRopeParameters.PartialRotaryFactor, float32(1))) + + return kv +} + +func (p *lagunaModel) parseMore(_ iofs.FS) error { + return p.validate() +} + +func (p *lagunaModel) validate() error { + if p.NumHiddenLayers == 0 { + return fmt.Errorf("laguna: num_hidden_layers must be set") + } + if p.HiddenSize == 0 { + return fmt.Errorf("laguna: hidden_size must be set") + } + if p.HeadDim == 0 { + return fmt.Errorf("laguna: head_dim must be set") + } + if p.NumKeyValueHeads == 0 { + return fmt.Errorf("laguna: num_key_value_heads must be set") + } + if p.SwaAttentionSinkEnabled { + return fmt.Errorf("laguna: unsupported swa_attention_sink_enabled=true") + } + if !p.Gating.perHead() { + return fmt.Errorf("laguna: unsupported attention gating %q: only gating=\"per-head\" is supported", p.Gating) + } + if p.QKNormType != "rmsnorm" { + return fmt.Errorf("laguna: unsupported qk_norm_type %q: only rmsnorm is supported", p.QKNormType) + } + if !p.MoERouterUseSigmoid { + return fmt.Errorf("laguna: unsupported moe_router_use_sigmoid=false") + } + if p.MoEApplyRouterWeightOnInput { + return fmt.Errorf("laguna: unsupported moe_apply_router_weight_on_input=true") + } + if p.DecoderSparseStep != 0 && p.DecoderSparseStep != 1 { + return fmt.Errorf("laguna: unsupported decoder_sparse_step=%d: only 1 is supported", p.DecoderSparseStep) + } + if len(p.MLPOnlyLayers) != 1 || p.MLPOnlyLayers[0] != 0 { + return fmt.Errorf("laguna: unsupported mlp_only_layers=%v: only [0] is supported", p.MLPOnlyLayers) + } + if p.NumExperts == 0 { + return fmt.Errorf("laguna: num_experts must be set") + } + if p.NumExpertsPerTok == 0 { + return fmt.Errorf("laguna: num_experts_per_tok must be set") + } + if p.MoEIntermediateSize == 0 { + return fmt.Errorf("laguna: moe_intermediate_size must be set") + } + if p.SharedExpertIntermediateSize == 0 { + return fmt.Errorf("laguna: shared_expert_intermediate_size must be set") + } + + if len(p.LayerTypes) > 0 && len(p.LayerTypes) != int(p.NumHiddenLayers) { + return fmt.Errorf("laguna: layer_types has %d entries, expected %d", len(p.LayerTypes), p.NumHiddenLayers) + } + for i, layerType := range p.LayerTypes { + if !lagunaLayerIsGlobal(layerType) && !lagunaLayerIsSliding(layerType) { + return fmt.Errorf("laguna: unsupported layer_types[%d]=%q", i, layerType) + } + } + if len(p.NumAttentionHeadsPerLayer) > 0 && len(p.NumAttentionHeadsPerLayer) != int(p.NumHiddenLayers) { + return fmt.Errorf("laguna: num_attention_heads_per_layer has %d entries, expected %d", len(p.NumAttentionHeadsPerLayer), p.NumHiddenLayers) + } + if len(p.NumAttentionHeadsPerLayer) == 0 && p.NumAttentionHeads == 0 { + return fmt.Errorf("laguna: num_attention_heads or num_attention_heads_per_layer must be set") + } + for i, heads := range p.NumAttentionHeadsPerLayer { + if heads == 0 { + return fmt.Errorf("laguna: num_attention_heads_per_layer[%d] must be non-zero", i) + } + } + + return nil +} + +func (p *lagunaModel) numHeadsForLayer(layer uint32) uint32 { + if len(p.NumAttentionHeadsPerLayer) > int(layer) && p.NumAttentionHeadsPerLayer[layer] > 0 { + return p.NumAttentionHeadsPerLayer[layer] + } + return p.NumAttentionHeads +} + +func (p *lagunaModel) layerUsesMoE(layer uint32) bool { + for _, denseLayer := range p.MLPOnlyLayers { + if denseLayer == layer { + return false + } + } + step := cmp.Or(p.DecoderSparseStep, uint32(1)) + return p.NumExperts > 0 && (layer+1)%step == 0 +} + +func (p *lagunaModel) Replacements() []string { + return []string{ + "lm_head", "output", + "model.embed_tokens", "token_embd", + "model.norm", "output_norm", + "model.layers", "blk", + "input_layernorm", "attn_norm", + "post_attention_layernorm", "ffn_norm", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.o_proj", "attn_output", + "self_attn.g_proj", "attn_g", + "self_attn.q_norm", "attn_q_norm", + "self_attn.k_norm", "attn_k_norm", + "mlp.gate_proj", "ffn_gate", + "mlp.up_proj", "ffn_up", + "mlp.down_proj", "ffn_down", + "mlp.gate.weight", "ffn_gate_inp.weight", + "mlp.experts.e_score_correction_bias", "exp_probs_b.bias", + "mlp.shared_expert.gate_proj", "ffn_gate_shexp", + "mlp.shared_expert.up_proj", "ffn_up_shexp", + "mlp.shared_expert.down_proj", "ffn_down_shexp", + "mlp.experts.*.gate_proj", "ffn_gate_exps", + "mlp.experts.*.up_proj", "ffn_up_exps", + "mlp.experts.*.down_proj", "ffn_down_exps", + } +} + +func (p *lagunaModel) Tensors(ts []Tensor) []*ggml.Tensor { + // Current Laguna drops store routed MoE experts as separate per-expert + // tensors. GGUF stores each projection as one stacked tensor. If future + // drops change expert naming or layout, update these patterns with a + // focused conversion test using the new tensor names. + merges := make([]merge, 0, p.NumHiddenLayers*3) + for i := range p.NumHiddenLayers { + merges = append(merges, + merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i), + }, + merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_up_exps.weight", i), + }, + merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_down_exps.weight", i), + }, + ) + } + + out, rest := mergeTensors(ts, merges...) + for _, t := range rest { + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + return out +} + +func (p *lagunaModel) specialTokenTypes() []string { + return []string{"bos", "eos", "pad", "unk"} +} + +func lagunaLayerIsSliding(layerType string) bool { + return strings.EqualFold(layerType, "sliding_attention") +} + +func lagunaLayerIsGlobal(layerType string) bool { + return strings.EqualFold(layerType, "full_attention") || strings.EqualFold(layerType, "global_attention") +} + +func lagunaLeadingDensePrefix(layers []uint32) (uint32, bool) { + for i, v := range layers { + if v != uint32(i) { + return 0, false + } + } + return uint32(len(layers)), true +} + +func lagunaDenseLayers(mlpOnlyLayers []uint32, mlpLayerTypes []string) ([]uint32, error) { + if len(mlpOnlyLayers) > 0 { + return mlpOnlyLayers, nil + } + if len(mlpLayerTypes) == 0 { + return nil, nil + } + + denseLayers := make([]uint32, 0, len(mlpLayerTypes)) + for i, layerType := range mlpLayerTypes { + switch { + case strings.EqualFold(layerType, "dense"): + denseLayers = append(denseLayers, uint32(i)) + case strings.EqualFold(layerType, "sparse"): + default: + return nil, fmt.Errorf("laguna: unsupported mlp_layer_types[%d]=%q", i, layerType) + } + } + return denseLayers, nil +} + +func lagunaMoeGatingFunc(useSigmoid bool) uint32 { + if useSigmoid { + return lagunaGatingFuncSigmoid + } + return lagunaGatingFuncSoftmax +} + +func lagunaAttentionFactor(ropeType string, scaleFactor, attentionFactor float32) float32 { + if attentionFactor != 0 { + return attentionFactor + } + if strings.EqualFold(ropeType, "yarn") && scaleFactor > 1 { + return float32(0.1*math.Log(float64(scaleFactor)) + 1) + } + return 1 +} + +func lagunaRopeDim(headDim uint32, partialRotaryFactor float32) uint32 { + if headDim == 0 { + return 0 + } + dim := uint32(float32(headDim) * partialRotaryFactor) + if dim == 0 || dim > headDim { + dim = headDim + } + if dim%2 != 0 { + dim-- + } + if dim == 0 { + return headDim + } + return dim +} + +var ( + _ ModelConverter = (*lagunaModel)(nil) + _ moreParser = (*lagunaModel)(nil) +) diff --git a/convert/convert_laguna_test.go b/convert/convert_laguna_test.go new file mode 100644 index 000000000..a636fd92a --- /dev/null +++ b/convert/convert_laguna_test.go @@ -0,0 +1,450 @@ +package convert + +import ( + "encoding/json" + "fmt" + "io" + "math" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/fs/ggml" +) + +type lagunaTestTensor struct { + tensorBase +} + +func newLagunaTestTensor(name string, shape ...uint64) Tensor { + return &lagunaTestTensor{tensorBase: tensorBase{name: name, shape: shape}} +} + +func (t *lagunaTestTensor) WriteTo(io.Writer) (int64, error) { + return 0, nil +} + +func (t *lagunaTestTensor) Clone() Tensor { + return &lagunaTestTensor{tensorBase: tensorBase{ + name: t.name, + shape: append([]uint64(nil), t.shape...), + }} +} + +func TestLagunaReplacements(t *testing.T) { + p := lagunaModel{} + r := strings.NewReplacer(p.Replacements()...) + + tests := []struct { + name string + in string + want string + }{ + {"embed", "model.embed_tokens.weight", "token_embd.weight"}, + {"final_norm", "model.norm.weight", "output_norm.weight"}, + {"lm_head", "lm_head.weight", "output.weight"}, + {"block prefix", "model.layers.7.input_layernorm.weight", "blk.7.attn_norm.weight"}, + {"q", "model.layers.3.self_attn.q_proj.weight", "blk.3.attn_q.weight"}, + {"k", "model.layers.3.self_attn.k_proj.weight", "blk.3.attn_k.weight"}, + {"v", "model.layers.3.self_attn.v_proj.weight", "blk.3.attn_v.weight"}, + {"o", "model.layers.3.self_attn.o_proj.weight", "blk.3.attn_output.weight"}, + {"g", "model.layers.3.self_attn.g_proj.weight", "blk.3.attn_g.weight"}, + {"q_norm", "model.layers.3.self_attn.q_norm.weight", "blk.3.attn_q_norm.weight"}, + {"k_norm", "model.layers.3.self_attn.k_norm.weight", "blk.3.attn_k_norm.weight"}, + {"post_attn_norm", "model.layers.3.post_attention_layernorm.weight", "blk.3.ffn_norm.weight"}, + {"dense gate", "model.layers.0.mlp.gate_proj.weight", "blk.0.ffn_gate.weight"}, + {"dense up", "model.layers.0.mlp.up_proj.weight", "blk.0.ffn_up.weight"}, + {"dense down", "model.layers.0.mlp.down_proj.weight", "blk.0.ffn_down.weight"}, + {"shexp gate", "model.layers.5.mlp.shared_expert.gate_proj.weight", "blk.5.ffn_gate_shexp.weight"}, + {"shexp up", "model.layers.5.mlp.shared_expert.up_proj.weight", "blk.5.ffn_up_shexp.weight"}, + {"shexp down", "model.layers.5.mlp.shared_expert.down_proj.weight", "blk.5.ffn_down_shexp.weight"}, + {"router", "model.layers.5.mlp.gate.weight", "blk.5.ffn_gate_inp.weight"}, + {"score bias", "model.layers.5.mlp.experts.e_score_correction_bias", "blk.5.exp_probs_b.bias"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := r.Replace(tc.in); got != tc.want { + t.Errorf("Replace(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +func TestLagunaValidateRejectsUnsupportedVariants(t *testing.T) { + base := validLagunaTestModel() + tests := []struct { + name string + edit func(*lagunaModel) + want string + }{ + { + name: "per-element gating", + edit: func(m *lagunaModel) { + m.Gating = "per-element" + }, + want: "unsupported attention gating", + }, + { + name: "attention sinks", + edit: func(m *lagunaModel) { + m.SwaAttentionSinkEnabled = true + }, + want: "swa_attention_sink_enabled=true", + }, + { + name: "qk norm disabled", + edit: func(m *lagunaModel) { + m.QKNormType = "none" + }, + want: "unsupported qk_norm_type", + }, + { + name: "softmax moe", + edit: func(m *lagunaModel) { + m.MoERouterUseSigmoid = false + }, + want: "moe_router_use_sigmoid=false", + }, + { + name: "router weight on input", + edit: func(m *lagunaModel) { + m.MoEApplyRouterWeightOnInput = true + }, + want: "moe_apply_router_weight_on_input=true", + }, + { + name: "unknown layer type", + edit: func(m *lagunaModel) { + m.LayerTypes[1] = "local_attention" + }, + want: "unsupported layer_types[1]", + }, + { + name: "nonstandard dense layout", + edit: func(m *lagunaModel) { + m.MLPOnlyLayers = []uint32{0, 3} + }, + want: "unsupported mlp_only_layers", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + m := base + m.LayerTypes = append([]string(nil), base.LayerTypes...) + m.NumAttentionHeadsPerLayer = append([]uint32(nil), base.NumAttentionHeadsPerLayer...) + m.MLPOnlyLayers = append([]uint32(nil), base.MLPOnlyLayers...) + tc.edit(&m) + err := m.validate() + if err == nil || !strings.Contains(err.Error(), tc.want) { + t.Fatalf("validate() error = %v, want substring %q", err, tc.want) + } + }) + } +} + +func TestLagunaGAConfigNormalizesBoolGatingAndNestedRope(t *testing.T) { + var m lagunaModel + if err := json.Unmarshal([]byte(`{ + "architectures": ["LagunaForCausalLM"], + "num_hidden_layers": 1, + "hidden_size": 8, + "num_attention_heads": 2, + "num_key_value_heads": 1, + "head_dim": 4, + "gating": true, + "num_experts": 2, + "num_experts_per_tok": 1, + "moe_intermediate_size": 4, + "shared_expert_intermediate_size": 4, + "decoder_sparse_step": 1, + "mlp_layer_types": ["dense"], + "rope_parameters": { + "full_attention": { + "rope_theta": 500000, + "rope_type": "yarn", + "factor": 32, + "original_max_position_embeddings": 4096, + "beta_fast": 64, + "beta_slow": 1, + "attention_factor": 1, + "partial_rotary_factor": 0.5 + }, + "sliding_attention": { + "rope_theta": 10000, + "rope_type": "default", + "partial_rotary_factor": 1 + } + } + }`), &m); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + + if err := m.validate(); err != nil { + t.Fatalf("validate() error = %v", err) + } + if m.Gating != "true" { + t.Fatalf("Gating = %q, want raw true marker", m.Gating) + } + if !m.Gating.perHead() { + t.Fatal("expected bool gating to normalize as per-head support") + } + if m.QKNormType != "rmsnorm" { + t.Fatalf("QKNormType = %q, want rmsnorm default", m.QKNormType) + } + if !m.MoERouterUseSigmoid { + t.Fatal("MoERouterUseSigmoid should default true") + } + if !m.NormTopKProb { + t.Fatal("NormTopKProb should default true") + } + if diff := cmp.Diff(m.MLPOnlyLayers, []uint32{0}); diff != "" { + t.Fatalf("MLPOnlyLayers mismatch (-got +want):\n%s", diff) + } + if m.RopeParameters.RopeTheta != 500000 || m.RopeParameters.PartialRotaryFactor != 0.5 { + t.Fatalf("full rope = %#v, want theta=500000 partial=0.5", m.RopeParameters) + } + if m.SwaRopeParameters.RopeTheta != 10000 || m.SwaRopeParameters.PartialRotaryFactor != 1 { + t.Fatalf("swa rope = %#v, want theta=10000 partial=1", m.SwaRopeParameters) + } +} + +func validLagunaTestModel() lagunaModel { + return lagunaModel{ + ModelParameters: ModelParameters{ + VocabSize: 32, + }, + NumHiddenLayers: 2, + HiddenSize: 8, + IntermediateSize: 16, + NumAttentionHeads: 2, + NumKeyValueHeads: 1, + HeadDim: 4, + RMSNormEPS: 1e-6, + MaxPositionEmbeddings: 4096, + SlidingWindow: 512, + Gating: "per-head", + QKNormType: "rmsnorm", + LayerTypes: []string{"global_attention", "sliding_attention"}, + NumAttentionHeadsPerLayer: []uint32{2, 2}, + NumExperts: 2, + NumExpertsPerTok: 1, + MoEIntermediateSize: 4, + SharedExpertIntermediateSize: 4, + NormTopKProb: true, + MoeRoutedScalingFactor: 2.5, + MoERouterUseSigmoid: true, + DecoderSparseStep: 1, + MLPOnlyLayers: []uint32{0}, + } +} + +func validLagunaTestTensors(m lagunaModel) []Tensor { + ts := []Tensor{ + newLagunaTestTensor("token_embd.weight", uint64(m.VocabSize), uint64(m.HiddenSize)), + newLagunaTestTensor("output_norm.weight", uint64(m.HiddenSize)), + } + + for layer := range m.NumHiddenLayers { + prefix := fmt.Sprintf("blk.%d", layer) + heads := uint64(m.numHeadsForLayer(layer)) + attnWidth := heads * uint64(m.HeadDim) + kvWidth := uint64(m.NumKeyValueHeads * m.HeadDim) + ts = append(ts, + newLagunaTestTensor(prefix+".attn_norm.weight", uint64(m.HiddenSize)), + newLagunaTestTensor(prefix+".ffn_norm.weight", uint64(m.HiddenSize)), + newLagunaTestTensor(prefix+".attn_q.weight", attnWidth, uint64(m.HiddenSize)), + newLagunaTestTensor(prefix+".attn_k.weight", kvWidth, uint64(m.HiddenSize)), + newLagunaTestTensor(prefix+".attn_v.weight", kvWidth, uint64(m.HiddenSize)), + newLagunaTestTensor(prefix+".attn_output.weight", uint64(m.HiddenSize), attnWidth), + newLagunaTestTensor(prefix+".attn_g.weight", heads, uint64(m.HiddenSize)), + newLagunaTestTensor(prefix+".attn_q_norm.weight", uint64(m.HeadDim)), + newLagunaTestTensor(prefix+".attn_k_norm.weight", uint64(m.HeadDim)), + ) + + if m.layerUsesMoE(layer) { + ts = append(ts, + newLagunaTestTensor(prefix+".ffn_gate_inp.weight", uint64(m.NumExperts), uint64(m.HiddenSize)), + newLagunaTestTensor(prefix+".exp_probs_b.bias", uint64(m.NumExperts)), + newLagunaTestTensor(prefix+".ffn_gate_shexp.weight", uint64(m.SharedExpertIntermediateSize), uint64(m.HiddenSize)), + newLagunaTestTensor(prefix+".ffn_up_shexp.weight", uint64(m.SharedExpertIntermediateSize), uint64(m.HiddenSize)), + newLagunaTestTensor(prefix+".ffn_down_shexp.weight", uint64(m.HiddenSize), uint64(m.SharedExpertIntermediateSize)), + ) + for expert := range m.NumExperts { + ts = append(ts, + newLagunaTestTensor(fmt.Sprintf("%s.mlp.experts.%d.gate_proj.weight", prefix, expert), uint64(m.MoEIntermediateSize), uint64(m.HiddenSize)), + newLagunaTestTensor(fmt.Sprintf("%s.mlp.experts.%d.up_proj.weight", prefix, expert), uint64(m.MoEIntermediateSize), uint64(m.HiddenSize)), + newLagunaTestTensor(fmt.Sprintf("%s.mlp.experts.%d.down_proj.weight", prefix, expert), uint64(m.HiddenSize), uint64(m.MoEIntermediateSize)), + ) + } + } else { + ts = append(ts, + newLagunaTestTensor(prefix+".ffn_gate.weight", uint64(m.IntermediateSize), uint64(m.HiddenSize)), + newLagunaTestTensor(prefix+".ffn_up.weight", uint64(m.IntermediateSize), uint64(m.HiddenSize)), + newLagunaTestTensor(prefix+".ffn_down.weight", uint64(m.HiddenSize), uint64(m.IntermediateSize)), + ) + } + } + + return ts +} + +func TestLagunaTensorsMergeRoutedExperts(t *testing.T) { + m := validLagunaTestModel() + out := m.Tensors(validLagunaTestTensors(m)) + + tensors := make(map[string]*ggml.Tensor, len(out)) + for _, t := range out { + tensors[t.Name] = t + } + + tests := map[string][]uint64{ + "blk.1.ffn_gate_exps.weight": {uint64(m.NumExperts), uint64(m.MoEIntermediateSize), uint64(m.HiddenSize)}, + "blk.1.ffn_up_exps.weight": {uint64(m.NumExperts), uint64(m.MoEIntermediateSize), uint64(m.HiddenSize)}, + "blk.1.ffn_down_exps.weight": {uint64(m.NumExperts), uint64(m.HiddenSize), uint64(m.MoEIntermediateSize)}, + } + for name, wantShape := range tests { + tensor, ok := tensors[name] + if !ok { + t.Fatalf("missing merged tensor %q", name) + } + if diff := cmp.Diff(wantShape, tensor.Shape); diff != "" { + t.Fatalf("%s shape mismatch (-want +got):\n%s", name, diff) + } + } + + for expert := range m.NumExperts { + name := fmt.Sprintf("blk.1.mlp.experts.%d.gate_proj.weight", expert) + if _, ok := tensors[name]; ok { + t.Fatalf("unexpected unmerged expert tensor %q", name) + } + } +} + +func TestLagunaKVShape(t *testing.T) { + m := lagunaModel{ + NumHiddenLayers: 4, + HiddenSize: 128, + IntermediateSize: 256, + NumAttentionHeads: 8, + NumKeyValueHeads: 4, + HeadDim: 16, + RMSNormEPS: 1e-6, + MaxPositionEmbeddings: 4096, + SlidingWindow: 512, + PartialRotaryFactor: 0.5, + Gating: "per-head", + QKNormType: "rmsnorm", + LayerTypes: []string{"full_attention", "sliding_attention", "sliding_attention", "sliding_attention"}, + NumAttentionHeadsPerLayer: []uint32{8, 16, 16, 16}, + NumExperts: 32, + NumExpertsPerTok: 4, + MoEIntermediateSize: 64, + SharedExpertIntermediateSize: 64, + NormTopKProb: true, + MoeRoutedScalingFactor: 2.5, + MoERouterUseSigmoid: true, + DecoderSparseStep: 1, + MLPOnlyLayers: []uint32{0}, + } + m.RopeParameters.RopeTheta = 500000 + m.RopeParameters.RopeType = "yarn" + m.RopeParameters.Factor = 32 + m.RopeParameters.OriginalMaxPositionEmbeddings = 4096 + m.RopeParameters.BetaFast = 64 + m.RopeParameters.BetaSlow = 1 + m.SwaRopeParameters.RopeTheta = 10000 + m.SwaRopeParameters.RopeType = "linear" + m.SwaRopeParameters.Factor = 1 + m.SwaRopeParameters.PartialRotaryFactor = 1 + + kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}, Template: "{% include 'chat_template.jinja' %}"}) + required := []string{ + "general.architecture", + "tokenizer.ggml.pre", + "laguna.block_count", + "laguna.context_length", + "laguna.embedding_length", + "laguna.feed_forward_length", + "laguna.attention.head_count", + "laguna.attention.head_count_kv", + "laguna.attention.key_length", + "laguna.attention.value_length", + "laguna.attention.layer_norm_rms_epsilon", + "laguna.attention.sliding_window", + "laguna.attention.layer_types", + "laguna.attention.sliding_window_pattern", + "laguna.attention.gating_type", + "laguna.attention.qk_norm", + "laguna.expert_count", + "laguna.expert_used_count", + "laguna.expert_feed_forward_length", + "laguna.expert_shared_feed_forward_length", + "laguna.expert_shared_count", + "laguna.expert_weights_norm", + "laguna.expert_weights_scale", + "laguna.expert_gating_func", + "laguna.leading_dense_block_count", + "laguna.dense_layers", + "laguna.rope.freq_base", + "laguna.rope.scaling.type", + "laguna.rope.scaling.factor", + "laguna.rope.partial_rotary_factor", + "laguna.rope.swa.freq_base", + "laguna.rope.swa.scaling.type", + "laguna.rope.dimension_count", + "laguna.rope.swa.dimension_count", + } + for _, k := range required { + if _, ok := kv[k]; !ok { + t.Errorf("missing required KV: %s", k) + } + } + + if got := kv["general.architecture"]; got != "laguna" { + t.Errorf("architecture = %v, want laguna", got) + } + if got := kv["tokenizer.ggml.add_bos_token"]; got != false { + t.Errorf("tokenizer.ggml.add_bos_token = %v, want false", got) + } + if _, ok := kv["tokenizer.chat_template"]; ok { + t.Fatal("tokenizer.chat_template should be omitted for Laguna") + } + if got := kv["laguna.expert_gating_func"]; got != lagunaGatingFuncSigmoid { + t.Errorf("expert_gating_func = %v, want sigmoid(%d)", got, lagunaGatingFuncSigmoid) + } + if got := kv["laguna.leading_dense_block_count"]; got != uint32(1) { + t.Errorf("leading_dense_block_count = %v, want 1", got) + } + if got := kv["laguna.rope.dimension_count"]; got != uint32(8) { + t.Errorf("rope.dimension_count = %v, want 8", got) + } + if got := kv["laguna.rope.swa.dimension_count"]; got != uint32(16) { + t.Errorf("rope.swa.dimension_count = %v, want 16", got) + } + if got, ok := kv["laguna.attention.layer_types"].([]uint32); !ok || len(got) != 4 || got[0] != 0 || got[1] != 1 || got[2] != 1 || got[3] != 1 { + t.Fatalf("layer_types = %#v, want [0 1 1 1]", kv["laguna.attention.layer_types"]) + } + if got, ok := kv["laguna.attention.sliding_window_pattern"].([]bool); !ok || len(got) != 4 || got[0] || !got[1] || !got[2] || !got[3] { + t.Fatalf("sliding_window_pattern = %#v, want [false true true true]", kv["laguna.attention.sliding_window_pattern"]) + } +} + +func TestLagunaKVYarnAttentionFactorFallback(t *testing.T) { + m := validLagunaTestModel() + m.RopeParameters.RopeType = "yarn" + m.RopeParameters.Factor = 32 + + kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}}) + got, ok := kv["laguna.rope.scaling.attn_factor"].(float32) + if !ok { + t.Fatalf("attn_factor type = %T, want float32", kv["laguna.rope.scaling.attn_factor"]) + } + + want := float32(0.1*math.Log(32) + 1) + if diff := math.Abs(float64(got - want)); diff > 1e-6 { + t.Fatalf("attn_factor = %v, want %v", got, want) + } +} diff --git a/convert/convert_nemotron_h.go b/convert/convert_nemotron_h.go index 59ea461f1..4eab2006f 100644 --- a/convert/convert_nemotron_h.go +++ b/convert/convert_nemotron_h.go @@ -3,6 +3,7 @@ package convert import ( "cmp" "encoding/json" + "errors" "fmt" "io/fs" "math" @@ -69,7 +70,415 @@ type nemotronHModel struct { ExpertGroupUsedCount uint32 `json:"topk_group"` } +type nemotronHNanoVLModel struct { + ModelParameters + MaxSequenceLength uint32 `json:"max_sequence_length"` + ForceImageSize uint32 `json:"force_image_size"` + DownsampleRatio float32 `json:"downsample_ratio"` + PatchSize uint32 `json:"patch_size"` + UseThumbnail *bool `json:"use_thumbnail"` + ImgContextTokenID uint32 `json:"img_context_token_id"` + ImgContextToken string `json:"img_context_token"` + ImgStartToken string `json:"img_start_token"` + ImgEndToken string `json:"img_end_token"` + VitHiddenSize uint32 `json:"vit_hidden_size"` + ProjectorHidden uint32 `json:"projector_hidden_size"` + SoundContextTokenID uint32 `json:"sound_context_token_id"` + SoundContextToken string `json:"sound_context_token"` + NormMean []float32 `json:"norm_mean"` + NormStd []float32 `json:"norm_std"` + VisionConfig radioConfig `json:"vision_config"` + SoundConfig soundConfig `json:"sound_config"` + LLMConfig nemotronHModel `json:"llm_config"` + Preprocessor struct { + ImageSize uint32 `json:"image_size"` + PatchSize uint32 `json:"patch_size"` + DownsampleRatio float32 `json:"downsample_ratio"` + MaxNumTiles uint32 `json:"max_num_tiles"` + UseThumbnail *bool `json:"use_thumbnail"` + NormMean []float32 `json:"norm_mean"` + NormStd []float32 `json:"norm_std"` + } +} + +type soundConfig struct { + ModelType string `json:"model_type"` + HiddenSize uint32 `json:"hidden_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + IntermediateSize uint32 `json:"intermediate_size"` + ConvKernelSize uint32 `json:"conv_kernel_size"` + SubsamplingConvChannels uint32 `json:"subsampling_conv_channels"` + SubsamplingConvKernelSize uint32 `json:"subsampling_conv_kernel_size"` + SubsamplingConvStride uint32 `json:"subsampling_conv_stride"` + SubsamplingFactor uint32 `json:"subsampling_factor"` + NumMelBins uint32 `json:"num_mel_bins"` + ProjectionHiddenSize uint32 `json:"projection_hidden_size"` + SamplingRate uint32 `json:"sampling_rate"` + ScaleInput bool `json:"scale_input"` +} + +type radioConfig struct { + Version string `json:"version"` + PatchSize uint32 `json:"patch_size"` + MaxResolution uint32 `json:"max_resolution"` + MinNumPatches uint32 `json:"min_num_patches"` + MaxNumPatches uint32 `json:"max_num_patches"` + SeparateVideoEmbedder bool `json:"separate_video_embedder"` + Args struct { + MinNumPatches uint32 `json:"min_num_patches"` + MaxNumPatches uint32 `json:"max_num_patches"` + } `json:"args"` +} + var _ ModelConverter = (*nemotronHModel)(nil) +var _ ModelConverter = (*nemotronHNanoVLModel)(nil) + +func (n *nemotronHNanoVLModel) parseMore(fsys fs.FS) error { + if n.MaxSequenceLength > 0 { + n.LLMConfig.MaxPositionEmbeddings = n.MaxSequenceLength + } + + if err := n.LLMConfig.parseMore(fsys); err != nil { + return err + } + + if bts, err := fs.ReadFile(fsys, "preprocessor_config.json"); err == nil { + if err := json.Unmarshal(bts, &n.Preprocessor); err != nil { + return fmt.Errorf("nemotron_h_omni: parse preprocessor_config.json: %w", err) + } + } else if !errors.Is(err, fs.ErrNotExist) { + return err + } + + if version := strings.TrimSpace(n.VisionConfig.Version); version != "" && version != "radio_v2.5-h" { + return fmt.Errorf("nemotron_h_omni: unsupported RADIO version %q", version) + } + if patchSize := n.visionPatchSize(); patchSize != 16 { + return fmt.Errorf("nemotron_h_omni: unsupported vision patch_size=%d", patchSize) + } + if scale := n.visionProjectorScaleFactor(); scale != 2 { + return fmt.Errorf("nemotron_h_omni: unsupported vision projector scale factor=%d", scale) + } + + if n.SoundConfig.NumHiddenLayers > 0 { + if modelType := strings.TrimSpace(n.SoundConfig.ModelType); modelType != "" && modelType != "parakeet" { + return fmt.Errorf("nemotron_h_omni: unsupported sound model_type %q", modelType) + } + if n.soundHiddenSize() == 0 { + return fmt.Errorf("nemotron_h_omni: sound hidden_size must be set") + } + if n.soundAttentionHeads() == 0 { + return fmt.Errorf("nemotron_h_omni: sound num_attention_heads must be set") + } + if n.soundSubsamplingFactor() != 8 { + return fmt.Errorf("nemotron_h_omni: unsupported sound subsampling_factor=%d", n.soundSubsamplingFactor()) + } + if n.soundMelBins() != 128 { + return fmt.Errorf("nemotron_h_omni: unsupported sound num_mel_bins=%d", n.soundMelBins()) + } + } + + return nil +} + +func (n *nemotronHNanoVLModel) KV(t *Tokenizer) KV { + kv := n.LLMConfig.KV(t) + kv["general.architecture"] = "nemotron_h_omni" + + kv["vision.block_count"] = n.visionBlockCount() + kv["vision.embedding_length"] = n.visionEmbeddingLength() + kv["vision.feed_forward_length"] = n.visionFeedForwardLength() + kv["vision.attention.head_count"] = n.visionAttentionHeads() + kv["vision.attention.layer_norm_epsilon"] = float32(1e-6) + kv["vision.patch_size"] = n.visionPatchSize() + kv["vision.image_size"] = n.visionImageSize() + kv["vision.max_tiles"] = n.visionMaxTiles() + kv["vision.use_thumbnail"] = n.visionUseThumbnail() + if minPatches := n.visionMinNumPatches(); minPatches > 0 { + kv["vision.min_num_patches"] = minPatches + } + if maxPatches := n.visionMaxNumPatches(); maxPatches > 0 { + kv["vision.max_num_patches"] = maxPatches + } + kv["vision.num_channels"] = uint32(3) + kv["vision.image_mean"] = slices.Clone(defaultFloat32Slice(n.visionMean(), imageNetStandardMean)) + kv["vision.image_std"] = slices.Clone(defaultFloat32Slice(n.visionStd(), imageNetStandardSTD)) + kv["vision.projector.scale_factor"] = n.visionProjectorScaleFactor() + + setTokenID := func(key string, explicit uint32, token string) { + if explicit > 0 { + kv[key] = explicit + return + } + if t == nil || t.Vocabulary == nil { + return + } + for i, v := range t.Vocabulary.Tokens { + if v == token { + kv[key] = uint32(i) + return + } + } + } + + setTokenID("vision.image_token_id", n.ImgContextTokenID, cmp.Or(n.ImgContextToken, "")) + setTokenID("vision.image_start_token_id", 0, cmp.Or(n.ImgStartToken, "")) + setTokenID("vision.image_end_token_id", 0, cmp.Or(n.ImgEndToken, "")) + + if n.SoundConfig.NumHiddenLayers > 0 { + kv["audio.block_count"] = n.SoundConfig.NumHiddenLayers + kv["audio.embedding_length"] = n.soundHiddenSize() + kv["audio.feed_forward_length"] = n.soundFeedForwardLength() + kv["audio.attention.head_count"] = n.soundAttentionHeads() + kv["audio.attention.layer_norm_epsilon"] = float32(1e-5) + kv["audio.conv_kernel_size"] = n.soundConvKernelSize() + kv["audio.num_mel_bins"] = n.soundMelBins() + kv["audio.sample_rate"] = n.soundSampleRate() + kv["audio.subsampling_factor"] = n.soundSubsamplingFactor() + kv["audio.subsampling_conv_channels"] = n.soundSubsamplingConvChannels() + kv["audio.subsampling_conv_kernel_size"] = n.soundSubsamplingConvKernelSize() + kv["audio.subsampling_conv_stride"] = n.soundSubsamplingConvStride() + kv["audio.projection_hidden_size"] = n.soundProjectionHiddenSize() + kv["audio.scale_input"] = n.SoundConfig.ScaleInput + setTokenID("audio.sound_token_id", n.SoundContextTokenID, cmp.Or(n.SoundContextToken, "")) + } + + return kv +} + +func (n *nemotronHNanoVLModel) Tensors(ts []Tensor) []*ggml.Tensor { + var textTensors []Tensor + var out []*ggml.Tensor + + for _, t := range ts { + switch { + case isNemotronHNanoVLOmittedTensor(t.Name()): + continue + case strings.Contains(t.Name(), ".attn_qkv"): + out = append(out, slices.Collect(splitDim(t, 0, + split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")}, + split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")}, + split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")}, + ))...) + case t.Name() == "v.position_embd": + shape := t.Shape() + if len(shape) == 3 && shape[0] == 1 { + shape = shape[1:] + } + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: shape, + WriterTo: t, + }) + case strings.HasPrefix(t.Name(), "a.") || strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm."): + name := t.Name() + shape := slices.Clone(t.Shape()) + if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, ".conv_dw.") && strings.HasSuffix(name, ".weight") && len(shape) == 3 { + t.SetRepacker(squeezeMiddleDim) + shape = []uint64{shape[0], shape[2]} + } + if strings.HasPrefix(name, "a.blk.") && (strings.Contains(name, ".conv_pw1.") || strings.Contains(name, ".conv_pw2.")) && strings.HasSuffix(name, ".weight") && len(shape) == 3 && shape[2] == 1 { + t.SetRepacker(squeezeLastDim) + shape = shape[:2] + } + out = append(out, &ggml.Tensor{ + Name: name, + Kind: t.Kind(), + Shape: shape, + WriterTo: t, + }) + default: + textTensors = append(textTensors, t) + } + } + + return append(n.LLMConfig.Tensors(textTensors), out...) +} + +func (n *nemotronHNanoVLModel) Replacements() []string { + return append([]string{ + "language_model.", "", + "vision_model.radio_model.model.patch_generator.embedder", "v.patch_embd", + "vision_model.radio_model.model.patch_generator.pos_embed", "v.position_embd", + "vision_model.radio_model.model.patch_generator.cls_token.token", "v.cls_embd", + "vision_model.radio_model.model.blocks", "v.blk", + "attn.qkv", "attn_qkv", + "attn.proj", "attn_out", + "mlp.fc1", "ffn_up", + "mlp.fc2", "ffn_down", + "norm1", "ln1", + "norm2", "ln2", + "mlp1.0", "mm.norm", + "mlp1.1", "mm.1", + "mlp1.3", "mm.2", + "sound_encoder.encoder.feature_extractor.featurizer.fb", "a.feature_extractor.fb", + "sound_encoder.encoder.feature_extractor.featurizer.window", "a.feature_extractor.window", + "sound_encoder.encoder.subsampling.layers.0", "a.subsampling.conv0", + "sound_encoder.encoder.subsampling.layers.2", "a.subsampling.dw1", + "sound_encoder.encoder.subsampling.layers.3", "a.subsampling.pw1", + "sound_encoder.encoder.subsampling.layers.5", "a.subsampling.dw2", + "sound_encoder.encoder.subsampling.layers.6", "a.subsampling.pw2", + "sound_encoder.encoder.subsampling.linear", "a.subsampling.linear", + "sound_encoder.encoder.layers", "a.blk", + "feed_forward1.linear1", "ffn1_up", + "feed_forward1.linear2", "ffn1_down", + "feed_forward2.linear1", "ffn2_up", + "feed_forward2.linear2", "ffn2_down", + "norm_feed_forward1", "ffn1_norm", + "norm_feed_forward2", "ffn2_norm", + "norm_self_att", "attn_norm", + "norm_conv", "conv_norm", + "norm_out", "out_norm", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.o_proj", "attn_out", + "self_attn.relative_k_proj", "attn_rel_k", + "self_attn.bias_u", "attn_bias_u", + "self_attn.bias_v", "attn_bias_v", + "conv.pointwise_conv1", "conv_pw1", + "conv.pointwise_conv2", "conv_pw2", + "conv.depthwise_conv", "conv_dw", + "conv.norm", "conv_bn", + "sound_projection.norm", "mm.a.norm", + "sound_projection.linear1", "mm.a.1", + "sound_projection.linear2", "mm.a.2", + }, n.LLMConfig.Replacements()...) +} + +func (n *nemotronHNanoVLModel) specialTokenTypes() []string { + return n.LLMConfig.specialTokenTypes() +} + +func isNemotronHNanoVLOmittedTensor(name string) bool { + return strings.HasSuffix(name, ".conv_bn.num_batches_tracked") || + strings.HasPrefix(name, "vision_model.radio_model.input_conditioner.") || + strings.HasPrefix(name, "vision_model.radio_model.model.patch_generator.video_embedder") +} + +func squeezeLastDim(_ string, data []float32, _ []uint64) ([]float32, error) { + return data, nil +} + +func (n *nemotronHNanoVLModel) visionImageSize() uint32 { + return cmp.Or(n.ForceImageSize, n.Preprocessor.ImageSize, uint32(512)) +} + +func (n *nemotronHNanoVLModel) visionPatchSize() uint32 { + return cmp.Or(n.PatchSize, n.Preprocessor.PatchSize, n.VisionConfig.PatchSize, uint32(16)) +} + +func (n *nemotronHNanoVLModel) visionProjectorScaleFactor() uint32 { + ratio := cmp.Or(n.DownsampleRatio, n.Preprocessor.DownsampleRatio, float32(0.5)) + if ratio <= 0 { + return 2 + } + + return max(uint32(1), uint32(math.Round(1.0/float64(ratio)))) +} + +func (n *nemotronHNanoVLModel) visionBlockCount() uint32 { + return 32 +} + +func (n *nemotronHNanoVLModel) visionEmbeddingLength() uint32 { + return cmp.Or(n.VitHiddenSize, uint32(1280)) +} + +func (n *nemotronHNanoVLModel) visionAttentionHeads() uint32 { + return 16 +} + +func (n *nemotronHNanoVLModel) visionFeedForwardLength() uint32 { + return 4 * n.visionEmbeddingLength() +} + +func (n *nemotronHNanoVLModel) visionMaxTiles() uint32 { + return cmp.Or(n.Preprocessor.MaxNumTiles, uint32(12)) +} + +func (n *nemotronHNanoVLModel) visionMinNumPatches() uint32 { + return cmp.Or(n.VisionConfig.MinNumPatches, n.VisionConfig.Args.MinNumPatches) +} + +func (n *nemotronHNanoVLModel) visionMaxNumPatches() uint32 { + return cmp.Or(n.VisionConfig.MaxNumPatches, n.VisionConfig.Args.MaxNumPatches) +} + +func (n *nemotronHNanoVLModel) visionUseThumbnail() bool { + for _, v := range []*bool{n.UseThumbnail, n.Preprocessor.UseThumbnail} { + if v != nil { + return *v + } + } + + return true +} + +func (n *nemotronHNanoVLModel) visionMean() []float32 { + if len(n.NormMean) > 0 { + return n.NormMean + } + return n.Preprocessor.NormMean +} + +func (n *nemotronHNanoVLModel) visionStd() []float32 { + if len(n.NormStd) > 0 { + return n.NormStd + } + return n.Preprocessor.NormStd +} + +func (n *nemotronHNanoVLModel) soundHiddenSize() uint32 { + return cmp.Or(n.SoundConfig.HiddenSize, uint32(1024)) +} + +func (n *nemotronHNanoVLModel) soundAttentionHeads() uint32 { + return cmp.Or(n.SoundConfig.NumAttentionHeads, uint32(8)) +} + +func (n *nemotronHNanoVLModel) soundFeedForwardLength() uint32 { + return cmp.Or(n.SoundConfig.IntermediateSize, 4*n.soundHiddenSize()) +} + +func (n *nemotronHNanoVLModel) soundConvKernelSize() uint32 { + return cmp.Or(n.SoundConfig.ConvKernelSize, uint32(9)) +} + +func (n *nemotronHNanoVLModel) soundMelBins() uint32 { + return cmp.Or(n.SoundConfig.NumMelBins, uint32(128)) +} + +func (n *nemotronHNanoVLModel) soundSampleRate() uint32 { + return cmp.Or(n.SoundConfig.SamplingRate, uint32(16000)) +} + +func (n *nemotronHNanoVLModel) soundSubsamplingFactor() uint32 { + return cmp.Or(n.SoundConfig.SubsamplingFactor, uint32(8)) +} + +func (n *nemotronHNanoVLModel) soundSubsamplingConvChannels() uint32 { + return cmp.Or(n.SoundConfig.SubsamplingConvChannels, uint32(256)) +} + +func (n *nemotronHNanoVLModel) soundSubsamplingConvKernelSize() uint32 { + return cmp.Or(n.SoundConfig.SubsamplingConvKernelSize, uint32(3)) +} + +func (n *nemotronHNanoVLModel) soundSubsamplingConvStride() uint32 { + return cmp.Or(n.SoundConfig.SubsamplingConvStride, uint32(2)) +} + +func (n *nemotronHNanoVLModel) soundProjectionHiddenSize() uint32 { + return cmp.Or(n.SoundConfig.ProjectionHiddenSize, uint32(4096)) +} + +var ( + imageNetStandardMean = []float32{0.48145466, 0.4578275, 0.40821073} + imageNetStandardSTD = []float32{0.26862954, 0.26130258, 0.27577711} +) func (n *nemotronHModel) parseMore(_ fs.FS) error { if n.NumHiddenLayers == 0 { diff --git a/convert/convert_nemotron_h_test.go b/convert/convert_nemotron_h_test.go index db6a675fc..b74469768 100644 --- a/convert/convert_nemotron_h_test.go +++ b/convert/convert_nemotron_h_test.go @@ -217,6 +217,316 @@ func TestNemotronHLoadModelMetadata(t *testing.T) { } } +func TestNemotronHNanoVLLoadModelMetadata(t *testing.T) { + tempDir := t.TempDir() + + config := `{ + "architectures": ["NemotronH_Nano_VL_V2"], + "model_type": "NemotronH_Nano_VL_V2", + "max_sequence_length": 131072, + "force_image_size": 512, + "downsample_ratio": 0.5, + "patch_size": 16, + "use_thumbnail": true, + "img_context_token_id": 18, + "img_context_token": "", + "img_start_token": "", + "img_end_token": "", + "sound_context_token_id": 27, + "sound_context_token": "", + "vit_hidden_size": 1280, + "projector_hidden_size": 20480, + "norm_mean": [0.48145466, 0.4578275, 0.40821073], + "norm_std": [0.26862954, 0.26130258, 0.27577711], + "vision_config": { + "version": "radio_v2.5-h", + "patch_size": 16, + "max_resolution": 2048, + "separate_video_embedder": true + }, + "sound_config": { + "model_type": "parakeet", + "hidden_size": 1024, + "num_attention_heads": 8, + "num_hidden_layers": 24, + "intermediate_size": 4096, + "conv_kernel_size": 9, + "subsampling_conv_channels": 256, + "subsampling_conv_kernel_size": 3, + "subsampling_conv_stride": 2, + "subsampling_factor": 8, + "num_mel_bins": 128, + "projection_hidden_size": 4096, + "sampling_rate": 16000 + }, + "llm_config": { + "architectures": ["NemotronHForCausalLM"], + "model_type": "nemotron_h", + "num_hidden_layers": 4, + "hidden_size": 512, + "max_position_embeddings": 262144, + "num_attention_heads": 8, + "num_key_value_heads": 2, + "head_dim": 64, + "layer_norm_epsilon": 1e-5, + "conv_kernel": 4, + "ssm_state_size": 128, + "mamba_num_heads": 16, + "mamba_head_dim": 32, + "n_groups": 8, + "hybrid_override_pattern": "ME*M", + "n_routed_experts": 16, + "num_experts_per_tok": 4, + "moe_intermediate_size": 256 + } + }` + + if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tempDir, "preprocessor_config.json"), []byte(`{ + "image_size": 512, + "patch_size": 16, + "downsample_ratio": 0.5, + "max_num_tiles": 12, + "use_thumbnail": true, + "norm_mean": [0.48145466, 0.4578275, 0.40821073], + "norm_std": [0.26862954, 0.26130258, 0.27577711] + }`), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil { + t.Fatal(err) + } + + conv, tokenizer, err := LoadModelMetadata(os.DirFS(tempDir)) + if err != nil { + t.Fatal(err) + } + if _, ok := conv.(*nemotronHNanoVLModel); !ok { + t.Fatalf("unexpected converter type: %T", conv) + } + + kv := conv.KV(tokenizer) + if got, want := kv["general.architecture"], "nemotron_h_omni"; got != want { + t.Fatalf("unexpected architecture: got %v want %v", got, want) + } + if got, want := kv["context_length"], uint32(131072); got != want { + t.Fatalf("unexpected context length: got %v want %v", got, want) + } + if got, want := kv["vision.block_count"], uint32(32); got != want { + t.Fatalf("unexpected vision block count: got %v want %v", got, want) + } + if got, want := kv["vision.image_size"], uint32(512); got != want { + t.Fatalf("unexpected vision image size: got %v want %v", got, want) + } + if got, want := kv["vision.projector.scale_factor"], uint32(2); got != want { + t.Fatalf("unexpected projector scale factor: got %v want %v", got, want) + } + if got, want := kv["audio.block_count"], uint32(24); got != want { + t.Fatalf("unexpected audio block count: got %v want %v", got, want) + } + if got, want := kv["audio.sound_token_id"], uint32(27); got != want { + t.Fatalf("unexpected audio token id: got %v want %v", got, want) + } + if got, want := kv["audio.subsampling_factor"], uint32(8); got != want { + t.Fatalf("unexpected audio subsampling factor: got %v want %v", got, want) + } +} + +func TestNemotronHNanoOmniReasoningV3LoadModelMetadata(t *testing.T) { + tempDir := t.TempDir() + + config := `{ + "architectures": ["NemotronH_Nano_Omni_Reasoning_V3"], + "model_type": "NemotronH_Nano_Omni_Reasoning_V3", + "max_sequence_length": 131072, + "force_image_size": 512, + "downsample_ratio": 0.5, + "patch_size": 16, + "img_context_token_id": 18, + "img_context_token": "", + "img_start_token": "", + "img_end_token": "", + "sound_context_token_id": 27, + "sound_context_token": "", + "vit_hidden_size": 1280, + "projector_hidden_size": 4096, + "vision_config": { + "version": "radio_v2.5-h", + "patch_size": 16, + "min_num_patches": 1024, + "max_num_patches": 13312, + "args": { + "min_num_patches": 1024, + "max_num_patches": 13312 + } + }, + "sound_config": { + "model_type": "parakeet", + "hidden_size": 1024, + "num_attention_heads": 8, + "num_hidden_layers": 24, + "intermediate_size": 4096, + "conv_kernel_size": 9, + "subsampling_conv_channels": 256, + "subsampling_conv_kernel_size": 3, + "subsampling_conv_stride": 2, + "subsampling_factor": 8, + "num_mel_bins": 128, + "projection_hidden_size": 4096, + "sampling_rate": 16000 + }, + "llm_config": { + "architectures": ["NemotronHForCausalLM"], + "model_type": "nemotron_h", + "num_hidden_layers": 4, + "hidden_size": 512, + "max_position_embeddings": 262144, + "num_attention_heads": 8, + "num_key_value_heads": 2, + "head_dim": 64, + "layer_norm_epsilon": 1e-5, + "conv_kernel": 4, + "ssm_state_size": 128, + "mamba_num_heads": 16, + "mamba_head_dim": 32, + "n_groups": 8, + "hybrid_override_pattern": "ME*M", + "n_routed_experts": 16, + "num_experts_per_tok": 4, + "moe_intermediate_size": 256 + } + }` + + if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil { + t.Fatal(err) + } + + conv, tokenizer, err := LoadModelMetadata(os.DirFS(tempDir)) + if err != nil { + t.Fatal(err) + } + if _, ok := conv.(*nemotronHNanoVLModel); !ok { + t.Fatalf("unexpected converter type: %T", conv) + } + + kv := conv.KV(tokenizer) + if got, want := kv["general.architecture"], "nemotron_h_omni"; got != want { + t.Fatalf("unexpected architecture: got %v want %v", got, want) + } + if got, want := kv["vision.block_count"], uint32(32); got != want { + t.Fatalf("unexpected vision block count: got %v want %v", got, want) + } + if got, want := kv["vision.min_num_patches"], uint32(1024); got != want { + t.Fatalf("unexpected vision min patches: got %v want %v", got, want) + } + if got, want := kv["vision.max_num_patches"], uint32(13312); got != want { + t.Fatalf("unexpected vision max patches: got %v want %v", got, want) + } + if got, want := kv["audio.block_count"], uint32(24); got != want { + t.Fatalf("unexpected audio block count: got %v want %v", got, want) + } + if got, want := kv["audio.sound_token_id"], uint32(27); got != want { + t.Fatalf("unexpected audio token id: got %v want %v", got, want) + } +} + +func TestNemotronHNanoVLTensorsRetainVisionAndAudio(t *testing.T) { + m := &nemotronHNanoVLModel{ + LLMConfig: nemotronHModel{NGroups: 8}, + } + + in := []Tensor{ + &fakeTensor{ + name: "blk.0.ssm_a", + shape: []uint64{4}, + data: []float32{0, 1, 2, 3}, + }, + &fakeTensor{name: "v.blk.0.attn_qkv.weight", shape: []uint64{3840, 1280}}, + &fakeTensor{name: "v.position_embd", shape: []uint64{1, 16384, 1280}}, + &fakeTensor{name: "v.cls_embd", shape: []uint64{10, 1280}}, + &fakeTensor{name: "mm.norm.weight", shape: []uint64{5120}}, + &fakeTensor{name: "a.feature_extractor.fb", shape: []uint64{1, 128, 257}}, + &fakeTensor{name: "a.subsampling.dw1.weight", shape: []uint64{256, 1, 3, 3}}, + &fakeTensor{name: "a.blk.0.conv_dw.weight", shape: []uint64{1024, 1, 9}}, + &fakeTensor{name: "a.blk.0.conv_pw1.weight", shape: []uint64{2048, 1024, 1}}, + &fakeTensor{name: "a.blk.0.conv_bn.num_batches_tracked", shape: []uint64{1}}, + &fakeTensor{name: "mm.a.1.weight", shape: []uint64{4096, 1024}}, + } + + out := m.Tensors(in) + got := map[string][]uint64{} + for _, tns := range out { + got[tns.Name] = tns.Shape + } + + for _, name := range []string{ + "blk.0.ssm_a", + "v.blk.0.attn_q.weight", + "v.blk.0.attn_k.weight", + "v.blk.0.attn_v.weight", + "v.position_embd", + "v.cls_embd", + "mm.norm.weight", + "a.feature_extractor.fb", + "a.subsampling.dw1.weight", + "a.blk.0.conv_dw.weight", + "a.blk.0.conv_pw1.weight", + "mm.a.1.weight", + } { + if _, ok := got[name]; !ok { + t.Fatalf("expected tensor %q in output", name) + } + } + + if gotShape, want := got["blk.0.ssm_a"], []uint64{4, 1}; !slices.Equal(gotShape, want) { + t.Fatalf("unexpected ssm_a shape: got %v want %v", gotShape, want) + } + if gotShape, want := got["v.position_embd"], []uint64{16384, 1280}; !slices.Equal(gotShape, want) { + t.Fatalf("unexpected position embedding shape: got %v want %v", gotShape, want) + } + if gotShape, want := got["a.blk.0.conv_dw.weight"], []uint64{1024, 9}; !slices.Equal(gotShape, want) { + t.Fatalf("unexpected audio conv_dw shape: got %v want %v", gotShape, want) + } + if gotShape, want := got["a.blk.0.conv_pw1.weight"], []uint64{2048, 1024}; !slices.Equal(gotShape, want) { + t.Fatalf("unexpected audio conv_pw1 shape: got %v want %v", gotShape, want) + } + if _, ok := got["a.blk.0.conv_bn.num_batches_tracked"]; ok { + t.Fatal("audio batchnorm num_batches_tracked should be omitted") + } +} + +func TestNemotronHNanoVLReplacements(t *testing.T) { + m := &nemotronHNanoVLModel{} + r := strings.NewReplacer(m.Replacements()...) + + if got, want := r.Replace("language_model.backbone.layers.1.mixer.fc1_latent_proj.weight"), "blk.1.ffn_latent_in.weight"; got != want { + t.Fatalf("unexpected fc1 replacement: got %q want %q", got, want) + } + if got, want := r.Replace("language_model.lm_head.weight"), "output.weight"; got != want { + t.Fatalf("unexpected lm_head replacement: got %q want %q", got, want) + } + if got, want := r.Replace("vision_model.radio_model.model.blocks.0.attn.qkv.weight"), "v.blk.0.attn_qkv.weight"; got != want { + t.Fatalf("unexpected vision replacement: got %q want %q", got, want) + } + if got, want := r.Replace("mlp1.1.weight"), "mm.1.weight"; got != want { + t.Fatalf("unexpected projector replacement: got %q want %q", got, want) + } + if got, want := r.Replace("sound_encoder.encoder.layers.0.self_attn.q_proj.weight"), "a.blk.0.attn_q.weight"; got != want { + t.Fatalf("unexpected audio q_proj replacement: got %q want %q", got, want) + } + if got, want := r.Replace("sound_encoder.encoder.layers.0.conv.pointwise_conv1.weight"), "a.blk.0.conv_pw1.weight"; got != want { + t.Fatalf("unexpected audio conv replacement: got %q want %q", got, want) + } + if got, want := r.Replace("sound_projection.linear2.weight"), "mm.a.2.weight"; got != want { + t.Fatalf("unexpected audio projector replacement: got %q want %q", got, want) + } +} + func TestNemotronHReplacementsLatentProjections(t *testing.T) { m := &nemotronHModel{} r := strings.NewReplacer(m.Replacements()...) diff --git a/convert/reader.go b/convert/reader.go index ec2da4a23..0a723cec4 100644 --- a/convert/reader.go +++ b/convert/reader.go @@ -42,8 +42,10 @@ func (t tensorBase) Kind() uint32 { strings.HasSuffix(t.name, ".bias") || strings.HasSuffix(t.name, ".shortconv.conv.weight") || strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal - strings.HasPrefix(t.name, "a.conv1d.") || // audio SSCP conv weights must be F32 for im2col - strings.Contains(t.name, ".conv_dw.") || // audio depthwise conv weights must be F32 + strings.HasPrefix(t.name, "a.feature_extractor.") || // audio feature-extractor constants are read with BackendGet and must be real F32 values + strings.HasPrefix(t.name, "a.conv1d.") || // audio SSCP conv weights are kept F32 for im2col; this likely slows audio and should be revisited + strings.HasPrefix(t.name, "a.subsampling.") || // audio Parakeet subsampling weights are kept F32 for conv/linear stability; this likely slows audio and should be revisited + strings.Contains(t.name, ".conv_dw.") || // audio depthwise conv weights are kept F32; this likely slows audio and should be revisited t.name == "token_types.weight" || t.name == "v.positional_embedding_vlm" || t.name == "v.position_embd.weight" || diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index 6127ab566..b67fe4fcd 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -5,10 +5,12 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" "fmt" "io" "io/fs" "maps" + "math" "slices" "strings" @@ -23,6 +25,11 @@ type safetensorMetadata struct { } func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]Tensor, error) { + fp8Block, err := safetensorsFP8BlockSize(fsys) + if err != nil { + return nil, err + } + var ts []Tensor for _, p := range ps { f, err := fsys.Open(p) @@ -50,24 +57,47 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T names := make(map[string]struct{}, len(keys)) + fp8Scales, err := collectSafetensorsFP8Scales(n, headers) + if err != nil { + return nil, err + } + for _, key := range keys { if value := headers[key]; value.Type != "" { + if _, ok := fp8Scales.consumed[key]; ok { + continue + } + // Scalar tensors (e.g. clipped linear min/max) are 0-dim in safetensors. // Promote them to 1-dim so they can be stored in GGUF. if len(value.Shape) == 0 { value.Shape = []uint64{1} } + + var scale *safetensorScale + if value.Type == "F8_E4M3" { + if !fp8Block.ok { + return nil, fmt.Errorf("missing fp8 block size metadata for tensor %q", key) + } + scale = fp8Scales.byWeight[key] + if scale == nil { + return nil, fmt.Errorf("missing fp8 scale companion for tensor %q", key) + } + } + ggufName := replacer.Replace(key) if _, ok := names[ggufName]; ok { return nil, fmt.Errorf("duplicate tensor name '%s' was found for this model", ggufName) } names[ggufName] = struct{}{} ts = append(ts, safetensor{ - fs: fsys, - path: p, - dtype: value.Type, - offset: safetensorsPad(n, value.Offsets[0]), - size: safetensorsPad(n, value.Offsets[1]) - safetensorsPad(n, value.Offsets[0]), + fs: fsys, + path: p, + dtype: value.Type, + offset: safetensorsPad(n, value.Offsets[0]), + size: safetensorsPad(n, value.Offsets[1]) - safetensorsPad(n, value.Offsets[0]), + scale: scale, + fp8Block: fp8Block, tensorBase: &tensorBase{ name: ggufName, shape: value.Shape, @@ -85,12 +115,22 @@ func safetensorsPad(n, offset int64) int64 { return 8 + n + offset } -type safetensor struct { - fs fs.FS - path string +type safetensorScale struct { + name string dtype string + shape []uint64 offset int64 size int64 +} + +type safetensor struct { + fs fs.FS + path string + dtype string + offset int64 + size int64 + scale *safetensorScale + fp8Block safetensorFP8BlockSize *tensorBase } @@ -104,17 +144,26 @@ func (st safetensor) Kind() uint32 { kind != tensorKindFP32 { kind = tensorKindBF16 } + if st.dtype == "F8_E4M3" && kind != tensorKindFP32 { + kind = tensorKindBF16 + } return kind } +func (st safetensor) SourceDType() string { + return st.dtype +} + func (st safetensor) Clone() Tensor { return &safetensor{ - fs: st.fs, - path: st.path, - dtype: st.dtype, - offset: st.offset, - size: st.size, + fs: st.fs, + path: st.path, + dtype: st.dtype, + offset: st.offset, + size: st.size, + scale: st.scale.Clone(), + fp8Block: st.fp8Block, tensorBase: &tensorBase{ name: st.name, repacker: st.repacker, @@ -123,6 +172,19 @@ func (st safetensor) Clone() Tensor { } } +func (ss *safetensorScale) Clone() *safetensorScale { + if ss == nil { + return nil + } + return &safetensorScale{ + name: ss.name, + dtype: ss.dtype, + shape: slices.Clone(ss.shape), + offset: ss.offset, + size: ss.size, + } +} + func (st safetensor) WriteTo(w io.Writer) (int64, error) { f, err := st.fs.Open(st.path) if err != nil { @@ -180,6 +242,16 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) { } f32s = bfloat16.DecodeFloat32(u8s) + case "F8_E4M3": + u8s := make([]uint8, st.size) + if err = binary.Read(br, binary.LittleEndian, u8s); err != nil { + return 0, err + } + + f32s, err = st.decodeFP8E4M3(u8s) + if err != nil { + return 0, err + } default: return 0, fmt.Errorf("unknown data type: %s", st.dtype) } @@ -208,3 +280,334 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) { return 0, fmt.Errorf("unknown storage type: %d", st.Kind()) } } + +type safetensorsFP8Scales struct { + byWeight map[string]*safetensorScale + consumed map[string]struct{} +} + +func collectSafetensorsFP8Scales(n int64, headers map[string]safetensorMetadata) (safetensorsFP8Scales, error) { + scales := safetensorsFP8Scales{ + byWeight: make(map[string]*safetensorScale), + consumed: make(map[string]struct{}), + } + + for key, value := range headers { + if value.Type != "F8_E4M3" { + continue + } + + scaleKey, scaleValue, ok, err := safetensorsFP8Scale(key, headers) + if err != nil { + return safetensorsFP8Scales{}, err + } + if !ok { + continue + } + if _, ok := scales.consumed[scaleKey]; ok { + return safetensorsFP8Scales{}, fmt.Errorf("fp8 scale companion %q is used by multiple tensors", scaleKey) + } + + scales.byWeight[key] = &safetensorScale{ + name: scaleKey, + dtype: scaleValue.Type, + shape: slices.Clone(scaleValue.Shape), + offset: safetensorsPad(n, scaleValue.Offsets[0]), + size: safetensorsPad(n, scaleValue.Offsets[1]) - safetensorsPad(n, scaleValue.Offsets[0]), + } + scales.consumed[scaleKey] = struct{}{} + } + + return scales, nil +} + +func safetensorsFP8Scale(key string, headers map[string]safetensorMetadata) (string, safetensorMetadata, bool, error) { + candidates := safetensorsFP8ScaleCandidates(key) + + var scaleKey string + var scaleValue safetensorMetadata + if strings.HasSuffix(key, ".weight") { + // Keep support for compressed-tensors exports that place the scale name + // between the module path and weight suffix. + base := strings.TrimSuffix(key, ".weight") + candidates = appendUnique(candidates, base+".weight_scale") + candidates = appendUnique(candidates, base+".weight_scale_inv") + } + + for _, candidate := range candidates { + if value, ok := headers[candidate]; ok && value.Type != "" { + if scaleKey != "" { + return "", safetensorMetadata{}, false, fmt.Errorf("multiple fp8 scale companions for tensor %q: %q and %q", key, scaleKey, candidate) + } + scaleKey = candidate + scaleValue = value + } + } + if scaleKey == "" { + return "", safetensorMetadata{}, false, nil + } + + return scaleKey, scaleValue, true, nil +} + +func safetensorsFP8ScaleCandidates(key string) []string { + var candidates []string + candidates = appendUnique(candidates, key+"_scale") + candidates = appendUnique(candidates, key+"_scale_inv") + candidates = appendUnique(candidates, key+".scale") + candidates = appendUnique(candidates, key+".scale_inv") + return candidates +} + +func appendUnique(values []string, value string) []string { + if !slices.Contains(values, value) { + values = append(values, value) + } + return values +} + +type safetensorFP8BlockSize struct { + rows int + cols int + ok bool +} + +type safetensorsSourceQuantization struct { + QuantMethod string `json:"quant_method"` + Format string `json:"format"` + WeightBlockSize []int `json:"weight_block_size"` + ConfigGroups map[string]struct { + Format string `json:"format"` + Weights struct { + BlockStructure []int `json:"block_structure"` + NumBits int `json:"num_bits"` + Type string `json:"type"` + } `json:"weights"` + } `json:"config_groups"` +} + +type safetensorsModelConfig struct { + Quantization safetensorsSourceQuantization `json:"quantization"` + QuantizationConfig safetensorsSourceQuantization `json:"quantization_config"` + CompressionConfig safetensorsSourceQuantization `json:"compression_config"` + TextConfig struct { + Quantization safetensorsSourceQuantization `json:"quantization"` + QuantizationConfig safetensorsSourceQuantization `json:"quantization_config"` + CompressionConfig safetensorsSourceQuantization `json:"compression_config"` + } `json:"text_config"` +} + +func safetensorsFP8BlockSize(fsys fs.FS) (safetensorFP8BlockSize, error) { + bts, err := fs.ReadFile(fsys, "config.json") + if errors.Is(err, fs.ErrNotExist) { + return safetensorFP8BlockSize{}, nil + } + if err != nil { + return safetensorFP8BlockSize{}, err + } + bts = sanitizeNonFiniteJSON(bts) + + var cfg safetensorsModelConfig + if err := json.Unmarshal(bts, &cfg); err != nil { + return safetensorFP8BlockSize{}, fmt.Errorf("parse config.json fp8 metadata: %w", err) + } + + var blocks []safetensorFP8BlockSize + for _, q := range []safetensorsSourceQuantization{ + cfg.Quantization, + cfg.QuantizationConfig, + cfg.CompressionConfig, + cfg.TextConfig.Quantization, + cfg.TextConfig.QuantizationConfig, + cfg.TextConfig.CompressionConfig, + } { + if strings.EqualFold(q.QuantMethod, "fp8") && len(q.WeightBlockSize) == 2 { + block, err := newSafetensorFP8BlockSize(q.WeightBlockSize[0], q.WeightBlockSize[1]) + if err != nil { + return safetensorFP8BlockSize{}, err + } + blocks = append(blocks, block) + } + + if !strings.EqualFold(q.QuantMethod, "compressed-tensors") && !strings.EqualFold(q.Format, "float-quantized") { + continue + } + for _, group := range q.ConfigGroups { + if !strings.EqualFold(group.Format, "float-quantized") || + group.Weights.NumBits != 8 || + !strings.EqualFold(group.Weights.Type, "float") || + len(group.Weights.BlockStructure) != 2 { + continue + } + block, err := newSafetensorFP8BlockSize(group.Weights.BlockStructure[0], group.Weights.BlockStructure[1]) + if err != nil { + return safetensorFP8BlockSize{}, err + } + blocks = append(blocks, block) + } + } + + if len(blocks) == 0 { + return safetensorFP8BlockSize{}, nil + } + + block := blocks[0] + for _, other := range blocks[1:] { + if other.rows != block.rows || other.cols != block.cols { + return safetensorFP8BlockSize{}, fmt.Errorf("multiple fp8 block sizes in config.json: %dx%d and %dx%d", block.rows, block.cols, other.rows, other.cols) + } + } + return block, nil +} + +func newSafetensorFP8BlockSize(rows, cols int) (safetensorFP8BlockSize, error) { + if rows <= 0 || cols <= 0 { + return safetensorFP8BlockSize{}, fmt.Errorf("invalid fp8 block size %dx%d", rows, cols) + } + return safetensorFP8BlockSize{rows: rows, cols: cols, ok: true}, nil +} + +func (st safetensor) decodeFP8E4M3(data []byte) ([]float32, error) { + if st.scale == nil { + return nil, fmt.Errorf("missing fp8 scale companion for tensor %q", st.name) + } + if !st.fp8Block.ok { + return nil, fmt.Errorf("missing fp8 block size metadata for tensor %q", st.name) + } + if len(st.shape) != 2 { + return nil, fmt.Errorf("expected 2D fp8 tensor %q, got shape %v", st.name, st.shape) + } + + rows, cols := int(st.shape[0]), int(st.shape[1]) + if rows < 0 || cols < 0 || rows*cols != len(data) { + return nil, fmt.Errorf("fp8 tensor %q shape %v does not match %d bytes", st.name, st.shape, len(data)) + } + + scale, err := st.readScale() + if err != nil { + return nil, err + } + + if len(st.scale.shape) != 2 { + return nil, fmt.Errorf("expected 2D fp8 scale tensor %q, got shape %v", st.scale.name, st.scale.shape) + } + + blockRows := st.fp8Block.rows + blockCols := st.fp8Block.cols + scaleRows, scaleCols := int(st.scale.shape[0]), int(st.scale.shape[1]) + expectedRows := (rows + blockRows - 1) / blockRows + expectedCols := (cols + blockCols - 1) / blockCols + if scaleRows != expectedRows || scaleCols != expectedCols { + return nil, fmt.Errorf("unexpected fp8 scale shape %v for tensor %q shape %v; want [%d %d]", st.scale.shape, st.name, st.shape, expectedRows, expectedCols) + } + if len(scale) != scaleRows*scaleCols { + return nil, fmt.Errorf("fp8 scale tensor %q shape %v does not match decoded length %d", st.scale.name, st.scale.shape, len(scale)) + } + + f32s := make([]float32, len(data)) + for r := range rows { + scaleRow := r / blockRows + rowOffset := r * cols + for c := range cols { + f32s[rowOffset+c] = decodeFloat8E4M3FN(data[rowOffset+c]) * scale[scaleRow*scaleCols+c/blockCols] + } + } + + return f32s, nil +} + +func (st safetensor) readScale() ([]float32, error) { + r, err := st.sectionReader(st.scale.offset, st.scale.size) + if err != nil { + return nil, fmt.Errorf("failed to read fp8 scale tensor %q: %w", st.scale.name, err) + } + if closer, ok := r.(io.Closer); ok { + defer closer.Close() + } + br := bufio.NewReaderSize(r, min(32<<10, int(st.scale.size))) + + switch st.scale.dtype { + case "F32": + f32s := make([]float32, st.scale.size/4) + if err := binary.Read(br, binary.LittleEndian, f32s); err != nil { + return nil, err + } + return f32s, nil + case "F16": + u16s := make([]uint16, st.scale.size/2) + if err := binary.Read(br, binary.LittleEndian, u16s); err != nil { + return nil, err + } + f32s := make([]float32, len(u16s)) + for i := range u16s { + f32s[i] = float16.Frombits(u16s[i]).Float32() + } + return f32s, nil + case "BF16": + u8s := make([]uint8, st.scale.size) + if err := binary.Read(br, binary.LittleEndian, u8s); err != nil { + return nil, err + } + return bfloat16.DecodeFloat32(u8s), nil + default: + return nil, fmt.Errorf("unsupported fp8 scale dtype %q for tensor %q", st.scale.dtype, st.scale.name) + } +} + +func (st safetensor) sectionReader(offset, size int64) (io.Reader, error) { + f, err := st.fs.Open(st.path) + if err != nil { + return nil, err + } + + if readerAt, ok := f.(io.ReaderAt); ok { + return &readCloserReader{ + Reader: io.NewSectionReader(readerAt, offset, size), + Closer: f, + }, nil + } + if seeker, ok := f.(io.Seeker); ok { + if _, err := seeker.Seek(offset, io.SeekStart); err != nil { + f.Close() + return nil, err + } + return &readCloserReader{ + Reader: io.LimitReader(f, size), + Closer: f, + }, nil + } + if _, err := io.CopyN(io.Discard, f, offset); err != nil { + f.Close() + return nil, err + } + return &readCloserReader{ + Reader: io.LimitReader(f, size), + Closer: f, + }, nil +} + +type readCloserReader struct { + io.Reader + io.Closer +} + +func decodeFloat8E4M3FN(v byte) float32 { + sign := float32(1) + if v&0x80 != 0 { + sign = -1 + } + + exp := int((v >> 3) & 0x0f) + mant := int(v & 0x07) + if exp == 0 { + if mant == 0 { + return 0 * sign + } + return sign * float32(math.Ldexp(float64(mant)/8, -6)) + } + if exp == 0x0f && mant == 0x07 { + return float32(math.NaN()) + } + + return sign * float32(math.Ldexp(1+float64(mant)/8, exp-7)) +} diff --git a/convert/reader_test.go b/convert/reader_test.go index c3d094f10..632a81aff 100644 --- a/convert/reader_test.go +++ b/convert/reader_test.go @@ -3,8 +3,10 @@ package convert import ( "bytes" "encoding/binary" + "fmt" "os" "path/filepath" + "strings" "testing" "github.com/d4l3k/go-bfloat16" @@ -231,6 +233,222 @@ func TestSafetensors(t *testing.T) { } } +func TestSafetensorWriteToFP8E4M3(t *testing.T) { + root, err := os.OpenRoot(t.TempDir()) + if err != nil { + t.Fatal(err) + } + defer root.Close() + + path := filepath.Base(t.Name()) + f, err := root.Create(path) + if err != nil { + t.Fatal(err) + } + + // E4M3FN encodings for 1.0, 2.0, 0.5, and -1.0. + if _, err := f.Write([]byte{0x38, 0x40, 0x30, 0xb8}); err != nil { + t.Fatal(err) + } + if _, err := f.Write(bfloat16.EncodeFloat32([]float32{2})); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + + st := safetensor{ + fs: root.FS(), + path: path, + dtype: "F8_E4M3", + offset: 0, + size: 4, + fp8Block: safetensorFP8BlockSize{rows: 128, cols: 128, ok: true}, + scale: &safetensorScale{ + name: "linear.weight_scale", + dtype: "BF16", + shape: []uint64{1, 1}, + offset: 4, + size: 2, + }, + tensorBase: &tensorBase{ + name: "linear.weight", + shape: []uint64{2, 2}, + }, + } + + var b bytes.Buffer + if _, err := st.WriteTo(&b); err != nil { + t.Fatal(err) + } + + want := bfloat16.EncodeFloat32([]float32{2, 4, 1, -2}) + if diff := cmp.Diff(want, b.Bytes()); diff != "" { + t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff) + } +} + +func TestSafetensorWriteToFP8E4M3UsesConfiguredBlockSize(t *testing.T) { + root, err := os.OpenRoot(t.TempDir()) + if err != nil { + t.Fatal(err) + } + defer root.Close() + + path := filepath.Base(t.Name()) + f, err := root.Create(path) + if err != nil { + t.Fatal(err) + } + + if _, err := f.Write(bytes.Repeat([]byte{0x38}, 12)); err != nil { + t.Fatal(err) + } + if _, err := f.Write(bfloat16.EncodeFloat32([]float32{1, 2, 3, 4})); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + + st := safetensor{ + fs: root.FS(), + path: path, + dtype: "F8_E4M3", + offset: 0, + size: 12, + fp8Block: safetensorFP8BlockSize{rows: 2, cols: 3, ok: true}, + scale: &safetensorScale{ + name: "linear.weight_scale", + dtype: "BF16", + shape: []uint64{2, 2}, + offset: 12, + size: 8, + }, + tensorBase: &tensorBase{ + name: "linear.weight", + shape: []uint64{3, 4}, + }, + } + + var b bytes.Buffer + if _, err := st.WriteTo(&b); err != nil { + t.Fatal(err) + } + + want := bfloat16.EncodeFloat32([]float32{ + 1, 1, 1, 2, + 1, 1, 1, 2, + 3, 3, 3, 4, + }) + if diff := cmp.Diff(want, b.Bytes()); diff != "" { + t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff) + } +} + +func TestParseSafetensorsConsumesFP8ScaleCompanion(t *testing.T) { + tempDir := t.TempDir() + generateSafetensorTestData(t, tempDir, map[string]*tensorData{ + "linear.weight": { + Offsets: []int{0, 4}, + Type: "F8_E4M3", + Shape: []int{2, 2}, + }, + "linear.weight_scale": { + Offsets: []int{4, 6}, + Type: "BF16", + Shape: []int{1, 1}, + }, + }) + writeFP8BlockConfig(t, tempDir, 128, 128) + + tensors, err := parseSafetensors(os.DirFS(tempDir), strings.NewReplacer(), "model-00001-of-00001.safetensors") + if err != nil { + t.Fatal(err) + } + if len(tensors) != 1 { + t.Fatalf("expected one tensor, got %d", len(tensors)) + } + if got := tensors[0].Name(); got != "linear.weight" { + t.Fatalf("unexpected tensor name %q", got) + } + if got := tensors[0].Kind(); got != tensorKindBF16 { + t.Fatalf("unexpected fp8 converted kind %d, want %d", got, tensorKindBF16) + } +} + +func TestParseSafetensorsRejectsFP8WithoutBlockMetadata(t *testing.T) { + tempDir := t.TempDir() + generateSafetensorTestData(t, tempDir, map[string]*tensorData{ + "linear.weight": { + Offsets: []int{0, 4}, + Type: "F8_E4M3", + Shape: []int{2, 2}, + }, + "linear.weight_scale": { + Offsets: []int{4, 6}, + Type: "BF16", + Shape: []int{1, 1}, + }, + }) + + _, err := parseSafetensors(os.DirFS(tempDir), strings.NewReplacer(), "model-00001-of-00001.safetensors") + if err == nil || !strings.Contains(err.Error(), "missing fp8 block size metadata") { + t.Fatalf("expected missing fp8 block size metadata error, got %v", err) + } +} + +func TestParseSafetensorsRejectsAmbiguousFP8ScaleCompanion(t *testing.T) { + tempDir := t.TempDir() + generateSafetensorTestData(t, tempDir, map[string]*tensorData{ + "linear.weight": { + Offsets: []int{0, 4}, + Type: "F8_E4M3", + Shape: []int{2, 2}, + }, + "linear.weight_scale": { + Offsets: []int{4, 6}, + Type: "BF16", + Shape: []int{1, 1}, + }, + "linear.weight.scale": { + Offsets: []int{6, 8}, + Type: "BF16", + Shape: []int{1, 1}, + }, + }) + writeFP8BlockConfig(t, tempDir, 128, 128) + + _, err := parseSafetensors(os.DirFS(tempDir), strings.NewReplacer(), "model-00001-of-00001.safetensors") + if err == nil || !strings.Contains(err.Error(), "multiple fp8 scale companions") { + t.Fatalf("expected ambiguous fp8 scale companion error, got %v", err) + } +} + +func writeFP8BlockConfig(t *testing.T, dir string, rows, cols int) { + t.Helper() + + config := fmt.Sprintf(`{ + "architectures": ["GenericForCausalLM"], + "compression_config": { + "format": "float-quantized", + "config_groups": { + "group_0": { + "format": "float-quantized", + "weights": { + "type": "float", + "num_bits": 8, + "block_structure": [%d, %d] + } + } + } + } +}`, rows, cols) + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(config), 0o644); err != nil { + t.Fatal(err) + } +} + func TestSafetensorKind(t *testing.T) { tests := []struct { name string @@ -259,6 +477,17 @@ func TestSafetensorKind(t *testing.T) { }, expected: tensorKindFP16, }, + { + name: "BF16 audio feature extractor constants should return FP32", + st: safetensor{ + tensorBase: &tensorBase{ + name: "a.feature_extractor.fb", + shape: []uint64{1, 128, 257}, + }, + dtype: "BF16", + }, + expected: tensorKindFP32, + }, { name: "BF16 dtype with FP32 base kind should return FP32", st: safetensor{ diff --git a/convert/tensor.go b/convert/tensor.go index 68870744f..5dad64c8d 100644 --- a/convert/tensor.go +++ b/convert/tensor.go @@ -5,6 +5,7 @@ import ( "errors" "io" "iter" + "maps" "path" "slices" "strconv" @@ -153,3 +154,54 @@ func (g mergeGroup) WriteTo(w io.Writer) (int64, error) { return 0, nil } + +func sourceTensorKV(ts []*ggml.Tensor) KV { + sourceFP8 := make(map[string]struct{}) + for _, t := range ts { + if writerSourceDType(t.WriterTo) == "F8_E4M3" { + sourceFP8[t.Name] = struct{}{} + } + } + if len(sourceFP8) == 0 { + return nil + } + + return KV{ + "source_quantization": "hf_fp8", + "source_fp8_tensors": slices.Sorted(maps.Keys(sourceFP8)), + } +} + +type sourceDTypeTensor interface { + SourceDType() string +} + +func writerSourceDType(w io.WriterTo) string { + switch w := w.(type) { + case sourceDTypeTensor: + return w.SourceDType() + case mergeGroup: + if len(w) == 0 { + return "" + } + dtype := sourceDType(w[0]) + if dtype == "" { + return "" + } + for _, t := range w[1:] { + if sourceDType(t) != dtype { + return "" + } + } + return dtype + default: + return "" + } +} + +func sourceDType(t Tensor) string { + if t, ok := t.(sourceDTypeTensor); ok { + return t.SourceDType() + } + return "" +} diff --git a/convert/tensor_test.go b/convert/tensor_test.go index e0dc2350a..1dea476f2 100644 --- a/convert/tensor_test.go +++ b/convert/tensor_test.go @@ -21,7 +21,8 @@ type fakeTensor struct { shape []uint64 data []float32 - repacker Repacker + sourceDType string + repacker Repacker } func (f fakeTensor) Name() string { @@ -36,16 +37,21 @@ func (f fakeTensor) Kind() uint32 { return 0 } +func (f fakeTensor) SourceDType() string { + return f.sourceDType +} + func (f *fakeTensor) SetRepacker(fn Repacker) { f.repacker = fn } func (f fakeTensor) Clone() Tensor { return &fakeTensor{ - name: f.name, - shape: slices.Clone(f.shape), - data: slices.Clone(f.data), - repacker: f.repacker, + name: f.name, + shape: slices.Clone(f.shape), + data: slices.Clone(f.data), + sourceDType: f.sourceDType, + repacker: f.repacker, } } @@ -995,3 +1001,43 @@ func TestMergeOrder(t *testing.T) { }) } } + +func TestSourceTensorKVRecordsFP8OutputTensors(t *testing.T) { + fp8 := &fakeTensor{name: "linear.weight", shape: []uint64{2, 2}, sourceDType: "F8_E4M3"} + bf16 := &fakeTensor{name: "other.weight", shape: []uint64{2, 2}, sourceDType: "BF16"} + + kv := sourceTensorKV([]*ggml.Tensor{ + {Name: "blk.0.linear.weight", WriterTo: fp8}, + {Name: "blk.0.other.weight", WriterTo: bf16}, + }) + + if got := kv["source_quantization"]; got != "hf_fp8" { + t.Fatalf("source_quantization = %v, want hf_fp8", got) + } + got, ok := kv["source_fp8_tensors"].([]string) + if !ok { + t.Fatalf("source_fp8_tensors = %#v, want []string", kv["source_fp8_tensors"]) + } + if diff := cmp.Diff([]string{"blk.0.linear.weight"}, got); diff != "" { + t.Fatalf("source_fp8_tensors mismatch (-want +got):\n%s", diff) + } +} + +func TestSourceTensorKVRecordsMergedFP8OutputTensors(t *testing.T) { + fp8A := &fakeTensor{name: "expert.0.weight", shape: []uint64{2, 2}, sourceDType: "F8_E4M3"} + fp8B := &fakeTensor{name: "expert.1.weight", shape: []uint64{2, 2}, sourceDType: "F8_E4M3"} + bf16 := &fakeTensor{name: "expert.2.weight", shape: []uint64{2, 2}, sourceDType: "BF16"} + + kv := sourceTensorKV([]*ggml.Tensor{ + {Name: "ffn_exps.weight", WriterTo: mergeGroup{fp8A, fp8B}}, + {Name: "mixed_exps.weight", WriterTo: mergeGroup{fp8A, bf16}}, + }) + + got, ok := kv["source_fp8_tensors"].([]string) + if !ok { + t.Fatalf("source_fp8_tensors = %#v, want []string", kv["source_fp8_tensors"]) + } + if diff := cmp.Diff([]string{"ffn_exps.weight"}, got); diff != "" { + t.Fatalf("source_fp8_tensors mismatch (-want +got):\n%s", diff) + } +} diff --git a/convert/tokenizer.go b/convert/tokenizer.go index 213ef11fb..7f642331c 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -103,6 +103,8 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) t.Pre = "qwen2" case "00431aed57e696b747435f734d1e3b9b1bfd931a121fb5cac7129e97c181e9ba": t.Pre = "qwen35" + case "b92c0824a58e1d8dc3221cf3e12c433c3a86f57e46d57229993489f0798e7702": + t.Pre = "laguna" case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855": // noop, empty pretokenizer default: diff --git a/docs/docs.json b/docs/docs.json index 17884d992..38c083f1b 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -124,7 +124,8 @@ "/integrations/opencode", "/integrations/droid", "/integrations/goose", - "/integrations/pi" + "/integrations/pi", + "/integrations/poolside" ] }, { diff --git a/docs/integrations/index.mdx b/docs/integrations/index.mdx index 9bb9d33b5..94c8f326a 100644 --- a/docs/integrations/index.mdx +++ b/docs/integrations/index.mdx @@ -15,6 +15,7 @@ Coding assistants that can read, modify, and execute code in your projects. - [Droid](/integrations/droid) - [Goose](/integrations/goose) - [Pi](/integrations/pi) +- [Poolside](/integrations/poolside) ## Assistants diff --git a/docs/integrations/poolside.mdx b/docs/integrations/poolside.mdx new file mode 100644 index 000000000..846081832 --- /dev/null +++ b/docs/integrations/poolside.mdx @@ -0,0 +1,54 @@ +--- +title: Poolside +--- + +Poolside is Poolside's software agent for the terminal, built for enterprise development workflows. + +## Install + +Install [Poolside](https://github.com/poolsideai/pool): + +## Usage with Ollama + +### Quick setup + +```shell +ollama launch pool +``` + +### Run directly with a model + +```shell +ollama launch pool --model kimi-k2.6:cloud +``` + +### Pass arguments through to Poolside + +Arguments after `--` are passed directly to Poolside: + +```shell +ollama launch pool -- --help +``` + +## Manual setup + +Poolside connects to Ollama using the OpenAI-compatible API via environment variables. + +1. Set the environment variables: + +```shell +export POOLSIDE_STANDALONE_BASE_URL=http://localhost:11434/v1 +export POOLSIDE_API_KEY=ollama +``` + +2. Run Poolside with an Ollama model: + +```shell +pool -m kimi-k2.6:cloud +``` + +Or run with environment variables inline: + +```shell +POOLSIDE_STANDALONE_BASE_URL=http://localhost:11434/v1 POOLSIDE_API_KEY=ollama pool -m kimi-k2.6:cloud +``` diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index e23d345cd..64eebb78c 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -283,10 +283,11 @@ func (kv KV) OllamaEngineRequired() bool { "gemma3n", "gemma4", "gptoss", "gpt-oss", + "laguna", "llama4", "mistral3", "mllama", - "nemotron_h", "nemotron_h_moe", + "nemotron_h", "nemotron_h_moe", "nemotron_h_omni", "nomic-bert", "olmo3", "qwen25vl", @@ -897,7 +898,7 @@ func (f GGML) FlashAttention() bool { "lfm2", "lfm2moe", "mistral3", - "nemotron_h", "nemotron_h_moe", + "nemotron_h", "nemotron_h_moe", "nemotron_h_omni", "olmo3", "qwen3", "qwen3moe", "qwen35", "qwen35moe", diff --git a/model/models/laguna/model.go b/model/models/laguna/model.go new file mode 100644 index 000000000..82a913e58 --- /dev/null +++ b/model/models/laguna/model.go @@ -0,0 +1,444 @@ +package laguna + +import ( + "fmt" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" + "github.com/ollama/ollama/tokenizer" +) + +const ( + cacheTypeSWA = iota + cacheTypeCausal +) + +type Options struct { + hiddenSize int + headDim int + + numHeads []int + numKVHeads int + + eps float32 + + slidingWindow int + slidingWindowPattern []bool + + fullRopeDim int + fullRopeBase, fullRopeScale float32 + fullRopeOriginalContextLength int + fullRopeAttentionFactor float32 + fullRopeBetaFast float32 + fullRopeBetaSlow float32 + + swaRopeDim int + swaRopeBase, swaRopeScale float32 + numExperts, numExpertsUsed int + normTopKProb bool + routedScalingFactor float32 + decoderSparseStep int + denseLayers map[int]bool +} + +func (o *Options) numHeadsForLayer(layer int) int { + if layer < len(o.numHeads) && o.numHeads[layer] > 0 { + return o.numHeads[layer] + } + if len(o.numHeads) > 0 && o.numHeads[0] > 0 { + return o.numHeads[0] + } + return 1 +} + +func (o *Options) layerIsSliding(layer int) bool { + return layer < len(o.slidingWindowPattern) && o.slidingWindowPattern[layer] +} + +func (o *Options) layerUsesMoE(layer int) bool { + if o.numExperts == 0 || o.denseLayers[layer] { + return false + } + step := o.decoderSparseStep + if step <= 0 { + step = 1 + } + return (layer+1)%step == 0 +} + +func (o *Options) applyRotaryPositionEmbeddings(ctx ml.Context, layer int, states, positions ml.Tensor) ml.Tensor { + opts := []func(*rope.Options){rope.WithTypeNeoX()} + if o.layerIsSliding(layer) { + return nn.RoPE(ctx, states, positions, o.swaRopeDim, o.swaRopeBase, 1./o.swaRopeScale, opts...) + } + + opts = append(opts, + rope.WithOriginalContextLength(o.fullRopeOriginalContextLength), + rope.WithExtrapolationFactor(1), + rope.WithAttentionFactor(o.fullRopeAttentionFactor), + rope.WithBetaFast(o.fullRopeBetaFast), + rope.WithBetaSlow(o.fullRopeBetaSlow), + ) + return nn.RoPE(ctx, states, positions, o.fullRopeDim, o.fullRopeBase, 1./o.fullRopeScale, opts...) +} + +type Attention struct { + Query *nn.Linear `gguf:"attn_q"` + QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` + Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` + Value *nn.Linear `gguf:"attn_v"` + Gate *nn.Linear `gguf:"attn_g"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (sa *Attention) Forward(ctx ml.Context, layer int, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + batchSize := hiddenStates.Dim(1) + numHeads := opts.numHeadsForLayer(layer) + + query := sa.Query.Forward(ctx, hiddenStates) + key := sa.Key.Forward(ctx, hiddenStates) + value := sa.Value.Forward(ctx, hiddenStates) + gate := sa.Gate.Forward(ctx, hiddenStates) + + query = query.Reshape(ctx, opts.headDim, numHeads, batchSize) + key = key.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize) + value = value.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize) + + query = sa.QueryNorm.Forward(ctx, query, opts.eps) + key = sa.KeyNorm.Forward(ctx, key, opts.eps) + + query = opts.applyRotaryPositionEmbeddings(ctx, layer, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, layer, key, positions) + + attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), cache) + // Laguna applies the per-head gate softplus in float32, then casts back. + gate = gate.Cast(ctx, ml.DTypeF32).Softplus(ctx).Cast(ctx, attention.DType()) + attention = attention.Mul(ctx, gate.Reshape(ctx, 1, numHeads, batchSize)) + attention = attention.Reshape(ctx, opts.headDim*numHeads, batchSize) + return sa.Output.Forward(ctx, attention) +} + +type MLP interface { + Forward(ml.Context, ml.Tensor, *Options) ml.Tensor +} + +type dense struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type sparse struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` + SharedExpert *dense `gguf:",suf:_shexp"` + ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"` +} + +func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor { + if moe.ExpProbsBias != nil { + scores = scores.Add(ctx, moe.ExpProbsBias) + } + return scores.TopK(ctx, opts.numExpertsUsed) +} + +func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + residual := hiddenStates + + scores := moe.Router.Forward(ctx, hiddenStates).Cast(ctx, ml.DTypeF32).Sigmoid(ctx) + selectedExperts := moe.topKIndices(ctx, scores, opts) + routingWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, selectedExperts) + if opts.normTopKProb { + routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1)) + routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx)) + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1)) + } + routingWeights = routingWeights.Scale(ctx, float64(opts.routedScalingFactor)) + + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) + upStates := moe.Up.Forward(ctx, hiddenStates, selectedExperts) + hiddenStates = moe.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, upStates) + + experts := moe.Down.Forward(ctx, hiddenStates, selectedExperts) + experts = experts.Mul(ctx, routingWeights) + + nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) + } + + return nextStates.Add(ctx, moe.SharedExpert.Forward(ctx, residual, opts)) +} + +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + *Attention + + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP +} + +func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + residual := hiddenStates + hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = l.Attention.Forward(ctx, layer, hiddenStates, positions, cache, opts) + + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + residual = hiddenStates + + hiddenStates = l.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts) + return hiddenStates.Add(ctx, residual) +} + +type Model struct { + model.Base + tokenizer.Tokenizer + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + *Options +} + +func New(c fs.Config) (model.Model, error) { + if c.Bool("attention.sink_enabled") { + return nil, fmt.Errorf("laguna: SWA attention sinks are not supported") + } + if c.Uint("attention.gating_type") != 1 { + return nil, fmt.Errorf("laguna: unsupported attention gating type %d", c.Uint("attention.gating_type")) + } + if !c.Bool("attention.qk_norm") { + return nil, fmt.Errorf("laguna: Q/K RMSNorm is required") + } + if gating := c.Uint("expert_gating_func"); gating != 2 { + return nil, fmt.Errorf("laguna: unsupported expert gating function %d", gating) + } + + numLayers := int(c.Uint("block_count")) + opts := newOptions(c, numLayers) + layers := make([]Layer, numLayers) + for i := range layers { + if opts.layerUsesMoE(i) { + layers[i].MLP = &sparse{} + } else { + layers[i].MLP = &dense{} + } + } + + var pre []string + switch c.String("tokenizer.ggml.pre") { + case "laguna": + pre = []string{ + `(?:\r?\n)+(?!\r?\n)`, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, + } + default: + return nil, model.ErrUnsupportedTokenizer + } + + m := Model{ + Tokenizer: tokenizer.NewBytePairEncoding( + &tokenizer.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + pre..., + ), + Layers: layers, + Options: opts, + } + + m.Cache = kvcache.NewWrapperCache( + kvcache.NewSWACache(int32(opts.slidingWindow), m.Shift), + kvcache.NewCausalCache(m.Shift), + ) + return &m, nil +} + +func newOptions(c fs.Config, numLayers int) *Options { + denseLayers := make(map[int]bool) + for _, layer := range configUints(c, "dense_layers") { + denseLayers[int(layer)] = true + } + for i := range c.Uint("leading_dense_block_count") { + denseLayers[int(i)] = true + } + + fullRopeScale := c.Float("rope.scaling.factor", 1) + if fullRopeScale == 0 { + fullRopeScale = 1 + } + swaRopeScale := c.Float("rope.swa.scaling.factor", 1) + if swaRopeScale == 0 { + swaRopeScale = 1 + } + fullRopeType := c.String("rope.scaling.type") + fullRopeAttentionFactor := lagunaAttentionFactor(fullRopeType, fullRopeScale, c.Float("rope.scaling.attn_factor")) + + return &Options{ + hiddenSize: int(c.Uint("embedding_length")), + headDim: int(c.Uint("attention.key_length")), + numHeads: expandIntArray(configUints(c, "attention.head_count"), numLayers, c.Uint("attention.head_count", 1)), + numKVHeads: int(c.Uint("attention.head_count_kv")), + eps: c.Float("attention.layer_norm_rms_epsilon", 1e-6), + slidingWindow: int(c.Uint("attention.sliding_window", 512)), + slidingWindowPattern: slidingWindowPattern(c, numLayers), + fullRopeDim: int(c.Uint("rope.dimension_count", c.Uint("attention.key_length"))), + fullRopeBase: c.Float("rope.freq_base", 500000), + fullRopeScale: fullRopeScale, + fullRopeOriginalContextLength: int(c.Uint("rope.scaling.original_context_length", 4096)), + fullRopeAttentionFactor: fullRopeAttentionFactor, + fullRopeBetaFast: c.Float("rope.scaling.beta_fast", 64), + fullRopeBetaSlow: c.Float("rope.scaling.beta_slow", 1), + swaRopeDim: int(c.Uint("rope.swa.dimension_count", c.Uint("attention.key_length"))), + swaRopeBase: c.Float("rope.swa.freq_base", 10000), + swaRopeScale: swaRopeScale, + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + normTopKProb: c.Bool("expert_weights_norm", true), + routedScalingFactor: c.Float("expert_weights_scale", 1), + decoderSparseStep: int(c.Uint("decoder_sparse_step", 1)), + denseLayers: denseLayers, + } +} + +func lagunaAttentionFactor(ropeType string, scaleFactor, attentionFactor float32) float32 { + if attentionFactor != 0 { + return attentionFactor + } + if ropeType == "yarn" && scaleFactor > 1 { + return float32(0.1*math.Log(float64(scaleFactor)) + 1) + } + return 1 +} + +func slidingWindowPattern(c fs.Config, numLayers int) []bool { + pattern := c.Bools("attention.sliding_window_pattern") + if len(pattern) == numLayers { + return pattern + } + + layerTypes := configUints(c, "attention.layer_types") + if len(layerTypes) == numLayers { + pattern = make([]bool, numLayers) + for i, layerType := range layerTypes { + pattern[i] = layerType == 1 + } + return pattern + } + + return make([]bool, numLayers) +} + +func configUints(c fs.Config, key string) []uint32 { + keyExists := c.Value(c.Architecture()+"."+key) != nil || c.Value(key) != nil + if cc, ok := c.(interface { + Uints(string, ...[]uint32) []uint32 + }); ok { + if values := cc.Uints(key); len(values) > 0 && (keyExists || !(len(values) == 1 && values[0] == 0)) { + return values + } + } + + ints := c.Ints(key) + if len(ints) > 0 && (keyExists || !(len(ints) == 1 && ints[0] == 0)) { + values := make([]uint32, len(ints)) + for i, v := range ints { + values[i] = uint32(v) + } + return values + } + + if scalar := c.Uint(key); scalar != 0 { + return []uint32{scalar} + } + return nil +} + +func expandIntArray(values []uint32, n int, fallback uint32) []int { + if len(values) == 0 { + values = []uint32{fallback} + } + defaultValue := values[0] + if len(values) == 1 { + defaultValue = values[0] + } + + out := make([]int, n) + for i := range out { + if i < len(values) { + out[i] = int(values[i]) + } else { + out[i] = int(defaultValue) + } + } + return out +} + +func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return m.Options.applyRotaryPositionEmbeddings(ctx, layer, key, shift), nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + if m.Cache != nil { + m.Cache.SetLayer(i) + if wrapper, ok := m.Cache.(*kvcache.WrapperCache); ok { + cacheType := cacheTypeCausal + if m.Options.layerIsSliding(i) { + cacheType = cacheTypeSWA + } + wrapper.SetLayerType(cacheType) + } + } + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs = batch.Outputs + } + + hiddenStates = layer.Forward(ctx, i, hiddenStates, positions, outputs, m.Cache, m.Options) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.Output.Forward(ctx, hiddenStates), nil +} + +func init() { + model.Register("laguna", New) +} + +var _ model.Model = (*Model)(nil) diff --git a/model/models/laguna/model_test.go b/model/models/laguna/model_test.go new file mode 100644 index 000000000..ac7b336b2 --- /dev/null +++ b/model/models/laguna/model_test.go @@ -0,0 +1,237 @@ +package laguna + +import ( + "iter" + "math" + "testing" +) + +type testConfig map[string]any + +func (c testConfig) Architecture() string { return "laguna" } + +func (c testConfig) key(key string) string { + switch { + case len(key) >= len("tokenizer.") && key[:len("tokenizer.")] == "tokenizer.": + return key + case len(key) >= len("general.") && key[:len("general.")] == "general.": + return key + default: + return "laguna." + key + } +} + +func (c testConfig) String(key string, defaultValue ...string) string { + if v, ok := c[c.key(key)].(string); ok { + return v + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" +} + +func (c testConfig) Uint(key string, defaultValue ...uint32) uint32 { + switch v := c[c.key(key)].(type) { + case uint32: + return v + case int: + return uint32(v) + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return 0 +} + +func (c testConfig) Float(key string, defaultValue ...float32) float32 { + if v, ok := c[c.key(key)].(float32); ok { + return v + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return 0 +} + +func (c testConfig) Bool(key string, defaultValue ...bool) bool { + if v, ok := c[c.key(key)].(bool); ok { + return v + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return false +} + +func (c testConfig) Strings(key string, defaultValue ...[]string) []string { + if v, ok := c[c.key(key)].([]string); ok { + return v + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return nil +} + +func (c testConfig) Ints(key string, defaultValue ...[]int32) []int32 { + if v, ok := c[c.key(key)].([]int32); ok { + return v + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return nil +} + +func (c testConfig) Uints(key string, defaultValue ...[]uint32) []uint32 { + if v, ok := c[c.key(key)].([]uint32); ok { + return v + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return nil +} + +func (c testConfig) Floats(key string, defaultValue ...[]float32) []float32 { + if v, ok := c[c.key(key)].([]float32); ok { + return v + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return nil +} + +func (c testConfig) Bools(key string, defaultValue ...[]bool) []bool { + if v, ok := c[c.key(key)].([]bool); ok { + return v + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return nil +} + +func (c testConfig) Len() int { return len(c) } + +func (c testConfig) Keys() iter.Seq[string] { + return func(yield func(string) bool) { + for key := range c { + if !yield(key) { + return + } + } + } +} + +func (c testConfig) Value(key string) any { return c[key] } + +func TestNewOptionsLayerConfig(t *testing.T) { + cfg := testConfig{ + "laguna.block_count": uint32(4), + "laguna.embedding_length": uint32(128), + "laguna.attention.key_length": uint32(16), + "laguna.attention.head_count": []uint32{8, 16, 16, 16}, + "laguna.attention.head_count_kv": uint32(4), + "laguna.attention.layer_norm_rms_epsilon": float32(1e-6), + "laguna.attention.sliding_window": uint32(512), + "laguna.attention.sliding_window_pattern": []bool{false, true, true, true}, + "laguna.rope.dimension_count": uint32(8), + "laguna.rope.freq_base": float32(500000), + "laguna.rope.scaling.factor": float32(32), + "laguna.rope.scaling.original_context_length": uint32(4096), + "laguna.rope.swa.dimension_count": uint32(16), + "laguna.rope.swa.freq_base": float32(10000), + "laguna.expert_count": uint32(32), + "laguna.expert_used_count": uint32(4), + "laguna.expert_weights_norm": true, + "laguna.expert_weights_scale": float32(2.5), + "laguna.decoder_sparse_step": uint32(1), + "laguna.dense_layers": []uint32{0}, + } + + opts := newOptions(cfg, 4) + if got := opts.numHeadsForLayer(0); got != 8 { + t.Fatalf("layer 0 heads = %d, want 8", got) + } + if got := opts.numHeadsForLayer(1); got != 16 { + t.Fatalf("layer 1 heads = %d, want 16", got) + } + if opts.layerIsSliding(0) { + t.Fatal("layer 0 should be full attention") + } + if !opts.layerIsSliding(1) { + t.Fatal("layer 1 should be sliding attention") + } + if opts.layerUsesMoE(0) { + t.Fatal("layer 0 should be dense") + } + if !opts.layerUsesMoE(1) { + t.Fatal("layer 1 should use MoE") + } + if opts.fullRopeDim != 8 || opts.swaRopeDim != 16 { + t.Fatalf("rope dims = full %d swa %d, want 8/16", opts.fullRopeDim, opts.swaRopeDim) + } +} + +func TestNewOptionsYarnAttentionFactorFallback(t *testing.T) { + cfg := testConfig{ + "laguna.block_count": uint32(1), + "laguna.embedding_length": uint32(128), + "laguna.attention.key_length": uint32(16), + "laguna.attention.head_count": uint32(8), + "laguna.attention.head_count_kv": uint32(4), + "laguna.rope.scaling.type": "yarn", + "laguna.rope.scaling.factor": float32(32), + } + + opts := newOptions(cfg, 1) + want := float32(0.1*math.Log(32) + 1) + if got := opts.fullRopeAttentionFactor; math.Abs(float64(got-want)) > 1e-6 { + t.Fatalf("fullRopeAttentionFactor = %v, want %v", got, want) + } +} + +func TestNewRejectsUnsupportedLagunaVariants(t *testing.T) { + tests := []struct { + name string + cfg testConfig + }{ + { + name: "attention sinks", + cfg: testConfig{ + "laguna.attention.sink_enabled": true, + }, + }, + { + name: "non per-head gate", + cfg: testConfig{ + "laguna.attention.gating_type": uint32(0), + }, + }, + { + name: "missing qk norm", + cfg: testConfig{ + "laguna.attention.gating_type": uint32(1), + }, + }, + { + name: "non sigmoid experts", + cfg: testConfig{ + "laguna.attention.gating_type": uint32(1), + "laguna.attention.qk_norm": true, + "laguna.expert_gating_func": uint32(1), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, err := New(tt.cfg); err == nil { + t.Fatal("expected unsupported variant error") + } + }) + } +} diff --git a/model/models/models.go b/model/models/models.go index 22439f4a2..31a2de444 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -11,6 +11,7 @@ import ( _ "github.com/ollama/ollama/model/models/glm4moelite" _ "github.com/ollama/ollama/model/models/glmocr" _ "github.com/ollama/ollama/model/models/gptoss" + _ "github.com/ollama/ollama/model/models/laguna" _ "github.com/ollama/ollama/model/models/lfm2" _ "github.com/ollama/ollama/model/models/llama" _ "github.com/ollama/ollama/model/models/llama4" diff --git a/model/models/nemotronh/imageproc.go b/model/models/nemotronh/imageproc.go new file mode 100644 index 000000000..9a0336c3e --- /dev/null +++ b/model/models/nemotronh/imageproc.go @@ -0,0 +1,355 @@ +package nemotronh + +import ( + "errors" + "image" + "math" + "slices" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/model/imageproc" +) + +type ImageProcessor struct { + imageSize int + patchSize int + numChannels int + maxTiles int + minNumPatches int + maxNumPatches int + useThumbnail bool + projectorScale int + imageMean [3]float32 + imageStd [3]float32 +} + +type processedVisionTile struct { + data []float32 + size image.Point +} + +func newImageProcessor(c fs.Config) ImageProcessor { + mean := c.Floats("vision.image_mean") + std := c.Floats("vision.image_std") + + processor := ImageProcessor{ + imageSize: int(c.Uint("vision.image_size", 512)), + patchSize: int(c.Uint("vision.patch_size", 16)), + numChannels: int(c.Uint("vision.num_channels", 3)), + maxTiles: int(c.Uint("vision.max_tiles", 12)), + minNumPatches: int(c.Uint("vision.min_num_patches")), + maxNumPatches: int(c.Uint("vision.max_num_patches")), + useThumbnail: c.Bool("vision.use_thumbnail", true), + projectorScale: int(c.Uint("vision.projector.scale_factor", 2)), + imageMean: imageproc.ClipDefaultMean, + imageStd: imageproc.ClipDefaultSTD, + } + + if len(mean) >= 3 { + processor.imageMean = [3]float32{mean[0], mean[1], mean[2]} + } + if len(std) >= 3 { + processor.imageStd = [3]float32{std[0], std[1], std[2]} + } + if processor.imageSize <= 0 { + processor.imageSize = 512 + } + if processor.patchSize <= 0 { + processor.patchSize = 16 + } + if processor.numChannels <= 0 { + processor.numChannels = 3 + } + if processor.maxTiles <= 0 { + processor.maxTiles = 12 + } + if processor.projectorScale <= 0 { + processor.projectorScale = 2 + } + + return processor +} + +func (p ImageProcessor) ProcessImage(img image.Image) ([]processedVisionTile, error) { + img = imageproc.Composite(img) + if p.useDynamicResolution() { + return p.processDynamicImage(img) + } + + return p.processTiledImage(img), nil +} + +func (p ImageProcessor) useDynamicResolution() bool { + return p.minNumPatches > 0 || p.maxNumPatches > 0 +} + +func (p ImageProcessor) processTiledImage(img image.Image) []processedVisionTile { + bounds := img.Bounds() + origWidth := bounds.Dx() + origHeight := bounds.Dy() + targetRatios := nemotronTargetRatios(p.maxTiles) + gridWidth, gridHeight := findClosestAspectRatio(float64(origWidth)/float64(origHeight), targetRatios, origWidth, origHeight, p.imageSize) + + targetWidth := p.imageSize * gridWidth + targetHeight := p.imageSize * gridHeight + resized := resizeImageBicubicCHW(img, targetWidth, targetHeight) + + tiles := make([]processedVisionTile, 0, gridWidth*gridHeight+1) + for row := range gridHeight { + for col := range gridWidth { + tile := cropCHWRegion( + resized, + targetWidth, + targetHeight, + p.numChannels, + col*p.imageSize, + row*p.imageSize, + p.imageSize, + p.imageSize, + ) + tiles = append(tiles, processedVisionTile{ + data: normalizeVisionCHW(tile, p.imageMean, p.imageStd), + size: image.Point{X: p.imageSize, Y: p.imageSize}, + }) + } + } + + if p.useThumbnail && len(tiles) > 1 { + thumbnail := resizeImageBicubicCHW(img, p.imageSize, p.imageSize) + tiles = append(tiles, processedVisionTile{ + data: normalizeVisionCHW(thumbnail, p.imageMean, p.imageStd), + size: image.Point{X: p.imageSize, Y: p.imageSize}, + }) + } + + return tiles +} + +func (p ImageProcessor) processDynamicImage(img image.Image) ([]processedVisionTile, error) { + bounds := img.Bounds() + origWidth := bounds.Dx() + origHeight := bounds.Dy() + patchesWidth, patchesHeight := p.dynamicPatchGrid(origWidth, origHeight) + if patchesWidth <= 0 || patchesHeight <= 0 { + return nil, errors.New("nemotron_h_omni: invalid dynamic image patch grid") + } + + targetWidth := patchesWidth * p.patchSize + targetHeight := patchesHeight * p.patchSize + resized := resizeImageBicubicCHW(img, targetWidth, targetHeight) + + return []processedVisionTile{{ + data: normalizeVisionCHW(resized, p.imageMean, p.imageStd), + size: image.Point{X: targetWidth, Y: targetHeight}, + }}, nil +} + +func (p ImageProcessor) dynamicPatchGrid(origWidth, origHeight int) (int, int) { + patchesHeight := max(1, int(math.Round(float64(origHeight)/float64(p.patchSize)+0.5))) + patchesWidth := max(1, int(math.Round(float64(origWidth)/float64(p.patchSize)+0.5))) + + patches := patchesHeight * patchesWidth + currentNumPatchesAvailable := p.maxNumPatches + if currentNumPatchesAvailable <= 0 { + currentNumPatchesAvailable = max(patches, p.minNumPatches) + } + + factor := math.Min(math.Sqrt(float64(currentNumPatchesAvailable)/float64(patches)), 1.0) + targetPatchesHeight := max(1, int(math.Floor(factor*float64(patchesHeight)))) + targetPatchesWidth := max(1, int(math.Floor(factor*float64(patchesWidth)))) + + if currentNumPatchesAvailable > p.minNumPatches && targetPatchesHeight*targetPatchesWidth < p.minNumPatches { + upFactor := math.Sqrt(float64(p.minNumPatches) / float64(targetPatchesHeight*targetPatchesWidth)) + targetPatchesHeight = int(math.Ceil(upFactor * float64(targetPatchesHeight))) + targetPatchesWidth = int(math.Ceil(upFactor * float64(targetPatchesWidth))) + } + + targetPatchesHeight = roundPatchGridForPixelShuffle(targetPatchesHeight, targetPatchesWidth, currentNumPatchesAvailable, p.projectorScale) + targetPatchesWidth = roundPatchGridForPixelShuffle(targetPatchesWidth, targetPatchesHeight, currentNumPatchesAvailable, p.projectorScale) + + return targetPatchesWidth, targetPatchesHeight +} + +func roundPatchGridForPixelShuffle(v, other, maxPatches, divisor int) int { + if divisor <= 1 { + return v + } + rem := v % divisor + if rem == 0 { + return v + } + + inc := divisor - rem + if (v+inc)*other <= maxPatches { + return v + inc + } + return max(divisor, v-rem) +} + +type nemotronImageRatio struct { + width int + height int +} + +func nemotronTargetRatios(maxTiles int) []nemotronImageRatio { + targetRatios := make([]nemotronImageRatio, 0, maxTiles*maxTiles) + for n := 1; n <= maxTiles; n++ { + for w := 1; w <= n; w++ { + for h := 1; h <= n; h++ { + if w*h > maxTiles { + continue + } + targetRatios = append(targetRatios, nemotronImageRatio{width: w, height: h}) + } + } + } + + unique := targetRatios[:0] + for _, ratio := range targetRatios { + if slices.Contains(unique, ratio) { + continue + } + unique = append(unique, ratio) + } + + slices.SortFunc(unique, func(a, b nemotronImageRatio) int { + return a.width*a.height - b.width*b.height + }) + + return unique +} + +func findClosestAspectRatio(aspectRatio float64, targetRatios []nemotronImageRatio, width, height, imageSize int) (int, int) { + bestRatio := nemotronImageRatio{width: 1, height: 1} + bestRatioDiff := math.MaxFloat64 + area := width * height + + for _, ratio := range targetRatios { + targetAspectRatio := float64(ratio.width) / float64(ratio.height) + ratioDiff := math.Abs(aspectRatio - targetAspectRatio) + if ratioDiff < bestRatioDiff { + bestRatioDiff = ratioDiff + bestRatio = ratio + continue + } + + if ratioDiff == bestRatioDiff && area > int(0.5*float64(imageSize*imageSize*ratio.width*ratio.height)) { + bestRatio = ratio + } + } + + return bestRatio.width, bestRatio.height +} + +func resizeImageBicubicCHW(img image.Image, outW, outH int) []float32 { + bounds := img.Bounds() + inW := bounds.Dx() + inH := bounds.Dy() + src := make([]float32, 3*inW*inH) + + for y := range inH { + for x := range inW { + r, g, b, _ := img.At(bounds.Min.X+x, bounds.Min.Y+y).RGBA() + src[y*inW+x] = float32(r>>8) / 255.0 + src[inW*inH+y*inW+x] = float32(g>>8) / 255.0 + src[2*inW*inH+y*inW+x] = float32(b>>8) / 255.0 + } + } + + dst := make([]float32, 3*outW*outH) + scaleX := float64(inW) / float64(outW) + scaleY := float64(inH) / float64(outH) + + for oy := range outH { + srcY := scaleY*(float64(oy)+0.5) - 0.5 + yBase := int(math.Floor(srcY)) + yFrac := clampUnit(srcY - float64(yBase)) + wy := torchBicubicWeights(yFrac) + + for ox := range outW { + srcX := scaleX*(float64(ox)+0.5) - 0.5 + xBase := int(math.Floor(srcX)) + xFrac := clampUnit(srcX - float64(xBase)) + wx := torchBicubicWeights(xFrac) + + for c := range 3 { + var sum float64 + channelOffset := c * inW * inH + for ky := range 4 { + iy := clampIndex(yBase-1+ky, 0, inH-1) + for kx := range 4 { + ix := clampIndex(xBase-1+kx, 0, inW-1) + sum += float64(src[channelOffset+iy*inW+ix]) * wy[ky] * wx[kx] + } + } + dst[c*outW*outH+oy*outW+ox] = float32(sum) + } + } + } + + return dst +} + +func cropCHWRegion(values []float32, width, height, channels, left, top, cropW, cropH int) []float32 { + out := make([]float32, channels*cropW*cropH) + channelSize := width * height + cropSize := cropW * cropH + for c := range channels { + srcBase := c * channelSize + dstBase := c * cropSize + for y := range cropH { + copy(out[dstBase+y*cropW:dstBase+(y+1)*cropW], values[srcBase+(top+y)*width+left:srcBase+(top+y)*width+left+cropW]) + } + } + return out +} + +func normalizeVisionCHW(values []float32, mean, std [3]float32) []float32 { + out := make([]float32, len(values)) + channelSize := len(values) / 3 + for c := range 3 { + base := c * channelSize + for i := range channelSize { + out[base+i] = (values[base+i] - mean[c]) / std[c] + } + } + return out +} + +func torchBicubicWeights(t float64) [4]float64 { + const a = -0.75 + return [4]float64{ + bicubicConvolution2(t+1.0, a), + bicubicConvolution1(t, a), + bicubicConvolution1(1.0-t, a), + bicubicConvolution2(2.0-t, a), + } +} + +func bicubicConvolution1(x, a float64) float64 { + return ((a+2)*x-(a+3))*x*x + 1 +} + +func bicubicConvolution2(x, a float64) float64 { + return ((a*x-5*a)*x+8*a)*x - 4*a +} + +func clampUnit(v float64) float64 { + if v < 0 { + return 0 + } + if v > 1 { + return 1 + } + return v +} + +func clampIndex(v, lo, hi int) int { + if v < lo { + return lo + } + if v > hi { + return hi + } + return v +} diff --git a/model/models/nemotronh/model.go b/model/models/nemotronh/model.go index 33220fe9b..d2142db47 100644 --- a/model/models/nemotronh/model.go +++ b/model/models/nemotronh/model.go @@ -117,9 +117,7 @@ func Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil } -func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) - +func (m *Model) forwardHiddenStates(ctx ml.Context, batch input.Batch, hiddenStates ml.Tensor) (ml.Tensor, error) { cache := m.Cache.(*HybridCache) for i, layer := range m.Layers { @@ -137,11 +135,24 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } } - hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil +} + +func (m *Model) forwardLogits(ctx ml.Context, batch input.Batch, hiddenStates ml.Tensor) (ml.Tensor, error) { + hiddenStates, err := m.forwardHiddenStates(ctx, batch, hiddenStates) + if err != nil { + return nil, err + } + return m.Output.Forward(ctx, hiddenStates), nil } -func New(c fs.Config) (model.Model, error) { +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + return m.forwardLogits(ctx, batch, hiddenStates) +} + +func newTextModel(c fs.Config) (*Model, error) { numLayers := int(c.Uint("block_count")) layers := make([]Layer, numLayers) @@ -306,6 +317,10 @@ func New(c fs.Config) (model.Model, error) { return &m, nil } +func New(c fs.Config) (model.Model, error) { + return newTextModel(c) +} + func init() { model.Register("nemotron_h", New) model.Register("nemotron_h_moe", New) diff --git a/model/models/nemotronh/model_audio.go b/model/models/nemotronh/model_audio.go new file mode 100644 index 000000000..8e1a005ef --- /dev/null +++ b/model/models/nemotronh/model_audio.go @@ -0,0 +1,511 @@ +package nemotronh + +import ( + "math" + "sync" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +type AudioOptions struct { + hiddenSize int + numHeads int + headDim int + intermediateSize int + convKernelSize int + melBins int + sampleRate int + subsamplingKernel int + subsamplingStride int + scaleInput bool + eps float32 +} + +type AudioFeatureExtractor struct { + FB ml.Tensor `gguf:"fb"` + Window ml.Tensor `gguf:"window"` + + mu sync.Mutex + fb []float32 + window []float32 + fbShape [2]int +} + +func (f *AudioFeatureExtractor) windowAndFilters(melBins, freqBins, sampleRate int) ([]float32, []float32) { + if f == nil { + return defaultParakeetWindow(), buildSlaneyMelFilterBank(freqBins, melBins, sampleRate) + } + + f.mu.Lock() + defer f.mu.Unlock() + + if f.window == nil { + if f.Window != nil { + if values := f.Window.BackendGet(); len(values) == parakeetWinLength { + f.window = values + } + } + if f.window == nil { + f.window = defaultParakeetWindow() + } + } + + if f.fb == nil { + if f.FB != nil { + if values := f.FB.BackendGet(); len(values) == melBins*freqBins { + f.fb = values + f.fbShape = [2]int{melBins, freqBins} + } + } + if f.fb == nil { + f.fb = buildSlaneyMelFilterBank(freqBins, melBins, sampleRate) + f.fbShape = [2]int{melBins, freqBins} + } + } + + return f.window, f.fb +} + +type AudioSubsampling struct { + Conv0 *nn.Conv2D `gguf:"conv0"` + DW1 *AudioDepthwiseConv2D `gguf:"dw1"` + PW1 *nn.Conv2D `gguf:"pw1"` + DW2 *AudioDepthwiseConv2D `gguf:"dw2"` + PW2 *nn.Conv2D `gguf:"pw2"` + + Linear *nn.Linear `gguf:"linear"` +} + +type AudioDepthwiseConv2D struct { + Weight ml.Tensor `gguf:"weight"` + Bias ml.Tensor `gguf:"bias"` +} + +type AudioFeedForward struct { + Up *nn.Linear `gguf:"up"` + Down *nn.Linear `gguf:"down"` +} + +type AudioSelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_out"` + RelativeKey *nn.Linear `gguf:"attn_rel_k"` + BiasU ml.Tensor `gguf:"attn_bias_u"` + BiasV ml.Tensor `gguf:"attn_bias_v"` +} + +type AudioConvolutionModule struct { + Pointwise1 *nn.Linear `gguf:"conv_pw1"` + Depthwise ml.Tensor `gguf:"conv_dw.weight"` + BatchNorm *AudioBatchNorm1D `gguf:"conv_bn"` + Pointwise2 *nn.Linear `gguf:"conv_pw2"` +} + +type AudioBatchNorm1D struct { + Weight ml.Tensor `gguf:"weight"` + Bias ml.Tensor `gguf:"bias"` + RunningMean ml.Tensor `gguf:"running_mean"` + RunningVar ml.Tensor `gguf:"running_var"` +} + +type AudioLayer struct { + FFN1Norm *nn.LayerNorm `gguf:"ffn1_norm"` + FFN1Up *nn.Linear `gguf:"ffn1_up"` + FFN1Down *nn.Linear `gguf:"ffn1_down"` + + AttentionNorm *nn.LayerNorm `gguf:"attn_norm"` + Attention *AudioSelfAttention + + ConvNorm *nn.LayerNorm `gguf:"conv_norm"` + Conv *AudioConvolutionModule + + FFN2Norm *nn.LayerNorm `gguf:"ffn2_norm"` + FFN2Up *nn.Linear `gguf:"ffn2_up"` + FFN2Down *nn.Linear `gguf:"ffn2_down"` + + OutputNorm *nn.LayerNorm `gguf:"out_norm"` +} + +type AudioModel struct { + FeatureExtractor *AudioFeatureExtractor `gguf:"feature_extractor"` + Subsampling *AudioSubsampling `gguf:"subsampling"` + Layers []AudioLayer `gguf:"blk"` + + *AudioOptions +} + +type AudioProjector struct { + Norm *nn.RMSNorm `gguf:"norm"` + Linear1 *nn.Linear `gguf:"1"` + Linear2 *nn.Linear `gguf:"2"` +} + +func (p *AudioProjector) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor { + x = p.Norm.Forward(ctx, x, eps) + x = audioF32(ctx, p.Linear1.Forward(ctx, x)) + x = x.RELU(ctx) + x = x.Mul(ctx, x) + return audioF32(ctx, p.Linear2.Forward(ctx, x)) +} + +func (m *AudioModel) ForwardAudio(ctx ml.Context, melFeatures ml.Tensor, validFrames int, projector *AudioProjector) ml.Tensor { + x := melFeatures.Reshape(ctx, melFeatures.Dim(0), melFeatures.Dim(1), 1, 1) + validLen := validFrames + + x = forwardAudioConv2D(ctx, m.Subsampling.Conv0, x, m.subsamplingStride, m.subsamplingStride, audioConvPadding(m.subsamplingKernel), audioConvPadding(m.subsamplingKernel), 1, 1) + x = x.RELU(ctx) + validLen = convOutputLength(validLen, m.subsamplingKernel, m.subsamplingStride, audioConvPadding(m.subsamplingKernel)) + x = applyAudioTimeMask(ctx, x, validLen) + + x = forwardAudioDepthwiseConv2D(ctx, m.Subsampling.DW1, x, m.subsamplingStride, m.subsamplingStride, audioConvPadding(m.subsamplingKernel), audioConvPadding(m.subsamplingKernel), 1, 1) + x = forwardAudioConv2D(ctx, m.Subsampling.PW1, x, 1, 1, 0, 0, 1, 1) + x = x.RELU(ctx) + validLen = convOutputLength(validLen, m.subsamplingKernel, m.subsamplingStride, audioConvPadding(m.subsamplingKernel)) + x = applyAudioTimeMask(ctx, x, validLen) + + x = forwardAudioDepthwiseConv2D(ctx, m.Subsampling.DW2, x, m.subsamplingStride, m.subsamplingStride, audioConvPadding(m.subsamplingKernel), audioConvPadding(m.subsamplingKernel), 1, 1) + x = forwardAudioConv2D(ctx, m.Subsampling.PW2, x, 1, 1, 0, 0, 1, 1) + x = x.RELU(ctx) + validLen = convOutputLength(validLen, m.subsamplingKernel, m.subsamplingStride, audioConvPadding(m.subsamplingKernel)) + x = applyAudioTimeMask(ctx, x, validLen) + + x = flattenAudioSubsamplingOutput(ctx, x) + x = m.Subsampling.Linear.Forward(ctx, x) + + if m.scaleInput { + x = x.Scale(ctx, math.Sqrt(float64(m.hiddenSize))) + } + if validLen > 0 && validLen < x.Dim(1) { + x = x.Slice(ctx, 1, 0, validLen, 1).Contiguous(ctx) + } + + for i := range m.Layers { + x = m.Layers[i].Forward(ctx, x, validLen, m.AudioOptions) + } + + if projector != nil { + x = projector.Forward(ctx, x, m.eps) + } + return x +} + +func flattenAudioSubsamplingOutput(ctx ml.Context, x ml.Tensor) ml.Tensor { + fOut, tOut, cOut := x.Dim(0), x.Dim(1), x.Dim(2) + // PyTorch flattens the subsampling output after [B, C, T, F] -> + // [B, T, C, F], so F must remain the fastest dimension inside each + // channel block before the linear projection. + x = x.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + return x.Reshape(ctx, cOut*fOut, tOut) +} + +func (l *AudioLayer) Forward(ctx ml.Context, x ml.Tensor, validLen int, opts *AudioOptions) ml.Tensor { + residual := x + x = audioFeedForward(ctx, l.FFN1Up, l.FFN1Down, l.FFN1Norm.Forward(ctx, x, opts.eps)).Scale(ctx, 0.5) + x = residual.Add(ctx, x) + + residual = x + x = l.Attention.Forward(ctx, l.AttentionNorm.Forward(ctx, x, opts.eps), validLen, opts) + x = residual.Add(ctx, x) + + residual = x + x = l.Conv.Forward(ctx, l.ConvNorm.Forward(ctx, x, opts.eps), opts) + x = residual.Add(ctx, x) + + residual = x + x = audioFeedForward(ctx, l.FFN2Up, l.FFN2Down, l.FFN2Norm.Forward(ctx, x, opts.eps)).Scale(ctx, 0.5) + x = residual.Add(ctx, x) + + return l.OutputNorm.Forward(ctx, x, opts.eps) +} + +func audioFeedForward(ctx ml.Context, up, down *nn.Linear, x ml.Tensor) ml.Tensor { + x = audioF32(ctx, up.Forward(ctx, x)) + x = x.SILU(ctx) + return audioF32(ctx, down.Forward(ctx, x)) +} + +func (a *AudioSelfAttention) Forward(ctx ml.Context, x ml.Tensor, validLen int, opts *AudioOptions) ml.Tensor { + seqLen := x.Dim(1) + headDim := opts.headDim + numHeads := opts.numHeads + + q := audioF32(ctx, a.Query.Forward(ctx, x)).Reshape(ctx, headDim, numHeads, seqLen) + k := audioF32(ctx, a.Key.Forward(ctx, x)).Reshape(ctx, headDim, numHeads, seqLen) + v := audioF32(ctx, a.Value.Forward(ctx, x)).Reshape(ctx, headDim, numHeads, seqLen) + + qU := q + if a.BiasU != nil { + qU = qU.Add(ctx, audioF32(ctx, a.BiasU).Reshape(ctx, headDim, numHeads, 1)) + } + qV := q + if a.BiasV != nil { + qV = qV.Add(ctx, audioF32(ctx, a.BiasV).Reshape(ctx, headDim, numHeads, 1)) + } + + qP := qU.Permute(ctx, 0, 2, 1, 3) + kP := k.Permute(ctx, 0, 2, 1, 3) + logits := kP.MulmatFullPrec(ctx, qP) + + positionEmbeddings := parakeetPositionEmbeddings(ctx, seqLen, opts.hiddenSize) + relKey := audioF32(ctx, a.RelativeKey.Forward(ctx, positionEmbeddings)).Reshape(ctx, headDim, numHeads, 2*seqLen-1) + pP := relKey.Permute(ctx, 0, 2, 1, 3) + qVP := qV.Permute(ctx, 0, 2, 1, 3) + relLogits := pP.MulmatFullPrec(ctx, qVP) + relLogits = relativeShiftParakeet(ctx, relLogits, seqLen, numHeads) + + logits = logits.Add(ctx, relLogits) + logits = logits.Scale(ctx, math.Pow(float64(headDim), -0.5)) + if validLen > 0 && validLen < seqLen { + logits = logits.Add(ctx, audioAttentionMask(ctx, seqLen, validLen)) + } + logits = logits.Softmax(ctx) + + vP := v.Permute(ctx, 0, 2, 1, 3) + vPT := vP.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + out := vPT.Mulmat(ctx, logits) + out = out.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + out = out.Reshape(ctx, opts.hiddenSize, seqLen) + return audioF32(ctx, a.Output.Forward(ctx, out)) +} + +func (c *AudioConvolutionModule) Forward(ctx ml.Context, x ml.Tensor, opts *AudioOptions) ml.Tensor { + x = audioF32(ctx, c.Pointwise1.Forward(ctx, x)) + hidden := x.Dim(0) / 2 + value := x.Slice(ctx, 0, 0, hidden, 1).Contiguous(ctx) + gate := x.Slice(ctx, 0, hidden, 2*hidden, 1).Contiguous(ctx).Sigmoid(ctx) + x = value.Mul(ctx, gate) + + x = audioDepthwiseConv1DSame(ctx, x, c.Depthwise, audioConvPadding(opts.convKernelSize)) + x = c.BatchNorm.Forward(ctx, x, opts.eps) + x = x.SILU(ctx) + return audioF32(ctx, c.Pointwise2.Forward(ctx, x)) +} + +func audioF32(ctx ml.Context, x ml.Tensor) ml.Tensor { + if x.DType() == ml.DTypeF32 { + return x + } + + // Metal binary kernels used by the audio graph require F32 operands here. + // This likely slows audio and should be revisited once the precision vs. + // speed tradeoff is validated against BF16-native elementwise paths. + return x.Cast(ctx, ml.DTypeF32) +} + +func (b *AudioBatchNorm1D) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor { + if b == nil || b.RunningMean == nil || b.RunningVar == nil { + return x + } + + hidden := x.Dim(0) + epsValues := make([]float32, hidden) + for i := range epsValues { + epsValues[i] = eps + } + + variance := b.RunningVar.Add(ctx, ctx.Input().FromFloats(epsValues, hidden)) + x = x.Sub(ctx, b.RunningMean) + x = x.Div(ctx, variance.Sqrt(ctx)) + if b.Weight != nil { + x = x.Mul(ctx, b.Weight) + } + if b.Bias != nil { + x = x.Add(ctx, b.Bias) + } + return x +} + +func forwardAudioConv2D(ctx ml.Context, conv *nn.Conv2D, x ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { + weight := conv.Weight.Contiguous(ctx) + x = weight.Conv2D(ctx, x, s0, s1, p0, p1, d0, d1) + if conv.Bias != nil { + x = x.Add(ctx, conv.Bias.Reshape(ctx, 1, 1, -1)) + } + return x +} + +func forwardAudioDepthwiseConv2D(ctx ml.Context, conv *AudioDepthwiseConv2D, x ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { + x = audioDepthwiseConv2D(ctx, x, conv.Weight, s0, s1, p0, p1, d0, d1) + if conv.Bias != nil { + x = x.Add(ctx, conv.Bias.Reshape(ctx, 1, 1, -1)) + } + return x +} + +func applyAudioTimeMask(ctx ml.Context, x ml.Tensor, validLen int) ml.Tensor { + if validLen <= 0 || validLen >= x.Dim(1) { + return x + } + mask := make([]float32, x.Dim(1)) + for i := range validLen { + mask[i] = 1 + } + return x.Mul(ctx, ctx.Input().FromFloats(mask, 1, x.Dim(1), 1, 1)) +} + +func audioDepthwiseConv1DSame(ctx ml.Context, x, kernel ml.Tensor, padding int) ml.Tensor { + kernelSize := kernel.Dim(0) + seqLen := x.Dim(1) + + kernelT := kernel.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + var out ml.Tensor + for k := range kernelSize { + offset := k - padding + shifted := x + switch { + case offset > 0: + shifted = x.Slice(ctx, 1, offset, seqLen, 1).Contiguous(ctx) + shifted = shifted.PadExt(ctx, 0, 0, 0, offset, 0, 0, 0, 0) + case offset < 0: + shift := -offset + shifted = x.Slice(ctx, 1, 0, seqLen-shift, 1).Contiguous(ctx) + shifted = shifted.PadExt(ctx, 0, 0, shift, 0, 0, 0, 0, 0) + } + + wk := kernelT.Slice(ctx, 1, k, k+1, 1).Contiguous(ctx) + term := shifted.Mul(ctx, wk) + if out == nil { + out = term + } else { + out = out.Add(ctx, term) + } + } + return out +} + +func audioDepthwiseConv2D(ctx ml.Context, x, kernel ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { + if d0 != 1 || d1 != 1 { + panic("audio depthwise conv2d only supports dilation 1") + } + + kernel = kernel.Contiguous(ctx) + kernelW, kernelH := kernel.Dim(0), kernel.Dim(1) + outW := convOutputLength(x.Dim(0), kernelW, s0, p0) + outH := convOutputLength(x.Dim(1), kernelH, s1, p1) + padded := x.PadExt(ctx, p0, p0, p1, p1, 0, 0, 0, 0) + + var out ml.Tensor + for ky := range kernelH { + for kx := range kernelW { + patch := padded.Slice(ctx, 0, kx, kx+s0*(outW-1)+1, s0).Contiguous(ctx) + patch = patch.Slice(ctx, 1, ky, ky+s1*(outH-1)+1, s1).Contiguous(ctx) + + wk := kernel.Slice(ctx, 0, kx, kx+1, 1).Slice(ctx, 1, ky, ky+1, 1).Contiguous(ctx) + if wk.Dim(2) == 1 { + wk = wk.Permute(ctx, 0, 1, 3, 2).Contiguous(ctx) + } else { + wk = wk.Reshape(ctx, 1, 1, wk.Dim(2), wk.Dim(3)) + } + + term := patch.Mul(ctx, wk) + if out == nil { + out = term + } else { + out = out.Add(ctx, term) + } + } + } + return out +} + +func convOutputLength(inputLength, kernel, stride, padding int) int { + if inputLength <= 0 { + return 0 + } + return (inputLength+2*padding-kernel)/stride + 1 +} + +func audioConvPadding(kernel int) int { + return (kernel - 1) / 2 +} + +func parakeetPositionEmbeddings(ctx ml.Context, seqLen, hiddenSize int) ml.Tensor { + half := hiddenSize / 2 + values := make([]float32, hiddenSize*(2*seqLen-1)) + for posIdx, pos := 0, seqLen-1; posIdx < 2*seqLen-1; posIdx, pos = posIdx+1, pos-1 { + for i := range half { + invFreq := math.Pow(10000, -float64(2*i)/float64(hiddenSize)) + angle := float64(pos) * invFreq + values[posIdx*hiddenSize+2*i] = float32(math.Sin(angle)) + values[posIdx*hiddenSize+2*i+1] = float32(math.Cos(angle)) + } + } + return ctx.Input().FromFloats(values, hiddenSize, 2*seqLen-1) +} + +func relativeShiftParakeet(ctx ml.Context, x ml.Tensor, seqLen, numHeads int) ml.Tensor { + positionLen := 2*seqLen - 1 + x = x.PadExt(ctx, 1, 0, 0, 0, 0, 0, 0, 0) + x = x.Reshape(ctx, seqLen, positionLen+1, numHeads) + x = x.Slice(ctx, 1, 1, positionLen+1, 1).Contiguous(ctx) + x = x.Reshape(ctx, positionLen, seqLen, numHeads) + return x.Slice(ctx, 0, 0, seqLen, 1).Contiguous(ctx) +} + +func audioAttentionMask(ctx ml.Context, seqLen, validLen int) ml.Tensor { + values := make([]float32, seqLen*seqLen) + for q := range seqLen { + for k := range seqLen { + if q >= validLen || k >= validLen { + values[q*seqLen+k] = -1e9 + } + } + } + return ctx.Input().FromFloats(values, seqLen, seqLen, 1) +} + +func newAudioModel(c fs.Config) *AudioModel { + numLayers := int(c.Uint("audio.block_count", 0)) + if numLayers == 0 { + return nil + } + return &AudioModel{ + Layers: make([]AudioLayer, numLayers), + AudioOptions: newAudioOptions(c), + } +} + +func newAudioProjector(c fs.Config) *AudioProjector { + if c.Uint("audio.block_count", 0) == 0 { + return nil + } + return &AudioProjector{} +} + +func newAudioOptions(c fs.Config) *AudioOptions { + hiddenSize := int(c.Uint("audio.embedding_length", 1024)) + numHeads := int(c.Uint("audio.attention.head_count", 8)) + headDim := hiddenSize / max(1, numHeads) + return &AudioOptions{ + hiddenSize: hiddenSize, + numHeads: numHeads, + headDim: headDim, + intermediateSize: int(c.Uint("audio.feed_forward_length", uint32(hiddenSize*4))), + convKernelSize: int(c.Uint("audio.conv_kernel_size", 9)), + melBins: int(c.Uint("audio.num_mel_bins", 128)), + sampleRate: int(c.Uint("audio.sample_rate", 16000)), + subsamplingKernel: int(c.Uint("audio.subsampling_conv_kernel_size", 3)), + subsamplingStride: int(c.Uint("audio.subsampling_conv_stride", 2)), + scaleInput: c.Bool("audio.scale_input", false), + eps: c.Float("audio.attention.layer_norm_epsilon", 1e-5), + } +} + +func defaultAudioOptions() *AudioOptions { + return &AudioOptions{ + hiddenSize: 1024, + numHeads: 8, + headDim: 128, + intermediateSize: 4096, + convKernelSize: 9, + melBins: 128, + sampleRate: 16000, + subsamplingKernel: 3, + subsamplingStride: 2, + eps: 1e-5, + } +} diff --git a/model/models/nemotronh/model_omni.go b/model/models/nemotronh/model_omni.go new file mode 100644 index 000000000..ad42be008 --- /dev/null +++ b/model/models/nemotronh/model_omni.go @@ -0,0 +1,239 @@ +package nemotronh + +import ( + "bytes" + "errors" + "image" + "slices" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type OmniModel struct { + *Model + *VisionModel `gguf:"v"` + *AudioModel `gguf:"a"` + *MultiModalProjector `gguf:"mm"` + *AudioProjector `gguf:"mm.a"` + + ImageProcessor + + imageTokenID int32 + imageStartToken int32 + imageEndToken int32 + audioTokenID int32 +} + +var _ model.MultimodalProcessor = (*OmniModel)(nil) + +func NewOmni(c fs.Config) (model.Model, error) { + textModel, err := newTextModel(c) + if err != nil { + return nil, err + } + + imageTokenID := int32(c.Uint("vision.image_token_id", 18)) + imageStartToken := int32(c.Uint("vision.image_start_token_id", 19)) + imageEndToken := int32(c.Uint("vision.image_end_token_id", 20)) + audioTokenID := int32(c.Uint("audio.sound_token_id", 27)) + + return &OmniModel{ + Model: textModel, + VisionModel: newVisionModel(c), + AudioModel: newAudioModel(c), + MultiModalProjector: newMultiModalProjector(c), + AudioProjector: newAudioProjector(c), + ImageProcessor: newImageProcessor(c), + imageTokenID: imageTokenID, + imageStartToken: imageStartToken, + imageEndToken: imageEndToken, + audioTokenID: audioTokenID, + }, nil +} + +func (m *OmniModel) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { + if isAudioData(multimodalData) { + return m.encodeAudioMultimodal(ctx, multimodalData) + } + + if m.VisionModel == nil || m.MultiModalProjector == nil || len(m.VisionModel.Layers) == 0 { + return nil, model.ErrNoVisionModel + } + + img, _, err := image.Decode(bytes.NewReader(multimodalData)) + if err != nil { + return nil, err + } + + tiles, err := m.ImageProcessor.ProcessImage(img) + if err != nil { + return nil, err + } + + mm := make([]input.Multimodal, 0, len(tiles)) + for _, tile := range tiles { + patches := visionPatchGrid{ + Width: tile.size.X / m.ImageProcessor.patchSize, + Height: tile.size.Y / m.ImageProcessor.patchSize, + } + if patches.Width == 0 || patches.Height == 0 { + return nil, errors.New("nemotron_h_omni: invalid resized image dimensions") + } + + patchInput := packVisionPatchesCHW(tile.data, tile.size.X, tile.size.Y, m.ImageProcessor.numChannels, m.ImageProcessor.patchSize) + visionOutputs := m.VisionModel.ForwardPacked(ctx, patchInput, patches) + projected := m.MultiModalProjector.Forward(ctx, visionOutputs, patches) + mm = append(mm, input.Multimodal{Tensor: projected}) + } + + return mm, nil +} + +type audioTag struct{} + +func (m *OmniModel) encodeAudioMultimodal(ctx ml.Context, data []byte) ([]input.Multimodal, error) { + if m.AudioModel == nil || m.AudioProjector == nil || len(m.AudioModel.Layers) == 0 { + return nil, model.ErrNoVisionModel + } + + samples, err := decodeWAV(data, m.AudioModel.sampleRate) + if err != nil { + return nil, err + } + + melData, frames, validFrames, err := computeParakeetMelSpectrogram(samples, m.AudioModel.FeatureExtractor, m.AudioModel.AudioOptions) + if err != nil { + return nil, err + } + + melTensor := ctx.Input().FromFloats(melData, m.AudioModel.melBins, frames) + audioOutputs := m.AudioModel.ForwardAudio(ctx, melTensor, validFrames, m.AudioProjector) + return []input.Multimodal{{Tensor: audioOutputs, Data: audioTag{}}}, nil +} + +func (m *OmniModel) PostLoad() error { + return nil +} + +func (m *OmniModel) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + var result []*input.Input + + imageToken := m.imageTokenID + if imageToken == 0 { + imageToken = 18 + } + + for _, inp := range inputs { + if len(inp.Multimodal) == 0 { + result = append(result, inp) + continue + } + + totalTokens := 0 + for _, mm := range inp.Multimodal { + if mm.Tensor == nil { + continue + } + totalTokens += mm.Tensor.Dim(1) + } + if totalTokens <= 0 { + return nil, errors.New("nemotron_h_omni: multimodal input has no tokens") + } + + if _, ok := inp.Multimodal[0].Data.(audioTag); ok { + audioToken := m.audioTokenID + if audioToken == 0 { + audioToken = 27 + } + + for i, mm := range inp.Multimodal { + tokenCount := 0 + if mm.Tensor != nil { + tokenCount = mm.Tensor.Dim(1) + } + if tokenCount <= 0 { + return nil, errors.New("nemotron_h_omni: multimodal input has no tokens") + } + + first := &input.Input{Token: audioToken, SameBatch: tokenCount - 1} + if i == 0 { + first.MultimodalHash = inp.MultimodalHash + } + first.Multimodal = []input.Multimodal{mm} + result = append(result, first) + if tokenCount > 1 { + result = append(result, slices.Repeat([]*input.Input{{Token: audioToken}}, tokenCount-1)...) + } + } + continue + } + + if m.imageStartToken > 0 { + result = append(result, &input.Input{ + Token: m.imageStartToken, + SameBatch: totalTokens + btoi(m.imageEndToken > 0), + }) + } + + for _, mm := range inp.Multimodal { + tokenCount := 0 + if mm.Tensor != nil { + tokenCount = mm.Tensor.Dim(1) + } + if tokenCount <= 0 { + return nil, errors.New("nemotron_h_omni: multimodal input has no tokens") + } + + result = append(result, &input.Input{ + Token: imageToken, + Multimodal: []input.Multimodal{mm}, + MultimodalHash: inp.MultimodalHash, + }) + if tokenCount > 1 { + result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, tokenCount-1)...) + } + } + + if m.imageEndToken > 0 { + result = append(result, &input.Input{Token: m.imageEndToken}) + } + } + + return result, nil +} + +func btoi(v bool) int { + if v { + return 1 + } + return 0 +} + +func (m *OmniModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + if len(batch.Multimodal) > 0 { + hiddenStates = hiddenStates.Duplicate(ctx) + } + + for _, mm := range batch.Multimodal { + offset := mm.Index + for _, multimodal := range mm.Multimodal { + if multimodal.Tensor == nil { + continue + } + + tensor := multimodal.Tensor + ctx.Forward(tensor.Copy(ctx, hiddenStates.View(ctx, offset*hiddenStates.Stride(1), tensor.Dim(0)*tensor.Dim(1)))) + offset += tensor.Dim(1) + } + } + + return m.forwardLogits(ctx, batch, hiddenStates) +} + +func init() { + model.Register("nemotron_h_omni", NewOmni) +} diff --git a/model/models/nemotronh/model_omni_test.go b/model/models/nemotronh/model_omni_test.go new file mode 100644 index 000000000..f211c4630 --- /dev/null +++ b/model/models/nemotronh/model_omni_test.go @@ -0,0 +1,606 @@ +package nemotronh + +import ( + "bytes" + "encoding/base64" + "encoding/binary" + "image" + "image/color" + "math" + "os" + "path/filepath" + "slices" + "strings" + "testing" + + fsggml "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/ml" + backendggml "github.com/ollama/ollama/ml/backend/ggml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model/input" +) + +type fakeTensor struct { + *backendggml.Tensor + dims []int +} + +func (t *fakeTensor) Dim(i int) int { + return t.dims[i] +} + +func setupTestContext(t *testing.T) ml.Context { + t.Helper() + + f, err := os.CreateTemp(t.TempDir(), "*.gguf") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if err := fsggml.WriteGGUF(f, fsggml.KV{"general.architecture": "test"}, nil); err != nil { + t.Fatal(err) + } + + b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true}) + if err != nil { + t.Fatal(err) + } + + ctx := b.NewContext().Input() + t.Cleanup(func() { + ctx.Close() + b.Close() + }) + + return ctx +} + +func TestPostTokenizeImageSpans(t *testing.T) { + m := &OmniModel{ + imageTokenID: 18, + imageStartToken: 19, + imageEndToken: 20, + } + + makeChunk := func() input.Multimodal { + return input.Multimodal{Tensor: &fakeTensor{dims: []int{2688, 256, 1, 1}}} + } + + in := []*input.Input{ + {Token: 7}, + { + Multimodal: []input.Multimodal{makeChunk(), makeChunk()}, + MultimodalHash: 99, + }, + {Token: 8}, + } + + out, err := m.PostTokenize(in) + if err != nil { + t.Fatalf("PostTokenize() error = %v", err) + } + + if len(out) != 516 { + t.Fatalf("len(out) = %d, want 516", len(out)) + } + + if out[0].Token != 7 { + t.Fatalf("out[0].Token = %d, want 7", out[0].Token) + } + if out[1].Token != 19 { + t.Fatalf("out[1].Token = %d, want 19", out[1].Token) + } + if out[1].SameBatch != 513 { + t.Fatalf("out[1].SameBatch = %d, want 513", out[1].SameBatch) + } + if out[2].Token != 18 || len(out[2].Multimodal) != 1 || out[2].MultimodalHash != 99 || out[2].SameBatch != 0 { + t.Fatalf("unexpected first image token: %+v", *out[2]) + } + if out[258].Token != 18 || len(out[258].Multimodal) != 1 || out[258].MultimodalHash != 99 || out[258].SameBatch != 0 { + t.Fatalf("unexpected second image token: %+v", *out[258]) + } + if out[514].Token != 20 { + t.Fatalf("out[514].Token = %d, want 20", out[514].Token) + } + if out[515].Token != 8 { + t.Fatalf("out[515].Token = %d, want 8", out[515].Token) + } +} + +func TestProjectorPixelShuffleMatchesReferenceV2Order(t *testing.T) { + ctx := setupTestContext(t) + + hidden := 2 + width := 4 + height := 2 + values := make([]float32, 0, hidden*width*height) + for y := range height { + for x := range width { + for c := range hidden { + values = append(values, float32(100*y+10*x+c)) + } + } + } + + got := pixelShuffleVisionOutputs(ctx, ctx.FromFloats(values, hidden, width*height), visionPatchGrid{ + Width: width, + Height: height, + }, 2) + ctx.Forward(got).Compute(got) + + want := []float32{ + 0, 1, 10, 11, 100, 101, 110, 111, + 20, 21, 30, 31, 120, 121, 130, 131, + } + if got.Shape()[0] != 8 || got.Shape()[1] != 2 { + t.Fatalf("shape = %v, want [8 2 1]", got.Shape()) + } + gotValues := got.BackendGet() + if len(gotValues) != len(want) { + t.Fatalf("len(got) = %d, want %d", len(gotValues), len(want)) + } + for i := range want { + if gotValues[i] != want[i] { + t.Fatalf("got[%d] = %v, want %v", i, gotValues[i], want[i]) + } + } +} + +func TestPostTokenizeAudioSpans(t *testing.T) { + m := &OmniModel{ + audioTokenID: 27, + } + + in := []*input.Input{ + {Token: 7}, + { + Multimodal: []input.Multimodal{{ + Tensor: &fakeTensor{dims: []int{2688, 13, 1, 1}}, + Data: audioTag{}, + }}, + MultimodalHash: 99, + }, + {Token: 8}, + } + + out, err := m.PostTokenize(in) + if err != nil { + t.Fatalf("PostTokenize() error = %v", err) + } + + if len(out) != 15 { + t.Fatalf("len(out) = %d, want 15", len(out)) + } + if out[0].Token != 7 || out[14].Token != 8 { + t.Fatalf("unexpected surrounding tokens: first=%d last=%d", out[0].Token, out[14].Token) + } + for i := 1; i <= 13; i++ { + if out[i].Token != 27 { + t.Fatalf("out[%d].Token = %d, want 27", i, out[i].Token) + } + } + if len(out[1].Multimodal) != 1 || out[1].MultimodalHash != 99 { + t.Fatalf("first audio token did not carry multimodal payload: %+v", *out[1]) + } + if out[1].SameBatch != 12 { + t.Fatalf("first audio token SameBatch = %d, want 12", out[1].SameBatch) + } + if len(out[2].Multimodal) != 0 { + t.Fatalf("only the first audio token should carry multimodal payload: %+v", *out[2]) + } +} + +func TestParakeetAudioPreprocessShapes(t *testing.T) { + data := sineWAV(t, 16000, 440, 1.0) + samples, err := decodeWAV(data, 16000) + if err != nil { + t.Fatal(err) + } + if got, want := len(samples), 16000; got != want { + t.Fatalf("sample count = %d, want %d", got, want) + } + + mel, frames, validFrames, err := computeParakeetMelSpectrogram(samples, nil, defaultAudioOptions()) + if err != nil { + t.Fatal(err) + } + if frames != 101 { + t.Fatalf("frames = %d, want 101", frames) + } + if validFrames != 100 { + t.Fatalf("validFrames = %d, want 100", validFrames) + } + if len(mel) != 101*128 { + t.Fatalf("len(mel) = %d, want %d", len(mel), 101*128) + } + lastFrame := mel[100*128 : 101*128] + if !slices.Equal(lastFrame, make([]float32, 128)) { + t.Fatal("expected masked final frame to be zero") + } +} + +func TestParakeetAudioPreprocessMatchesIntegrationWAVReference(t *testing.T) { + data := integrationAudioWAV(t) + samples, err := decodeWAV(data, 16000) + if err != nil { + t.Fatal(err) + } + if got, want := len(samples), 42083; got != want { + t.Fatalf("sample count = %d, want %d", got, want) + } + + mel, frames, validFrames, err := computeParakeetMelSpectrogram(samples, nil, defaultAudioOptions()) + if err != nil { + t.Fatal(err) + } + if frames != 264 { + t.Fatalf("frames = %d, want 264", frames) + } + if validFrames != 263 { + t.Fatalf("validFrames = %d, want 263", validFrames) + } + if len(mel) != 264*128 { + t.Fatalf("len(mel) = %d, want %d", len(mel), 264*128) + } + lastFrame := mel[263*128 : 264*128] + if !slices.Equal(lastFrame, make([]float32, 128)) { + t.Fatal("expected masked final frame to be zero") + } + + // Reference values come from the ParakeetExtractor path used by vLLM: + // pre-emphasis, torch.stft(center=True, pad_mode="constant"), Slaney mel + // filters, log guard 2^-24, and per-mel normalization over valid frames. + checks := map[[2]int]float32{ + {0, 0}: -1.0855197, + {0, 50}: -0.93212974, + {1, 10}: -0.9735168, + {2, 100}: -0.6533053, + {50, 0}: 2.2483668, + {50, 127}: -0.3828735, + {100, 50}: 2.9742377, + {262, 0}: -0.9521758, + {262, 127}: -0.4602786, + {263, 50}: 0, + } + for pos, want := range checks { + got := mel[pos[0]*128+pos[1]] + if math.Abs(float64(got-want)) > 1e-4 { + t.Errorf("mel[%d,%d] = %v, want %v", pos[0], pos[1], got, want) + } + } +} + +func integrationAudioWAV(t *testing.T) []byte { + t.Helper() + + path := filepath.Join("..", "..", "..", "integration", "audio_test_data_test.go") + b, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + + const marker = "const audioEncodingPrompt = `" + s := string(b) + start := strings.Index(s, marker) + if start < 0 { + t.Fatal("audioEncodingPrompt marker not found") + } + start += len(marker) + end := strings.Index(s[start:], "`") + if end < 0 { + t.Fatal("audioEncodingPrompt terminator not found") + } + + data, err := base64.StdEncoding.DecodeString(strings.TrimSpace(s[start : start+end])) + if err != nil { + t.Fatal(err) + } + return data +} + +func TestRelativeShiftParakeetMatchesReference(t *testing.T) { + ctx := setupTestContext(t) + + seqLen := 3 + positionLen := 2*seqLen - 1 + values := make([]float32, seqLen*positionLen) + for q := range seqLen { + for p := range positionLen { + values[q*positionLen+p] = float32(q*10 + p) + } + } + + x := ctx.FromFloats(values, positionLen, seqLen, 1) + got := relativeShiftParakeet(ctx, x, seqLen, 1) + ctx.Forward(got).Compute(got) + + want := []float32{ + 2, 3, 4, + 11, 12, 13, + 20, 21, 22, + } + if !slices.Equal(got.BackendGet(), want) { + t.Fatalf("relative shift mismatch:\n got %v\nwant %v", got.BackendGet(), want) + } +} + +func TestAudioDepthwiseConv2DMatchesReference(t *testing.T) { + ctx := setupTestContext(t) + + freq, frames, channels := 4, 5, 2 + xValues := make([]float32, freq*frames*channels) + for i := range xValues { + xValues[i] = float32(i)/10 - 1 + } + kernelValues := make([]float32, 3*3*channels) + for i := range kernelValues { + kernelValues[i] = float32(i)/7 - 1 + } + + x := ctx.FromFloats(xValues, freq, frames, channels, 1) + kernel := ctx.FromFloats(kernelValues, 3, 3, 1, channels) + bias := ctx.FromFloats([]float32{0.25, -0.5}, channels) + got := audioDepthwiseConv2D(ctx, x, kernel, 2, 2, 1, 1, 1, 1).Add(ctx, bias.Reshape(ctx, 1, 1, -1)) + ctx.Forward(got).Compute(got) + + want := []float32{ + 0.86428565, 1.3357141, + 1.2785715, 1.3642857, + -0.5928571, -1.7499999, + 5.4000001, 8.8142853, + 10.514286, 16.042856, + 6.6857138, 9.8428574, + } + assertCloseSlice(t, got.BackendGet(), want, 1e-5) +} + +func TestFlattenAudioSubsamplingOutputMatchesReference(t *testing.T) { + ctx := setupTestContext(t) + + const ( + freq = 2 + frames = 3 + channels = 2 + ) + values := make([]float32, freq*frames*channels) + for c := range channels { + for t := range frames { + for f := range freq { + values[f+freq*(t+frames*c)] = float32(100*c + 10*t + f) + } + } + } + + got := flattenAudioSubsamplingOutput(ctx, ctx.FromFloats(values, freq, frames, channels, 1)) + ctx.Forward(got).Compute(got) + + want := []float32{ + 0, 1, 100, 101, + 10, 11, 110, 111, + 20, 21, 120, 121, + } + assertCloseSlice(t, got.BackendGet(), want, 0) +} + +func TestAudioDepthwiseConv1DMatchesReference(t *testing.T) { + ctx := setupTestContext(t) + + xValues := make([]float32, 2*5) + for i := range xValues { + xValues[i] = float32(i)/5 - 0.7 + } + kernelValues := make([]float32, 3*2) + for i := range kernelValues { + kernelValues[i] = float32(i)/3 - 0.5 + } + + x := ctx.FromFloats(xValues, 2, 5) + kernel := ctx.FromFloats(kernelValues, 3, 2) + got := audioDepthwiseConv1DSame(ctx, x, kernel, 1) + ctx.Forward(got).Compute(got) + + want := []float32{ + 0.066666655, -0.5333333, + 0.41666666, 0.016666688, + 0.21666668, 1.0166667, + 0.01666667, 2.0166664, + -0.40000004, 1.2666667, + } + assertCloseSlice(t, got.BackendGet(), want, 1e-5) +} + +func TestAudioSelfAttentionMatchesReference(t *testing.T) { + ctx := setupTestContext(t) + + const ( + hiddenSize = 4 + numHeads = 2 + headDim = 2 + seqLen = 3 + ) + xValues := make([]float32, hiddenSize*seqLen) + for i := range xValues { + xValues[i] = float32(i)/10 - 0.5 + } + + identity := make([]float32, hiddenSize*hiddenSize) + for i := range hiddenSize { + identity[i*hiddenSize+i] = 1 + } + linear := func() *nn.Linear { + return &nn.Linear{Weight: ctx.FromFloats(identity, hiddenSize, hiddenSize)} + } + + attn := &AudioSelfAttention{ + Query: linear(), + Key: linear(), + Value: linear(), + Output: linear(), + RelativeKey: linear(), + BiasU: ctx.FromFloats([]float32{0.1, -0.2, 0.3, -0.4}, headDim, numHeads), + BiasV: ctx.FromFloats([]float32{-0.05, 0.07, 0.11, -0.13}, headDim, numHeads), + } + got := attn.Forward(ctx, ctx.FromFloats(xValues, hiddenSize, seqLen), seqLen, &AudioOptions{ + hiddenSize: hiddenSize, + numHeads: numHeads, + headDim: headDim, + }) + ctx.Forward(got).Compute(got) + + want := []float32{ + -0.08471569, 0.015284289, 0.05532019, 0.1553202, + -0.09135241, 0.008647568, 0.11468154, 0.21468155, + -0.019152153, 0.08084783, 0.1733382, 0.2733382, + } + assertCloseSlice(t, got.BackendGet(), want, 1e-5) +} + +func assertCloseSlice(t *testing.T, got, want []float32, tolerance float64) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("len(got) = %d, want %d", len(got), len(want)) + } + for i := range want { + if math.Abs(float64(got[i]-want[i])) > tolerance { + t.Fatalf("got[%d] = %v, want %v\nall got: %v", i, got[i], want[i], got) + } + } +} + +func TestPackPatchesCHW(t *testing.T) { + values := []float32{ + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 100, 101, 102, 103, + 104, 105, 106, 107, + 108, 109, 110, 111, + 112, 113, 114, 115, + } + + got := packVisionPatchesCHW(values, 4, 4, 2, 2) + want := []float32{ + 0, 1, 4, 5, 100, 101, 104, 105, + 2, 3, 6, 7, 102, 103, 106, 107, + 8, 9, 12, 13, 108, 109, 112, 113, + 10, 11, 14, 15, 110, 111, 114, 115, + } + + if len(got) != len(want) { + t.Fatalf("len(got) = %d, want %d", len(got), len(want)) + } + + for i := range want { + if got[i] != want[i] { + t.Fatalf("got[%d] = %v, want %v", i, got[i], want[i]) + } + } +} + +func TestResizePositionEmbeddingMatchesReferenceInterpolation(t *testing.T) { + values := []float32{ + 0, 10, + 20, 30, + } + got := resizePositionEmbedding(values, 1, 2, 2, 3, 3) + want := []float32{ + 0, 5, 10, + 10, 15, 20, + 20, 25, 30, + } + + if len(got) != len(want) { + t.Fatalf("len(got) = %d, want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("got[%d] = %v, want %v", i, got[i], want[i]) + } + } +} + +func TestDynamicImageProcessorMatchesReferencePatchBudget(t *testing.T) { + p := ImageProcessor{ + imageSize: 512, + patchSize: 16, + numChannels: 3, + minNumPatches: 1024, + maxNumPatches: 13312, + projectorScale: 2, + imageMean: [3]float32{0.48145466, 0.4578275, 0.40821073}, + imageStd: [3]float32{0.26862954, 0.26130258, 0.27577711}, + } + + img := image.NewRGBA(image.Rect(0, 0, 400, 250)) + bounds := img.Bounds() + width, height := bounds.Dx(), bounds.Dy() + for y := range height { + for x := range width { + img.SetRGBA(x, y, color.RGBA{R: uint8(x), G: uint8(y), B: 128, A: 255}) + } + } + + tiles, err := p.ProcessImage(img) + if err != nil { + t.Fatalf("ProcessImage() error = %v", err) + } + if got, want := len(tiles), 1; got != want { + t.Fatalf("len(tiles) = %d, want %d", got, want) + } + if got, want := tiles[0].size, (image.Point{X: 672, Y: 416}); got != want { + t.Fatalf("tile size = %v, want %v", got, want) + } + if got, want := len(tiles[0].data), 3*672*416; got != want { + t.Fatalf("tile data len = %d, want %d", got, want) + } +} + +func sineWAV(t *testing.T, sampleRate int, frequency float64, seconds float64) []byte { + t.Helper() + + samples := int(float64(sampleRate) * seconds) + var pcm bytes.Buffer + for i := range samples { + v := int16(math.Sin(2*math.Pi*frequency*float64(i)/float64(sampleRate)) * 32767) + if err := binary.Write(&pcm, binary.LittleEndian, v); err != nil { + t.Fatal(err) + } + } + + var out bytes.Buffer + out.WriteString("RIFF") + if err := binary.Write(&out, binary.LittleEndian, uint32(36+pcm.Len())); err != nil { + t.Fatal(err) + } + out.WriteString("WAVE") + out.WriteString("fmt ") + if err := binary.Write(&out, binary.LittleEndian, uint32(16)); err != nil { + t.Fatal(err) + } + if err := binary.Write(&out, binary.LittleEndian, uint16(1)); err != nil { + t.Fatal(err) + } + if err := binary.Write(&out, binary.LittleEndian, uint16(1)); err != nil { + t.Fatal(err) + } + if err := binary.Write(&out, binary.LittleEndian, uint32(sampleRate)); err != nil { + t.Fatal(err) + } + if err := binary.Write(&out, binary.LittleEndian, uint32(sampleRate*2)); err != nil { + t.Fatal(err) + } + if err := binary.Write(&out, binary.LittleEndian, uint16(2)); err != nil { + t.Fatal(err) + } + if err := binary.Write(&out, binary.LittleEndian, uint16(16)); err != nil { + t.Fatal(err) + } + out.WriteString("data") + if err := binary.Write(&out, binary.LittleEndian, uint32(pcm.Len())); err != nil { + t.Fatal(err) + } + out.Write(pcm.Bytes()) + return out.Bytes() +} diff --git a/model/models/nemotronh/model_vision.go b/model/models/nemotronh/model_vision.go new file mode 100644 index 000000000..aa6a146d5 --- /dev/null +++ b/model/models/nemotronh/model_vision.go @@ -0,0 +1,348 @@ +package nemotronh + +import ( + "math" + "sync" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +const nemotronVisionBatchSize = 1 + +type visionPatchGrid struct { + Width int + Height int +} + +type VisionPatchEmbedding struct { + *nn.Linear +} + +func packVisionPatchesCHW(values []float32, width, height, channels, patchSize int) []float32 { + patchesX, patchesY := width/patchSize, height/patchSize + patchDim := channels * patchSize * patchSize + plane := width * height + + patches := make([]float32, patchDim*patchesX*patchesY) + offset := 0 + for py := range patchesY { + for px := range patchesX { + for c := range channels { + channelBase := c * plane + for yy := range patchSize { + rowBase := (py*patchSize + yy) * width + for xx := range patchSize { + patches[offset] = values[channelBase+rowBase+px*patchSize+xx] + offset++ + } + } + } + } + } + + return patches +} + +func (p *VisionPatchEmbedding) ForwardPacked(ctx ml.Context, patches []float32, patchDim, numPatches int) ml.Tensor { + hiddenState := ctx.Input().FromFloats(patches, patchDim, numPatches) + hiddenState = hiddenState.Duplicate(ctx) + return p.Linear.Forward(ctx, hiddenState) +} + +func (p *VisionPatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSize int) ml.Tensor { + // Match the RADIO patch generator's exact flattening order: patches are laid + // out token-major with each token packed as channel, then patch-row, then + // patch-col. This is more explicit than the prior IM2Col path and likely + // slower, but it avoids backend-specific packing differences that caused the + // converted patch embedder to diverge badly from the reference model. + width, height, channels := pixelValues.Dim(0), pixelValues.Dim(1), pixelValues.Dim(2) + patchesX, patchesY := width/patchSize, height/patchSize + patchDim := channels * patchSize * patchSize + + values := pixelValues.BackendGet() + return p.ForwardPacked(ctx, packVisionPatchesCHW(values, width, height, channels, patchSize), patchDim, patchesX*patchesY) +} + +type VisionSelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_out"` +} + +func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionOptions) ml.Tensor { + headDim := opts.hiddenSize / opts.numHeads + + query := sa.Query.Forward(ctx, hiddenState) + key := sa.Key.Forward(ctx, hiddenState) + value := sa.Value.Forward(ctx, hiddenState) + + query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), nemotronVisionBatchSize) + key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), nemotronVisionBatchSize) + value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), nemotronVisionBatchSize) + + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil) + attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), nemotronVisionBatchSize) + return sa.Output.Forward(ctx, attention) +} + +type VisionMLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor { + return mlp.Down.Forward(ctx, mlp.Up.Forward(ctx, hiddenState).GELU(ctx)) +} + +type VisionEncoderLayer struct { + LayerNorm1 *nn.LayerNorm `gguf:"ln1"` + SelfAttention *VisionSelfAttention + + LayerNorm2 *nn.LayerNorm `gguf:"ln2"` + MLP *VisionMLP +} + +func (l *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionOptions) ml.Tensor { + residual := hiddenState + + hiddenState = l.LayerNorm1.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, opts) + hiddenState = hiddenState.Add(ctx, residual) + + residual = hiddenState + hiddenState = l.LayerNorm2.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.MLP.Forward(ctx, hiddenState) + return hiddenState.Add(ctx, residual) +} + +type VisionOptions struct { + hiddenSize int + numHeads int + imageSize int + patchSize int + eps float32 +} + +type VisionModel struct { + PatchEmbedding *VisionPatchEmbedding `gguf:"patch_embd"` + PositionEmbedding ml.Tensor `gguf:"position_embd"` + ClassEmbedding ml.Tensor `gguf:"cls_embd"` + + Layers []VisionEncoderLayer `gguf:"blk"` + + *VisionOptions + + resizedPositionEmbeddingsMu sync.Mutex + resizedPositionEmbeddings map[visionPatchGrid][]float32 +} + +func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, patches visionPatchGrid) ml.Tensor { + numPatches := patches.Width * patches.Height + + hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize) + return m.forwardPatchEmbeddings(ctx, hiddenState, patches, numPatches) +} + +func (m *VisionModel) ForwardPacked(ctx ml.Context, patchValues []float32, patches visionPatchGrid) ml.Tensor { + numPatches := patches.Width * patches.Height + patchDim := 0 + if numPatches > 0 { + patchDim = len(patchValues) / numPatches + } + hiddenState := m.PatchEmbedding.ForwardPacked(ctx, patchValues, patchDim, numPatches) + return m.forwardPatchEmbeddings(ctx, hiddenState, patches, numPatches) +} + +func (m *VisionModel) forwardPatchEmbeddings(ctx ml.Context, hiddenState ml.Tensor, patches visionPatchGrid, numPatches int) ml.Tensor { + if m.PositionEmbedding != nil { + positionEmbeddings := m.positionEmbeddings(ctx, hiddenState, patches, numPatches) + hiddenState = hiddenState.Add(ctx, positionEmbeddings) + } + + if m.ClassEmbedding != nil { + numPrefixTokens := m.ClassEmbedding.Dim(1) + classEmbeddings := m.ClassEmbedding.Cast(ctx, hiddenState.DType()) + classEmbeddings = classEmbeddings.Reshape(ctx, classEmbeddings.Dim(0), numPrefixTokens, 1) + hiddenState = classEmbeddings.Concat(ctx, hiddenState, 1) + } + + for _, layer := range m.Layers { + hiddenState = layer.Forward(ctx, hiddenState, m.VisionOptions) + } + + if m.ClassEmbedding != nil { + hiddenState = hiddenState.Slice(ctx, 1, m.ClassEmbedding.Dim(1), hiddenState.Dim(1), 1) + } + + return hiddenState.Reshape(ctx, hiddenState.Dim(0), hiddenState.Dim(1)) +} + +func (m *VisionModel) positionEmbeddings(ctx ml.Context, hiddenState ml.Tensor, patches visionPatchGrid, numPatches int) ml.Tensor { + posTokens := m.PositionEmbedding.Dim(1) + source := int(math.Sqrt(float64(posTokens))) + + positionEmbeddings := m.PositionEmbedding.Cast(ctx, hiddenState.DType()) + if !(source > 0 && source*source == posTokens && (source != patches.Width || source != patches.Height)) { + if positionEmbeddings.Dim(1) > numPatches { + positionEmbeddings = positionEmbeddings.Slice(ctx, 1, 0, numPatches, 1) + } + return positionEmbeddings + } + + if cached, ok := m.cachePositionEmbeddings(ctx, hiddenState.Dim(0), patches); ok { + return ctx.Input().FromFloats(cached, hiddenState.Dim(0), numPatches) + } + + // Runner fit/reserve builds worst-case multimodal graphs before weights are + // loaded, so the align-corners CPU cache path cannot materialize source + // values there. Fall back to a graph-only bilinear resize for reservation; + // the loaded inference path above still uses the cached align-corners data. + positionEmbeddings = positionEmbeddings.Reshape(ctx, -1, source, source) + positionEmbeddings = positionEmbeddings.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx) + positionEmbeddings = positionEmbeddings.Interpolate(ctx, [4]int{ + patches.Width, + patches.Height, + hiddenState.Dim(0), + 1, + }, ml.SamplingModeBilinear) + positionEmbeddings = positionEmbeddings.Permute(ctx, 1, 2, 0, 3) + return positionEmbeddings.Contiguous(ctx, -1, patches.Width*patches.Height) +} + +func (m *VisionModel) cachePositionEmbeddings(ctx ml.Context, hidden int, patches visionPatchGrid) ([]float32, bool) { + m.resizedPositionEmbeddingsMu.Lock() + cached := m.resizedPositionEmbeddings[patches] + m.resizedPositionEmbeddingsMu.Unlock() + if cached != nil { + return cached, true + } + + if len(m.PositionEmbedding.Bytes()) == 0 { + return nil, false + } + + posTokens := m.PositionEmbedding.Dim(1) + source := int(math.Sqrt(float64(posTokens))) + positionEmbeddingsF32 := m.PositionEmbedding.Cast(ctx, ml.DTypeF32) + ctx.Forward(positionEmbeddingsF32).Compute(positionEmbeddingsF32) + + // RADIO eval-time CPE uses bilinear interpolation with align_corners=false. + // Cache a CPU-resized token-major embedding here for correctness first. This + // is likely slower than a native graph path and should be revisited if this + // precision vs speed tradeoff is not worthwhile. + cached = resizePositionEmbedding(positionEmbeddingsF32.Floats(), hidden, source, source, patches.Width, patches.Height) + + m.resizedPositionEmbeddingsMu.Lock() + if m.resizedPositionEmbeddings == nil { + m.resizedPositionEmbeddings = make(map[visionPatchGrid][]float32) + } + if existing := m.resizedPositionEmbeddings[patches]; existing != nil { + cached = existing + } else { + m.resizedPositionEmbeddings[patches] = cached + } + m.resizedPositionEmbeddingsMu.Unlock() + + return cached, true +} + +func resizePositionEmbedding(values []float32, hidden, sourceWidth, sourceHeight, targetWidth, targetHeight int) []float32 { + out := make([]float32, hidden*targetWidth*targetHeight) + + scaleX := float64(sourceWidth) / float64(targetWidth) + scaleY := float64(sourceHeight) / float64(targetHeight) + + for oy := range targetHeight { + srcY := scaleY*(float64(oy)+0.5) - 0.5 + y0 := int(math.Floor(srcY)) + y1 := min(y0+1, sourceHeight-1) + wy := float32(srcY - float64(y0)) + y0 = max(y0, 0) + + for ox := range targetWidth { + srcX := scaleX*(float64(ox)+0.5) - 0.5 + x0 := int(math.Floor(srcX)) + x1 := min(x0+1, sourceWidth-1) + wx := float32(srcX - float64(x0)) + x0 = max(x0, 0) + + t00 := (y0*sourceWidth + x0) * hidden + t01 := (y0*sourceWidth + x1) * hidden + t10 := (y1*sourceWidth + x0) * hidden + t11 := (y1*sourceWidth + x1) * hidden + dst := (oy*targetWidth + ox) * hidden + for h := range hidden { + v00 := values[t00+h] + v01 := values[t01+h] + v10 := values[t10+h] + v11 := values[t11+h] + top := v00 + (v01-v00)*wx + bot := v10 + (v11-v10)*wx + out[dst+h] = top + (bot-top)*wy + } + } + } + + return out +} + +func newVisionModel(c fs.Config) *VisionModel { + return &VisionModel{ + Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)), + VisionOptions: &VisionOptions{ + hiddenSize: int(c.Uint("vision.embedding_length", 1280)), + numHeads: int(c.Uint("vision.attention.head_count", 16)), + imageSize: int(c.Uint("vision.image_size", 512)), + patchSize: int(c.Uint("vision.patch_size", 16)), + eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6), + }, + } +} + +type MultiModalProjector struct { + Norm *nn.RMSNorm `gguf:"norm"` + Linear1 *nn.Linear `gguf:"1"` + Linear2 *nn.Linear `gguf:"2"` + + scaleFactor int +} + +func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, patches visionPatchGrid) ml.Tensor { + scaleFactor := max(p.scaleFactor, 1) + + // The reference projector first pixel-shuffles the vision grid with + // downsample_ratio=0.5 before applying the RMSNorm/MLP. Preserve that exact + // v2 packing order here rather than flattening 2x2 neighborhoods via IM2Col. + merged := pixelShuffleVisionOutputs(ctx, visionOutputs, patches, scaleFactor) + + merged = p.Norm.Forward(ctx, merged, 1e-5) + merged = p.Linear1.Forward(ctx, merged) + merged = merged.RELU(ctx) + merged = merged.Mul(ctx, merged) + return p.Linear2.Forward(ctx, merged) +} + +func pixelShuffleVisionOutputs(ctx ml.Context, visionOutputs ml.Tensor, patches visionPatchGrid, scaleFactor int) ml.Tensor { + hiddenSize := visionOutputs.Dim(0) + scaleFactor = max(scaleFactor, 1) + + merged := visionOutputs.Reshape(ctx, hiddenSize, patches.Width, patches.Height, 1) + + width := patches.Width / scaleFactor + height := patches.Height / scaleFactor + channels := hiddenSize * scaleFactor + + merged = merged.Reshape(ctx, channels, width, patches.Height, 1) + merged = merged.Reshape(ctx, channels, width, scaleFactor, height) + merged = merged.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + return merged.Reshape(ctx, channels*scaleFactor, width*height, 1) +} + +func newMultiModalProjector(c fs.Config) *MultiModalProjector { + return &MultiModalProjector{ + scaleFactor: int(c.Uint("vision.projector.scale_factor", 2)), + } +} diff --git a/model/models/nemotronh/process_audio.go b/model/models/nemotronh/process_audio.go new file mode 100644 index 000000000..90a889d60 --- /dev/null +++ b/model/models/nemotronh/process_audio.go @@ -0,0 +1,328 @@ +package nemotronh + +import ( + "encoding/binary" + "fmt" + "math" + "math/cmplx" +) + +const ( + parakeetHopLength = 160 + parakeetNFFT = 512 + parakeetWinLength = 400 + parakeetPreemphasis = 0.97 + parakeetLogZeroGuardValue = 1.0 / (1 << 24) + parakeetNormalizeEps = 1e-5 +) + +func isAudioData(data []byte) bool { + return len(data) >= 12 && string(data[:4]) == "RIFF" && string(data[8:12]) == "WAVE" +} + +func decodeWAV(data []byte, targetSampleRate int) ([]float32, error) { + if len(data) < 12 { + return nil, fmt.Errorf("WAV file too short") + } + if !isAudioData(data) { + return nil, fmt.Errorf("not a WAV file") + } + + var audioFormat uint16 + var numChannels, sampleRate, bitsPerSample int + var audioData []byte + foundFmt := false + + offset := 12 + for offset+8 <= len(data) { + chunkID := string(data[offset : offset+4]) + chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8])) + chunkEnd := min(offset+8+chunkSize, len(data)) + chunkData := data[offset+8 : chunkEnd] + + switch chunkID { + case "fmt ": + if len(chunkData) < 16 { + return nil, fmt.Errorf("fmt chunk too short") + } + audioFormat = binary.LittleEndian.Uint16(chunkData[0:2]) + numChannels = int(binary.LittleEndian.Uint16(chunkData[2:4])) + sampleRate = int(binary.LittleEndian.Uint32(chunkData[4:8])) + bitsPerSample = int(binary.LittleEndian.Uint16(chunkData[14:16])) + if audioFormat == 0xfffe && len(chunkData) >= 26 { + audioFormat = binary.LittleEndian.Uint16(chunkData[24:26]) + } + foundFmt = true + case "data": + audioData = chunkData + } + + offset += 8 + chunkSize + if chunkSize%2 != 0 { + offset++ + } + } + + if !foundFmt { + return nil, fmt.Errorf("no fmt chunk found in WAV file") + } + if audioFormat != 1 && audioFormat != 3 { + return nil, fmt.Errorf("unsupported WAV format: %d (need PCM=1 or float=3)", audioFormat) + } + if audioData == nil { + return nil, fmt.Errorf("no data chunk found in WAV file") + } + if numChannels <= 0 { + return nil, fmt.Errorf("invalid WAV channel count: %d", numChannels) + } + + samples := decodeWAVSamples(audioData, audioFormat, bitsPerSample, numChannels) + if sampleRate != targetSampleRate { + samples = resampleLinear(samples, sampleRate, targetSampleRate) + } + return samples, nil +} + +func decodeWAVSamples(data []byte, format uint16, bits, channels int) []float32 { + bytesPerSample := bits / 8 + if bytesPerSample <= 0 || channels <= 0 { + return nil + } + totalSamples := len(data) / (bytesPerSample * channels) + mono := make([]float32, totalSamples) + + for i := range totalSamples { + var sum float64 + for ch := range channels { + off := (i*channels + ch) * bytesPerSample + if off+bytesPerSample > len(data) { + break + } + switch { + case format == 1 && bits == 16: + v := int16(binary.LittleEndian.Uint16(data[off : off+2])) + sum += float64(v) / 32768.0 + case format == 1 && bits == 32: + v := int32(binary.LittleEndian.Uint32(data[off : off+4])) + sum += float64(v) / 2147483648.0 + case format == 1 && bits == 24: + v := int32(data[off]) | int32(data[off+1])<<8 | int32(data[off+2])<<16 + if v&0x800000 != 0 { + v |= ^0xffffff + } + sum += float64(v) / 8388608.0 + case format == 3 && bits == 32: + sum += float64(math.Float32frombits(binary.LittleEndian.Uint32(data[off : off+4]))) + case format == 1 && bits == 8: + sum += (float64(data[off]) - 128.0) / 128.0 + } + } + mono[i] = float32(sum / float64(channels)) + } + return mono +} + +func resampleLinear(samples []float32, fromRate, toRate int) []float32 { + if fromRate <= 0 || toRate <= 0 || len(samples) == 0 { + return samples + } + n := int(float64(len(samples)) / float64(fromRate) * float64(toRate)) + if n <= 1 { + return slicesCloneOne(samples) + } + out := make([]float32, n) + for i := range n { + pos := float64(i) * float64(len(samples)-1) / float64(n-1) + idx := int(pos) + frac := float32(pos - float64(idx)) + if idx+1 < len(samples) { + out[i] = samples[idx]*(1-frac) + samples[idx+1]*frac + } else { + out[i] = samples[idx] + } + } + return out +} + +func slicesCloneOne(samples []float32) []float32 { + if len(samples) == 0 { + return nil + } + return []float32{samples[0]} +} + +func computeParakeetMelSpectrogram(samples []float32, extractor *AudioFeatureExtractor, opts *AudioOptions) ([]float32, int, int, error) { + if len(samples) == 0 { + return nil, 0, 0, fmt.Errorf("audio too short to encode") + } + if opts == nil { + opts = defaultAudioOptions() + } + + melBins := opts.melBins + freqBins := parakeetNFFT/2 + 1 + window, melFilters := extractor.windowAndFilters(melBins, freqBins, opts.sampleRate) + if len(window) != parakeetWinLength { + return nil, 0, 0, fmt.Errorf("invalid Parakeet window length: %d", len(window)) + } + if len(melFilters) != melBins*freqBins { + return nil, 0, 0, fmt.Errorf("invalid Parakeet mel filter shape: %d", len(melFilters)) + } + + emphasized := make([]float32, len(samples)) + emphasized[0] = samples[0] + for i := 1; i < len(samples); i++ { + emphasized[i] = samples[i] - parakeetPreemphasis*samples[i-1] + } + + frames := len(samples)/parakeetHopLength + 1 + validFrames := max(1, len(samples)/parakeetHopLength) + if validFrames > frames { + validFrames = frames + } + + result := make([]float32, frames*melBins) + fftInput := make([]complex128, parakeetNFFT) + winOffset := (parakeetNFFT - parakeetWinLength) / 2 + centerPad := parakeetNFFT / 2 + + for frame := range frames { + for i := range parakeetNFFT { + fftInput[i] = 0 + } + for i := range parakeetWinLength { + src := frame*parakeetHopLength + i + winOffset - centerPad + if src >= 0 && src < len(emphasized) { + fftInput[i+winOffset] = complex(float64(emphasized[src])*float64(window[i]), 0) + } + } + + fft(fftInput) + + for mel := range melBins { + var v float64 + filterOffset := mel * freqBins + for freq := range freqBins { + mag := cmplx.Abs(fftInput[freq]) + v += float64(melFilters[filterOffset+freq]) * mag * mag + } + result[frame*melBins+mel] = float32(math.Log(v + parakeetLogZeroGuardValue)) + } + } + + for mel := range melBins { + var sum float64 + for frame := range validFrames { + sum += float64(result[frame*melBins+mel]) + } + mean := sum / float64(validFrames) + + var variance float64 + for frame := range validFrames { + d := float64(result[frame*melBins+mel]) - mean + variance += d * d + } + denom := max(1, validFrames-1) + std := math.Sqrt(variance / float64(denom)) + + for frame := range frames { + idx := frame*melBins + mel + if frame >= validFrames { + result[idx] = 0 + continue + } + result[idx] = float32((float64(result[idx]) - mean) / (std + parakeetNormalizeEps)) + } + } + + return result, frames, validFrames, nil +} + +func defaultParakeetWindow() []float32 { + window := make([]float32, parakeetWinLength) + for i := range window { + window[i] = float32(0.5 - 0.5*math.Cos(2*math.Pi*float64(i)/float64(parakeetWinLength-1))) + } + return window +} + +func buildSlaneyMelFilterBank(numFreqBins, numMels int, sampleRate int) []float32 { + hzToMel := func(f float64) float64 { + if f < 1000 { + return 3 * f / 200 + } + return 15 + math.Log(f/1000)*27/math.Log(6.4) + } + melToHz := func(m float64) float64 { + if m < 15 { + return 200 * m / 3 + } + return 1000 * math.Exp(math.Log(6.4)*(m-15)/27) + } + + minMel := hzToMel(0) + maxMel := hzToMel(float64(sampleRate) / 2) + mels := make([]float64, numMels+2) + freqs := make([]float64, numMels+2) + for i := range mels { + mels[i] = minMel + (maxMel-minMel)*float64(i)/float64(numMels+1) + freqs[i] = melToHz(mels[i]) + } + + fftFreqs := make([]float64, numFreqBins) + for i := range fftFreqs { + fftFreqs[i] = float64(i) * float64(sampleRate) / float64(parakeetNFFT) + } + + filters := make([]float32, numMels*numFreqBins) + for mel := range numMels { + left, center, right := freqs[mel], freqs[mel+1], freqs[mel+2] + enorm := 2.0 / (right - left) + for freq, fftFreq := range fftFreqs { + var lower, upper float64 + if center > left { + lower = (fftFreq - left) / (center - left) + } + if right > center { + upper = (right - fftFreq) / (right - center) + } + v := math.Max(0, math.Min(lower, upper)) + filters[mel*numFreqBins+freq] = float32(v * enorm) + } + } + return filters +} + +func fft(x []complex128) { + n := len(x) + if n <= 1 { + return + } + + j := 0 + for i := 1; i < n; i++ { + bit := n >> 1 + for j&bit != 0 { + j ^= bit + bit >>= 1 + } + j ^= bit + if i < j { + x[i], x[j] = x[j], x[i] + } + } + + for size := 2; size <= n; size <<= 1 { + halfSize := size / 2 + w := complex(math.Cos(2*math.Pi/float64(size)), -math.Sin(2*math.Pi/float64(size))) + for start := 0; start < n; start += size { + wn := complex(1, 0) + for k := range halfSize { + t := wn * x[start+k+halfSize] + x[start+k+halfSize] = x[start+k] - t + x[start+k] = x[start+k] + t + wn *= w + } + } + } +} diff --git a/model/parsers/laguna.go b/model/parsers/laguna.go new file mode 100644 index 000000000..91e9b150c --- /dev/null +++ b/model/parsers/laguna.go @@ -0,0 +1,498 @@ +package parsers + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" + "unicode" + + "github.com/ollama/ollama/api" +) + +const ( + lagunaThinkingOpenTag = "" + lagunaThinkingCloseTag = "" + lagunaToolCallOpenTag = "" + lagunaToolCallCloseTag = "" + lagunaUserOpenTag = "" + lagunaUserCloseTag = "" +) + +type lagunaParserState int + +const ( + lagunaParserStateThinking lagunaParserState = iota + lagunaParserStateContent + lagunaParserStateTool +) + +type LagunaParser struct { + state lagunaParserState + buffer strings.Builder + tools []api.Tool + callIndex int + thinkingEnabled bool + thinkingSuppressed bool + allowLeadingThinkOpen bool +} + +func (p *LagunaParser) HasToolSupport() bool { + return true +} + +func (p *LagunaParser) HasThinkingSupport() bool { + return true +} + +func (p *LagunaParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.tools = tools + p.callIndex = 0 + p.buffer.Reset() + p.thinkingEnabled = thinkValue == nil || thinkValue.Bool() + p.thinkingSuppressed = thinkValue != nil && !thinkValue.Bool() + p.state = lagunaParserStateContent + p.allowLeadingThinkOpen = false + return tools +} + +func (p *LagunaParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + var contentSB, thinkingSB strings.Builder + + for { + progress := false + switch p.state { + case lagunaParserStateThinking: + progress, thinking = p.consumeThinking(done) + if p.thinkingEnabled { + thinkingSB.WriteString(thinking) + } + case lagunaParserStateContent: + var parsedCalls []api.ToolCall + progress, content, parsedCalls, err = p.consumeContent(done) + if err != nil { + return "", "", nil, err + } + contentSB.WriteString(content) + calls = append(calls, parsedCalls...) + case lagunaParserStateTool: + var call api.ToolCall + progress, call, err = p.consumeTool(done) + if err != nil { + return "", "", nil, err + } + if progress { + calls = append(calls, call) + } + } + if !progress { + break + } + } + + return contentSB.String(), thinkingSB.String(), calls, nil +} + +func (p *LagunaParser) consumeThinking(done bool) (bool, string) { + acc := p.buffer.String() + if p.allowLeadingThinkOpen { + trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace) + if strings.HasPrefix(trimmed, lagunaThinkingOpenTag) { + p.buffer.Reset() + p.buffer.WriteString(strings.TrimLeftFunc(strings.TrimPrefix(trimmed, lagunaThinkingOpenTag), unicode.IsSpace)) + p.allowLeadingThinkOpen = false + return true, "" + } + if strings.HasPrefix(lagunaThinkingOpenTag, trimmed) && !done { + return false, "" + } + p.allowLeadingThinkOpen = false + } + + if idx := strings.Index(acc, lagunaThinkingCloseTag); idx != -1 { + thinking := acc[:idx] + after := strings.TrimLeftFunc(acc[idx+len(lagunaThinkingCloseTag):], unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(after) + p.state = lagunaParserStateContent + return true, thinking + } + if idx := strings.Index(acc, lagunaToolCallOpenTag); idx != -1 { + thinking := strings.TrimRightFunc(acc[:idx], unicode.IsSpace) + after := acc[idx+len(lagunaToolCallOpenTag):] + p.buffer.Reset() + p.buffer.WriteString(after) + p.state = lagunaParserStateTool + return true, thinking + } + if done { + p.buffer.Reset() + p.state = lagunaParserStateContent + return acc != "", acc + } + + overlapLen := max(overlap(acc, lagunaThinkingCloseTag), overlap(acc, lagunaToolCallOpenTag)) + trailingLen := trailingWhitespaceLen(acc) + keep := max(overlapLen, trailingLen) + if keep > 0 && keep < len(acc) { + emit := acc[:len(acc)-keep] + p.buffer.Reset() + p.buffer.WriteString(acc[len(acc)-keep:]) + return emit != "", emit + } + return false, "" +} + +func (p *LagunaParser) consumeContent(done bool) (bool, string, []api.ToolCall, error) { + acc := p.buffer.String() + if p.thinkingEnabled || p.thinkingSuppressed { + if idx := strings.Index(acc, lagunaThinkingOpenTag); idx != -1 { + content := acc[:idx] + after := strings.TrimLeftFunc(acc[idx+len(lagunaThinkingOpenTag):], unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(after) + p.state = lagunaParserStateThinking + p.allowLeadingThinkOpen = false + return true, content, nil, nil + } + if !done { + overlapLen := overlap(acc, lagunaThinkingOpenTag) + if overlapLen > 0 && overlapLen < len(acc) { + content := acc[:len(acc)-overlapLen] + p.buffer.Reset() + p.buffer.WriteString(acc[len(acc)-overlapLen:]) + return content != "", content, nil, nil + } + } + } + if p.thinkingEnabled { + trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace) + if strings.HasPrefix(trimmed, lagunaThinkingCloseTag) { + p.buffer.Reset() + p.buffer.WriteString(strings.TrimLeftFunc(strings.TrimPrefix(trimmed, lagunaThinkingCloseTag), unicode.IsSpace)) + return true, "", nil, nil + } + if strings.HasPrefix(lagunaThinkingCloseTag, trimmed) && !done { + return false, "", nil, nil + } + } + if p.thinkingSuppressed { + trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace) + if strings.HasPrefix(trimmed, lagunaThinkingCloseTag) { + p.buffer.Reset() + p.buffer.WriteString(strings.TrimLeftFunc(strings.TrimPrefix(trimmed, lagunaThinkingCloseTag), unicode.IsSpace)) + return true, "", nil, nil + } + if strings.HasPrefix(lagunaThinkingCloseTag, trimmed) && !done { + return false, "", nil, nil + } + } + if idx := strings.Index(acc, lagunaToolCallOpenTag); idx != -1 { + content := strings.TrimRightFunc(acc[:idx], unicode.IsSpace) + after := acc[idx+len(lagunaToolCallOpenTag):] + p.buffer.Reset() + p.buffer.WriteString(after) + p.state = lagunaParserStateTool + return true, content, nil, nil + } + if idx := strings.Index(acc, lagunaUserOpenTag); idx != -1 && len(p.tools) > 0 { + before := strings.TrimRightFunc(acc[:idx], unicode.IsSpace) + afterOpen := acc[idx+len(lagunaUserOpenTag):] + if closeIdx := strings.Index(afterOpen, lagunaUserCloseTag); closeIdx != -1 { + raw := afterOpen[:closeIdx] + if call, ok := p.parseToolAlias(raw); ok { + after := strings.TrimLeftFunc(afterOpen[closeIdx+len(lagunaUserCloseTag):], unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(after) + return true, before, []api.ToolCall{call}, nil + } + } else if !done { + if idx > 0 { + p.buffer.Reset() + p.buffer.WriteString(acc[idx:]) + return true, before, nil, nil + } + return false, "", nil, nil + } + } + if len(p.tools) > 0 { + if progress, content, call, ok, err := p.consumeStandaloneJSONTool(done); ok || err != nil { + if err != nil { + return false, "", nil, err + } + if progress { + return true, content, []api.ToolCall{call}, nil + } + return false, "", nil, nil + } + } + if done { + p.buffer.Reset() + return acc != "", acc, nil, nil + } + overlapLen := max(overlap(acc, lagunaToolCallOpenTag), overlap(acc, lagunaUserOpenTag)) + if p.thinkingEnabled || p.thinkingSuppressed { + overlapLen = max(overlapLen, overlap(acc, lagunaThinkingOpenTag)) + } + if p.thinkingSuppressed { + overlapLen = max(overlapLen, overlap(acc, lagunaThinkingCloseTag)) + } + trailingLen := trailingWhitespaceLen(acc) + keep := max(overlapLen, trailingLen) + if keep > 0 && keep < len(acc) { + emit := acc[:len(acc)-keep] + p.buffer.Reset() + p.buffer.WriteString(acc[len(acc)-keep:]) + return emit != "", emit, nil, nil + } + if keep == 0 && acc != "" { + p.buffer.Reset() + return true, acc, nil, nil + } + return false, "", nil, nil +} + +func (p *LagunaParser) consumeStandaloneJSONTool(done bool) (progress bool, content string, call api.ToolCall, ok bool, err error) { + acc := p.buffer.String() + jsonIdx := strings.Index(acc, "{") + if jsonIdx == -1 { + return false, "", api.ToolCall{}, false, nil + } + + before := strings.TrimRightFunc(acc[:jsonIdx], unicode.IsSpace) + raw := strings.TrimLeftFunc(acc[jsonIdx:], unicode.IsSpace) + if !lagunaLooksLikeJSONToolCall(raw, done) { + return false, "", api.ToolCall{}, false, nil + } + + if !done && !json.Valid([]byte(strings.TrimSpace(raw))) { + if before != "" { + p.buffer.Reset() + p.buffer.WriteString(acc[jsonIdx:]) + return true, before, api.ToolCall{}, true, nil + } + return false, "", api.ToolCall{}, true, nil + } + + call, err = parseLagunaToolCall(raw, p.tools) + if err != nil { + return false, "", api.ToolCall{}, true, err + } + call.Function.Index = p.callIndex + p.callIndex++ + p.buffer.Reset() + p.state = lagunaParserStateContent + return true, before, call, true, nil +} + +func lagunaLooksLikeJSONToolCall(raw string, done bool) bool { + trimmed := strings.TrimLeftFunc(raw, unicode.IsSpace) + if !strings.HasPrefix(trimmed, "{") { + return false + } + if strings.Contains(trimmed, `"name"`) || strings.Contains(trimmed, `"arguments"`) { + return true + } + if done { + return false + } + return strings.HasPrefix(trimmed, `{"`) || strings.HasPrefix(trimmed, "{\n") || strings.HasPrefix(trimmed, "{\r\n") +} + +func (p *LagunaParser) parseToolAlias(raw string) (api.ToolCall, bool) { + raw = cleanLagunaToolCallRaw(raw) + name, ok := lagunaToolCallName(raw) + if !ok { + return api.ToolCall{}, false + } + if _, ok := lagunaResolveToolName(name, p.tools); !ok { + return api.ToolCall{}, false + } + call, err := parseLagunaToolCall(raw, p.tools) + if err != nil { + return api.ToolCall{}, false + } + call.Function.Index = p.callIndex + p.callIndex++ + return call, true +} + +func lagunaResolveToolName(name string, tools []api.Tool) (string, bool) { + for i := range tools { + if tools[i].Function.Name == name { + return name, true + } + } + + aliases := map[string]string{ + "read_file": "read", + "write_file": "write", + "edit_file": "edit", + "web_fetch": "webfetch", + } + if alias, ok := aliases[name]; ok { + for i := range tools { + if tools[i].Function.Name == alias { + return alias, true + } + } + } + return name, false +} + +func cleanLagunaToolCallRaw(raw string) string { + raw = strings.TrimSpace(raw) + for strings.HasPrefix(raw, lagunaToolCallOpenTag) { + raw = strings.TrimSpace(strings.TrimPrefix(raw, lagunaToolCallOpenTag)) + } + if idx := strings.Index(raw, lagunaToolCallCloseTag); idx != -1 { + raw = strings.TrimSpace(raw[:idx]) + } + if idx := strings.Index(raw, lagunaToolCallOpenTag); idx != -1 { + before := strings.TrimSpace(raw[:idx]) + if before != "" { + return before + } + raw = strings.TrimSpace(raw[idx+len(lagunaToolCallOpenTag):]) + } + return raw +} + +func lagunaToolCallName(raw string) (string, bool) { + raw = cleanLagunaToolCallRaw(raw) + if strings.HasPrefix(raw, "{") { + var parsed struct { + Name string `json:"name"` + } + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + return "", false + } + name := strings.TrimSpace(parsed.Name) + return name, name != "" + } + + nameEnd := strings.Index(raw, "") + if nameEnd < 0 { + nameEnd = strings.Index(raw, "{") + } + if nameEnd < 0 { + nameEnd = strings.IndexAny(raw, "\r\n") + } + if nameEnd < 0 { + nameEnd = len(raw) + } + name := strings.TrimSpace(raw[:nameEnd]) + return name, name != "" +} + +func (p *LagunaParser) consumeTool(done bool) (bool, api.ToolCall, error) { + acc := p.buffer.String() + if idx := strings.Index(acc, lagunaToolCallCloseTag); idx != -1 { + raw := acc[:idx] + after := strings.TrimLeftFunc(acc[idx+len(lagunaToolCallCloseTag):], unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(after) + p.state = lagunaParserStateContent + call, err := parseLagunaToolCall(raw, p.tools) + if err != nil { + return false, api.ToolCall{}, err + } + call.Function.Index = p.callIndex + p.callIndex++ + return true, call, nil + } + if done && strings.TrimSpace(acc) != "" { + p.buffer.Reset() + p.state = lagunaParserStateContent + call, err := parseLagunaToolCall(acc, p.tools) + if err != nil { + return false, api.ToolCall{}, err + } + call.Function.Index = p.callIndex + p.callIndex++ + return true, call, nil + } + return false, api.ToolCall{}, nil +} + +var lagunaArgRE = regexp.MustCompile(`(?s)(.*?)\s*(.*?)`) + +func parseLagunaToolCall(raw string, tools []api.Tool) (api.ToolCall, error) { + raw = cleanLagunaToolCallRaw(raw) + if strings.HasPrefix(raw, "{") { + var parsed struct { + Name string `json:"name"` + Arguments api.ToolCallFunctionArguments `json:"arguments"` + } + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + return api.ToolCall{}, fmt.Errorf("failed to parse Laguna JSON tool call: %w", err) + } + if parsed.Name == "" { + return api.ToolCall{}, fmt.Errorf("empty Laguna tool call name") + } + if name, ok := lagunaResolveToolName(parsed.Name, tools); ok { + parsed.Name = name + } + return api.ToolCall{ + Function: api.ToolCallFunction{ + Name: parsed.Name, + Arguments: parsed.Arguments, + }, + }, nil + } + + nameEnd := strings.Index(raw, "") + name := raw + argsText := "" + if nameEnd >= 0 { + name = raw[:nameEnd] + argsText = raw[nameEnd:] + } else if jsonStart := strings.Index(raw, "{"); jsonStart >= 0 { + name = raw[:jsonStart] + argsText = raw[jsonStart:] + } + name = strings.TrimSpace(name) + if resolved, ok := lagunaResolveToolName(name, tools); ok { + name = resolved + } + + var matchedTool *api.Tool + for i := range tools { + if tools[i].Function.Name == name { + matchedTool = &tools[i] + break + } + } + + call := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: name, + Arguments: api.NewToolCallFunctionArguments(), + }, + } + if strings.HasPrefix(strings.TrimSpace(argsText), "{") { + if err := json.Unmarshal([]byte(strings.TrimSpace(argsText)), &call.Function.Arguments); err != nil { + return api.ToolCall{}, fmt.Errorf("failed to parse Laguna JSON tool call arguments: %w", err) + } + return call, nil + } + for _, match := range lagunaArgRE.FindAllStringSubmatch(argsText, -1) { + key := strings.TrimSpace(match[1]) + value := match[2] + var paramType api.PropertyType + if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil { + if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok { + if len(prop.AnyOf) > 0 { + for _, anyOfProp := range prop.AnyOf { + paramType = append(paramType, anyOfProp.Type...) + } + } else { + paramType = prop.Type + } + } + } + call.Function.Arguments.Set(key, parseValue(value, paramType)) + } + return call, nil +} diff --git a/model/parsers/laguna_test.go b/model/parsers/laguna_test.go new file mode 100644 index 000000000..73fb42487 --- /dev/null +++ b/model/parsers/laguna_test.go @@ -0,0 +1,484 @@ +package parsers + +import ( + "testing" + + "github.com/ollama/ollama/api" +) + +func lagunaTestTools() []api.Tool { + props := api.NewToolPropertiesMap() + props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}}) + props.Set("days", api.ToolProperty{Type: api.PropertyType{"integer"}}) + return []api.Tool{{ + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Properties: props, + }, + }, + }} +} + +func TestLagunaParserToolCall(t *testing.T) { + parser := ParserForName("laguna") + if parser == nil { + t.Fatal("expected laguna parser") + } + if !parser.HasToolSupport() || !parser.HasThinkingSupport() { + t.Fatal("laguna parser should advertise tools and thinking") + } + + parser.Init(lagunaTestTools(), nil, nil) + content, thinking, calls, err := parser.Add("get_weather\nlocation\nParis\ndays\n3\n", true) + if err != nil { + t.Fatal(err) + } + if content != "" || thinking != "" { + t.Fatalf("content=%q thinking=%q, want empty", content, thinking) + } + if len(calls) != 1 { + t.Fatalf("calls=%d, want 1", len(calls)) + } + if calls[0].Function.Name != "get_weather" { + t.Fatalf("name=%q, want get_weather", calls[0].Function.Name) + } + if got, _ := calls[0].Function.Arguments.Get("location"); got != "Paris" { + t.Fatalf("location=%v, want Paris", got) + } + if got, _ := calls[0].Function.Arguments.Get("days"); got != 3 { + t.Fatalf("days=%v, want 3", got) + } +} + +func TestLagunaParserJSONToolCall(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(lagunaTestTools(), nil, nil) + + _, _, calls, err := parser.Add("\n{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\",\"days\":3}}\n", true) + if err != nil { + t.Fatal(err) + } + if len(calls) != 1 { + t.Fatalf("calls=%d, want 1", len(calls)) + } + if calls[0].Function.Name != "get_weather" { + t.Fatalf("name=%q, want get_weather", calls[0].Function.Name) + } + if got, _ := calls[0].Function.Arguments.Get("location"); got != "Paris" { + t.Fatalf("location=%v, want Paris", got) + } + if got, _ := calls[0].Function.Arguments.Get("days"); got != float64(3) { + t.Fatalf("days=%v, want 3", got) + } +} + +func TestLagunaParserStandaloneJSONToolCall(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(lagunaTestTools(), nil, nil) + + content, thinking, calls, err := parser.Add("{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\",\"days\":3}}", true) + if err != nil { + t.Fatal(err) + } + if content != "" || thinking != "" { + t.Fatalf("content=%q thinking=%q", content, thinking) + } + if len(calls) != 1 { + t.Fatalf("calls=%d, want 1", len(calls)) + } + if calls[0].Function.Name != "get_weather" { + t.Fatalf("name=%q, want get_weather", calls[0].Function.Name) + } +} + +func TestLagunaParserStandaloneJSONToolCallAfterLeadingContent(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(lagunaTestTools(), nil, nil) + + content, thinking, calls, err := parser.Add("Let me call the weather tool.\n{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}", true) + if err != nil { + t.Fatal(err) + } + if content != "Let me call the weather tool." || thinking != "" { + t.Fatalf("content=%q thinking=%q", content, thinking) + } + if len(calls) != 1 { + t.Fatalf("calls=%d, want 1", len(calls)) + } + if calls[0].Function.Name != "get_weather" { + t.Fatalf("name=%q, want get_weather", calls[0].Function.Name) + } +} + +func TestLagunaParserStreamingStandaloneJSONToolCall(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(lagunaTestTools(), nil, nil) + + content, thinking, calls, err := parser.Add("{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco,", false) + if err != nil { + t.Fatal(err) + } + if content != "" || thinking != "" || len(calls) != 0 { + t.Fatalf("first chunk content=%q thinking=%q calls=%d", content, thinking, len(calls)) + } + + content, thinking, calls, err = parser.Add(" CA\"}}", true) + if err != nil { + t.Fatal(err) + } + if content != "" || thinking != "" || len(calls) != 1 { + t.Fatalf("second chunk content=%q thinking=%q calls=%d", content, thinking, len(calls)) + } + if calls[0].Function.Name != "get_weather" { + t.Fatalf("name=%q, want get_weather", calls[0].Function.Name) + } + if got, _ := calls[0].Function.Arguments.Get("location"); got != "San Francisco, CA" { + t.Fatalf("location=%v, want San Francisco, CA", got) + } +} + +func TestLagunaParserNameLineJSONToolCall(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(lagunaTestTools(), nil, nil) + + _, _, calls, err := parser.Add("get_weather\n{\"location\":\"San Francisco\"}", true) + if err != nil { + t.Fatal(err) + } + if len(calls) != 1 { + t.Fatalf("calls=%d, want 1", len(calls)) + } + if calls[0].Function.Name != "get_weather" { + t.Fatalf("name=%q, want get_weather", calls[0].Function.Name) + } + if got, _ := calls[0].Function.Arguments.Get("location"); got != "San Francisco" { + t.Fatalf("location=%v, want San Francisco", got) + } +} + +func TestLagunaParserNormalizesCommonToolAliases(t *testing.T) { + props := api.NewToolPropertiesMap() + props.Set("path", api.ToolProperty{Type: api.PropertyType{"string"}}) + tools := []api.Tool{{ + Function: api.ToolFunction{ + Name: "read", + Parameters: api.ToolFunctionParameters{ + Properties: props, + }, + }, + }} + + parser := ParserForName("laguna") + parser.Init(tools, nil, nil) + + _, _, calls, err := parser.Add("\n{\"name\":\"read_file\",\"arguments\":{\"path\":\"./go.mod\"}}\n", true) + if err != nil { + t.Fatal(err) + } + if len(calls) != 1 { + t.Fatalf("calls=%d, want 1", len(calls)) + } + if calls[0].Function.Name != "read" { + t.Fatalf("name=%q, want read", calls[0].Function.Name) + } + if got, _ := calls[0].Function.Arguments.Get("path"); got != "./go.mod" { + t.Fatalf("path=%v, want ./go.mod", got) + } +} + +func TestLagunaParserIgnoresDuplicatedNestedToolCall(t *testing.T) { + props := api.NewToolPropertiesMap() + props.Set("name", api.ToolProperty{Type: api.PropertyType{"string"}}) + tools := []api.Tool{{ + Function: api.ToolFunction{ + Name: "skill", + Parameters: api.ToolFunctionParameters{ + Properties: props, + }, + }, + }} + + parser := ParserForName("laguna") + parser.Init(tools, nil, nil) + + _, _, calls, err := parser.Add("skill\n{\"name\":\"git-diff-review\"}\nskill\n{\"name\":\"git-diff-review\"}", true) + if err != nil { + t.Fatal(err) + } + if len(calls) != 1 { + t.Fatalf("calls=%d, want 1", len(calls)) + } + if calls[0].Function.Name != "skill" { + t.Fatalf("name=%q, want skill", calls[0].Function.Name) + } + if got, _ := calls[0].Function.Arguments.Get("name"); got != "git-diff-review" { + t.Fatalf("name arg=%v, want git-diff-review", got) + } +} + +func TestLagunaParserThinkingThenTool(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(lagunaTestTools(), nil, &api.ThinkValue{Value: true}) + + content, thinking, calls, err := parser.Add("Need current weather.\nget_weather\nlocation\nSF\n", true) + if err != nil { + t.Fatal(err) + } + if content != "" { + t.Fatalf("content=%q, want empty", content) + } + if thinking != "Need current weather." { + t.Fatalf("thinking=%q, want reasoning", thinking) + } + if len(calls) != 1 || calls[0].Function.Name != "get_weather" { + t.Fatalf("unexpected calls: %#v", calls) + } +} + +func TestLagunaParserUserTaggedToolAlias(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(lagunaTestTools(), nil, nil) + + content, thinking, calls, err := parser.Add("get_weather\nlocation\nSan Francisco, CA\n", true) + if err != nil { + t.Fatal(err) + } + if content != "" || thinking != "" { + t.Fatalf("content=%q thinking=%q, want empty", content, thinking) + } + if len(calls) != 1 { + t.Fatalf("calls=%d, want 1", len(calls)) + } + if calls[0].Function.Name != "get_weather" { + t.Fatalf("name=%q, want get_weather", calls[0].Function.Name) + } + if got, _ := calls[0].Function.Arguments.Get("location"); got != "San Francisco, CA" { + t.Fatalf("location=%v, want San Francisco, CA", got) + } +} + +func TestLagunaParserUserTaggedToolAliasWithLeadingContent(t *testing.T) { + parser := ParserForName("laguna") + props := api.NewToolPropertiesMap() + props.Set("path", api.ToolProperty{Type: api.PropertyType{"string"}}) + tools := []api.Tool{{ + Function: api.ToolFunction{ + Name: "read", + Parameters: api.ToolFunctionParameters{ + Properties: props, + }, + }, + }} + parser.Init(tools, nil, nil) + + content, thinking, calls, err := parser.Add("I'll read the file for you.\nread\npath\n/Users/test/code/myproject/go.mod\n", true) + if err != nil { + t.Fatal(err) + } + if content != "I'll read the file for you." || thinking != "" { + t.Fatalf("content=%q thinking=%q", content, thinking) + } + if len(calls) != 1 { + t.Fatalf("calls=%d, want 1", len(calls)) + } + if calls[0].Function.Name != "read" { + t.Fatalf("name=%q, want read", calls[0].Function.Name) + } + if got, _ := calls[0].Function.Arguments.Get("path"); got != "/Users/test/code/myproject/go.mod" { + t.Fatalf("path=%v, want /Users/test/code/myproject/go.mod", got) + } +} + +func TestLagunaParserUserTaggedJSONToolCallWithLeadingContent(t *testing.T) { + parser := ParserForName("laguna") + props := api.NewToolPropertiesMap() + props.Set("command", api.ToolProperty{Type: api.PropertyType{"string"}}) + tools := []api.Tool{{ + Function: api.ToolFunction{ + Name: "bash", + Parameters: api.ToolFunctionParameters{ + Properties: props, + }, + }, + }} + parser.Init(tools, nil, nil) + + content, thinking, calls, err := parser.Add("I'll run git diff for you.\n{\"name\":\"bash\",\"arguments\":{\"command\":\"git diff main\"}}\n", true) + if err != nil { + t.Fatal(err) + } + if content != "I'll run git diff for you." || thinking != "" { + t.Fatalf("content=%q thinking=%q", content, thinking) + } + if len(calls) != 1 { + t.Fatalf("calls=%d, want 1", len(calls)) + } + if calls[0].Function.Name != "bash" { + t.Fatalf("name=%q, want bash", calls[0].Function.Name) + } + if got, _ := calls[0].Function.Arguments.Get("command"); got != "git diff main" { + t.Fatalf("command=%v, want git diff main", got) + } +} + +func TestLagunaParserStreamingUserTaggedToolAliasAfterContent(t *testing.T) { + parser := ParserForName("laguna") + props := api.NewToolPropertiesMap() + props.Set("path", api.ToolProperty{Type: api.PropertyType{"string"}}) + tools := []api.Tool{{ + Function: api.ToolFunction{ + Name: "read", + Parameters: api.ToolFunctionParameters{ + Properties: props, + }, + }, + }} + parser.Init(tools, nil, nil) + + content, thinking, calls, err := parser.Add("I'll read the file for you.read\npath\n/Users/test/code/myproject/go.mod\n", true) + if err != nil { + t.Fatal(err) + } + if content != "" || thinking != "" { + t.Fatalf("second chunk content=%q thinking=%q", content, thinking) + } + if len(calls) != 1 { + t.Fatalf("calls=%d, want 1", len(calls)) + } + if calls[0].Function.Name != "read" { + t.Fatalf("name=%q, want read", calls[0].Function.Name) + } + if got, _ := calls[0].Function.Arguments.Get("path"); got != "/Users/test/code/myproject/go.mod" { + t.Fatalf("path=%v, want /Users/test/code/myproject/go.mod", got) + } +} + +func TestLagunaParserUserTaggedNonToolContent(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(lagunaTestTools(), nil, nil) + + content, thinking, calls, err := parser.Add("hello", true) + if err != nil { + t.Fatal(err) + } + if content != "hello" || thinking != "" || len(calls) != 0 { + t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls)) + } +} + +func TestLagunaParserThinkingDefaultsOn(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(nil, nil, nil) + content, thinking, calls, err := parser.Add("Need to reason.\nDirect answer.", true) + if err != nil { + t.Fatal(err) + } + if content != "Direct answer." || thinking != "Need to reason." || len(calls) != 0 { + t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls)) + } +} + +func TestLagunaParserThinkingDefaultsOnWhenToolsPresent(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(lagunaTestTools(), nil, nil) + content, thinking, calls, err := parser.Add("Need to reason.\nget_weather\nlocation\nParis\n", true) + if err != nil { + t.Fatal(err) + } + if thinking != "Need to reason." || len(calls) != 1 { + t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls)) + } + if content != "" { + t.Fatalf("content=%q, want thinking block suppressed from content when default thinking is enabled", content) + } +} + +func TestLagunaParserThinkingExplicitlyDisabled(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(nil, nil, &api.ThinkValue{Value: false}) + content, thinking, calls, err := parser.Add("Hidden?\nDirect answer.", true) + if err != nil { + t.Fatal(err) + } + if content != "Direct answer." || thinking != "" || len(calls) != 0 { + t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls)) + } +} + +func TestLagunaParserThinkingExplicitlyDisabledDropsLeadingCloseTag(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(nil, nil, &api.ThinkValue{Value: false}) + content, thinking, calls, err := parser.Add("\nTokyo\n", true) + if err != nil { + t.Fatal(err) + } + if content != "Tokyo\n" || thinking != "" || len(calls) != 0 { + t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls)) + } +} + +func TestLagunaParserThinkingEnabledDropsLeadingCloseTag(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(nil, nil, &api.ThinkValue{Value: true}) + content, thinking, calls, err := parser.Add("\nTokyo\n", true) + if err != nil { + t.Fatal(err) + } + if content != "Tokyo\n" || thinking != "" || len(calls) != 0 { + t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls)) + } +} + +func TestLagunaParserThinkingDefaultOnDropsLeadingCloseTag(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(nil, nil, nil) + content, thinking, calls, err := parser.Add("\nTokyo\n", true) + if err != nil { + t.Fatal(err) + } + if content != "Tokyo\n" || thinking != "" || len(calls) != 0 { + t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls)) + } +} + +func TestLagunaParserThinkingEnabledUntaggedAnswerIsContent(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(nil, nil, &api.ThinkValue{Value: true}) + content, thinking, calls, err := parser.Add("Direct answer.", true) + if err != nil { + t.Fatal(err) + } + if content != "Direct answer." || thinking != "" || len(calls) != 0 { + t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls)) + } +} + +func TestLagunaParserSplitToolTag(t *testing.T) { + parser := ParserForName("laguna") + parser.Init(lagunaTestTools(), nil, &api.ThinkValue{Value: true}) + + content, thinking, calls, err := parser.Add("Need lookupget_weather\nlocation\nSF\n", true) + if err != nil { + t.Fatal(err) + } + if content != "" || thinking != "" || len(calls) != 1 { + t.Fatalf("second chunk content=%q thinking=%q calls=%d", content, thinking, len(calls)) + } +} diff --git a/model/parsers/nemotron3nano.go b/model/parsers/nemotron3nano.go index eaa6cb87a..4bb82e5a1 100644 --- a/model/parsers/nemotron3nano.go +++ b/model/parsers/nemotron3nano.go @@ -16,14 +16,17 @@ const ( ) const ( + nemotronThinkOpen = "" nemotronThinkClose = "" nemotronToolCallOpen = "" ) type Nemotron3NanoParser struct { - state Nemotron3NanoParserState - buffer strings.Builder - toolParser *Qwen3CoderParser + state Nemotron3NanoParserState + buffer strings.Builder + toolParser *Qwen3CoderParser + maybeThinkingOpenAtBOL bool + skipThinkingLeadingWS bool } func (p *Nemotron3NanoParser) HasToolSupport() bool { return true } @@ -32,14 +35,18 @@ func (p *Nemotron3NanoParser) HasThinkingSupport() bool { return true } func (p *Nemotron3NanoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { p.toolParser = &Qwen3CoderParser{} p.toolParser.Init(tools, nil, nil) + p.buffer.Reset() + p.maybeThinkingOpenAtBOL = false + p.skipThinkingLeadingWS = false - thinkingEnabled := thinkValue != nil && thinkValue.Bool() + thinkingEnabled := thinkValue == nil || thinkValue.Bool() prefill := lastMessage != nil && lastMessage.Role == "assistant" if !thinkingEnabled || (prefill && lastMessage.Content != "") { p.state = Nemotron3NanoCollectingContent } else { p.state = Nemotron3NanoCollectingThinking + p.maybeThinkingOpenAtBOL = true } return tools @@ -61,6 +68,29 @@ func (p *Nemotron3NanoParser) Add(s string, done bool) (content string, thinking // Nemotron3NanoCollectingThinking - buffer and look for end markers p.buffer.WriteString(s) + if p.skipThinkingLeadingWS { + trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(trimmed) + if trimmed == "" { + return "", "", nil, nil + } + p.skipThinkingLeadingWS = false + } + if p.stripOpeningThinkTag() { + return p.Add("", done) + } + if p.maybeThinkingOpenAtBOL { + bufStr := p.buffer.String() + trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace) + if trimmed == "" || overlap(trimmed, nemotronThinkOpen) == len(trimmed) { + if len(trimmed) != len(bufStr) { + p.buffer.Reset() + p.buffer.WriteString(trimmed) + } + return "", "", nil, nil + } + } bufStr := p.buffer.String() // Look for end of thinking: or (model may skip ) @@ -124,3 +154,35 @@ func (p *Nemotron3NanoParser) emitThinking(bufStr string) string { p.buffer.Reset() return bufStr } + +func (p *Nemotron3NanoParser) stripOpeningThinkTag() bool { + if !p.maybeThinkingOpenAtBOL { + return false + } + + bufStr := p.buffer.String() + trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace) + if trimmed == "" { + p.buffer.Reset() + return false + } + + if strings.HasPrefix(trimmed, nemotronThinkOpen) { + p.buffer.Reset() + p.buffer.WriteString(strings.TrimLeftFunc(trimmed[len(nemotronThinkOpen):], unicode.IsSpace)) + p.maybeThinkingOpenAtBOL = false + p.skipThinkingLeadingWS = true + return true + } + + if overlap(trimmed, nemotronThinkOpen) == len(trimmed) { + if len(trimmed) != len(bufStr) { + p.buffer.Reset() + p.buffer.WriteString(trimmed) + } + return false + } + + p.maybeThinkingOpenAtBOL = false + return false +} diff --git a/model/parsers/nemotron3nano_test.go b/model/parsers/nemotron3nano_test.go index 2ba36a156..c52dd34cb 100644 --- a/model/parsers/nemotron3nano_test.go +++ b/model/parsers/nemotron3nano_test.go @@ -82,6 +82,20 @@ func TestNemotron3NanoParser(t *testing.T) { expectedThinking: "My thoughts...", expectedContent: "Content here.", }, + { + name: "leading open think tag is ignored", + input: "\nLet me think about this...\nHere is my answer.", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Let me think about this...", + expectedContent: "Here is my answer.", + }, + { + name: "empty explicit think block is ignored", + input: "\nHere is my answer.", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "", + expectedContent: "Here is my answer.", + }, } for _, tt := range tests { @@ -191,6 +205,13 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) { }, }, }, + { + name: "leading open think tag split across chunks", + chunks: []string{"", "\nThink first", "", "\nDone."}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Think first", + expectedContent: "Done.", + }, } for _, tt := range tests { @@ -265,11 +286,11 @@ func TestNemotron3NanoParser_Init(t *testing.T) { } }) - t.Run("starts in content state when nil thinkValue", func(t *testing.T) { + t.Run("starts in thinking state when nil thinkValue", func(t *testing.T) { p := &Nemotron3NanoParser{} p.Init(nil, nil, nil) - if p.state != Nemotron3NanoCollectingContent { - t.Errorf("expected state Nemotron3NanoCollectingContent, got %v", p.state) + if p.state != Nemotron3NanoCollectingThinking { + t.Errorf("expected state Nemotron3NanoCollectingThinking, got %v", p.state) } }) @@ -281,6 +302,29 @@ func TestNemotron3NanoParser_Init(t *testing.T) { t.Errorf("expected state Nemotron3NanoCollectingContent, got %v", p.state) } }) + + t.Run("reinit clears buffered state", func(t *testing.T) { + p := &Nemotron3NanoParser{} + p.Init(nil, nil, &api.ThinkValue{Value: true}) + if _, _, _, err := p.Add("thinking in progress", false); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + p.Init(nil, nil, &api.ThinkValue{Value: false}) + content, thinking, calls, err := p.Add("content only", true) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "content only" { + t.Fatalf("expected content after reinit, got %q", content) + } + if thinking != "" { + t.Fatalf("expected no thinking after reinit, got %q", thinking) + } + if len(calls) != 0 { + t.Fatalf("expected no tool calls after reinit, got %v", calls) + } + }) } func TestNemotron3NanoParser_WithTools(t *testing.T) { diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index 38cad3e79..2f25ae613 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -87,6 +87,8 @@ func ParserForName(name string) Parser { return &LFM2Parser{hasThinkingSupport: false} case "lfm2-thinking": return &LFM2Parser{hasThinkingSupport: true} + case "laguna": + return &LagunaParser{} default: return nil } diff --git a/model/renderers/laguna.go b/model/renderers/laguna.go new file mode 100644 index 000000000..0400b5158 --- /dev/null +++ b/model/renderers/laguna.go @@ -0,0 +1,111 @@ +package renderers + +import ( + "strings" + + "github.com/ollama/ollama/api" +) + +const ( + lagunaBOS = "〈|EOS|〉" + lagunaThoughtOpen = "" + lagunaThoughtClose = "" +) + +type LagunaRenderer struct{} + +func (r *LagunaRenderer) Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { + var sb strings.Builder + sb.WriteString(lagunaBOS) + + thinkingEnabled := think == nil || think.Bool() + systemMessage := "" + firstMessageIsSystem := len(messages) > 0 && messages[0].Role == "system" + if firstMessageIsSystem { + systemMessage = strings.TrimRight(messages[0].Content, "\n") + } + + sb.WriteString("\n") + if thinkingEnabled { + sb.WriteString("You should use chain-of-thought reasoning. Put your reasoning inside tags before your response.") + } else { + sb.WriteString("You should respond directly without using chain-of-thought reasoning tags.") + } + if strings.TrimSpace(systemMessage) != "" { + sb.WriteByte('\n') + sb.WriteString(systemMessage) + } + if len(tools) > 0 { + sb.WriteString("\n\n### Tools\n\n") + sb.WriteString("You may call functions to assist with the user query.\n") + sb.WriteString("All available function signatures are listed below:\n") + sb.WriteString("\n") + for _, tool := range tools { + if b, err := marshalWithSpaces(tool); err == nil { + sb.Write(b) + sb.WriteByte('\n') + } + } + sb.WriteString("\n\n") + sb.WriteString("For each function call, return a json object with function name and arguments within '' and '' tags:\n") + sb.WriteString("\n{\"name\": , \"arguments\": }\n") + } + sb.WriteString("\n\n") + + for i, message := range messages { + if i == 0 && firstMessageIsSystem { + continue + } + content := message.Content + switch message.Role { + case "user": + sb.WriteString("\n") + sb.WriteString(content) + sb.WriteString("\n\n") + case "assistant": + lastMessage := i == len(messages)-1 + prefill := lastMessage && (content != "" || message.Thinking != "" || len(message.ToolCalls) > 0) + sb.WriteString("\n") + if thinkingEnabled && message.Thinking != "" { + sb.WriteString(lagunaThoughtOpen) + sb.WriteString(message.Thinking) + sb.WriteString(lagunaThoughtClose) + sb.WriteByte('\n') + } + if strings.Trim(content, "\n") != "" { + sb.WriteString(strings.Trim(content, "\n")) + sb.WriteByte('\n') + } + for _, toolCall := range message.ToolCalls { + sb.WriteString("") + sb.WriteString(toolCall.Function.Name) + sb.WriteByte('\n') + for name, value := range toolCall.Function.Arguments.All() { + sb.WriteString("") + sb.WriteString(name) + sb.WriteString("\n") + sb.WriteString("") + sb.WriteString(formatToolCallArgument(value)) + sb.WriteString("\n") + } + sb.WriteString("\n") + } + if !prefill { + sb.WriteString("\n") + } + case "tool": + sb.WriteString("\n") + sb.WriteString(content) + sb.WriteString("\n\n") + case "system": + sb.WriteString("\n") + sb.WriteString(content) + sb.WriteString("\n\n") + } + } + + if len(messages) == 0 || messages[len(messages)-1].Role != "assistant" { + sb.WriteString("\n") + } + return sb.String(), nil +} diff --git a/model/renderers/laguna_test.go b/model/renderers/laguna_test.go new file mode 100644 index 000000000..e5cecfae7 --- /dev/null +++ b/model/renderers/laguna_test.go @@ -0,0 +1,339 @@ +package renderers + +import ( + "encoding/json" + "os" + "os/exec" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" +) + +const ( + lagunaDirectDirective = "You should respond directly without using chain-of-thought reasoning tags." + lagunaThinkDirective = "You should use chain-of-thought reasoning. Put your reasoning inside tags before your response." +) + +func TestLagunaRendererReferenceFlowCoverage(t *testing.T) { + weather := lagunaWeatherTool() + + tests := []struct { + name string + messages []api.Message + tools []api.Tool + think *api.ThinkValue + want string + }{ + { + name: "user_only_thinking_default_on", + messages: []api.Message{{Role: "user", Content: "Hello"}}, + want: "" + + "〈|EOS|〉\n" + + lagunaThinkDirective + + "\n\n" + + "\nHello\n\n" + + "\n", + }, + { + name: "user_only_thinking_enabled", + messages: []api.Message{{Role: "user", Content: "Hello"}}, + think: &api.ThinkValue{Value: true}, + want: "" + + "〈|EOS|〉\n" + + lagunaThinkDirective + + "\n\n" + + "\nHello\n\n" + + "\n", + }, + { + name: "user_only_thinking_disabled", + messages: []api.Message{{Role: "user", Content: "Hello"}}, + think: &api.ThinkValue{Value: false}, + want: "" + + "〈|EOS|〉\n" + + lagunaDirectDirective + + "\n\n" + + "\nHello\n\n" + + "\n", + }, + { + name: "first_system_is_header", + messages: []api.Message{ + {Role: "system", Content: "Stay concise.\n\n"}, + {Role: "user", Content: "Hi"}, + }, + want: "" + + "〈|EOS|〉\n" + + lagunaThinkDirective + + "\nStay concise." + + "\n\n" + + "\nHi\n\n" + + "\n", + }, + { + name: "additional_system_message_renders_in_loop", + messages: []api.Message{ + {Role: "system", Content: "Primary."}, + {Role: "user", Content: "Hi"}, + {Role: "system", Content: "Secondary."}, + }, + want: "" + + "〈|EOS|〉\n" + + lagunaThinkDirective + + "\nPrimary." + + "\n\n" + + "\nHi\n\n" + + "\nSecondary.\n\n" + + "\n", + }, + { + name: "tools_in_header", + messages: []api.Message{ + {Role: "system", Content: "Stay concise."}, + {Role: "user", Content: "Weather?"}, + }, + tools: weather, + think: &api.ThinkValue{Value: true}, + want: "" + + "〈|EOS|〉\n" + + lagunaThinkDirective + + "\nStay concise." + + "\n\n### Tools\n\n" + + "You may call functions to assist with the user query.\n" + + "All available function signatures are listed below:\n" + + "\n" + + `{"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "City"}}}}}` + "\n" + + "\n\n" + + "For each function call, return a json object with function name and arguments within '' and '' tags:\n" + + "\n{\"name\": , \"arguments\": }\n" + + "\n\n" + + "\nWeather?\n\n" + + "\n", + }, + { + name: "tools_default_thinking_on_when_unspecified", + messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + }, + tools: weather, + want: "" + + "〈|EOS|〉\n" + + lagunaThinkDirective + + "\n\n### Tools\n\n" + + "You may call functions to assist with the user query.\n" + + "All available function signatures are listed below:\n" + + "\n" + + `{"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "City"}}}}}` + "\n" + + "\n\n" + + "For each function call, return a json object with function name and arguments within '' and '' tags:\n" + + "\n{\"name\": , \"arguments\": }\n" + + "\n\n" + + "\nWeather?\n\n" + + "\n", + }, + { + name: "assistant_history_with_thinking_content_tool_and_response", + messages: []api.Message{ + {Role: "user", Content: "Add these."}, + { + Role: "assistant", + Content: "\nCalling the tool.\n", + Thinking: "Need addition.", + ToolCalls: []api.ToolCall{{ + Function: api.ToolCallFunction{ + Name: "add", + Arguments: testArgsOrdered([]orderedArg{ + {Key: "a", Value: 2}, + {Key: "b", Value: 3}, + }), + }, + }}, + }, + {Role: "tool", Content: "5"}, + {Role: "user", Content: "Thanks"}, + }, + think: &api.ThinkValue{Value: true}, + want: "" + + "〈|EOS|〉\n" + + lagunaThinkDirective + + "\n\n" + + "\nAdd these.\n\n" + + "\n" + + "Need addition.\n" + + "Calling the tool.\n" + + "add\n" + + "a\n2\n" + + "b\n3\n" + + "\n" + + "\n" + + "\n5\n\n" + + "\nThanks\n\n" + + "\n", + }, + { + name: "final_assistant_prefill_is_continued", + messages: []api.Message{ + {Role: "user", Content: "Complete this"}, + {Role: "assistant", Content: "Partial"}, + }, + want: "" + + "〈|EOS|〉\n" + + lagunaThinkDirective + + "\n\n" + + "\nComplete this\n\n" + + "\nPartial\n", + }, + } + + renderer := &LagunaRenderer{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := renderer.Render(tt.messages, tt.tools, tt.think) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatalf("renderer output mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLagunaRendererMatchesLocalJinjaControlFlow(t *testing.T) { + if os.Getenv("VERIFY_LAGUNA_JINJA2") == "" { + t.Skip("set VERIFY_LAGUNA_JINJA2=1 to compare against the local Laguna chat_template.jinja") + } + python := "/Users/daniel/.codex/worktrees/7038/ollama/.venv/bin/python3" + if _, err := os.Stat(python); err != nil { + t.Fatalf("VERIFY_LAGUNA_JINJA2 requires %s with jinja2 installed", python) + } + + tests := []struct { + name string + messages []api.Message + think *api.ThinkValue + }{ + { + name: "user_only", + messages: []api.Message{{Role: "user", Content: "Hello"}}, + }, + { + name: "system_user", + messages: []api.Message{ + {Role: "system", Content: "Stay concise.\n"}, + {Role: "user", Content: "Hello"}, + }, + }, + { + name: "additional_system_and_tool_response", + messages: []api.Message{ + {Role: "system", Content: "Primary."}, + {Role: "user", Content: "Weather?"}, + {Role: "assistant", Content: "Calling."}, + {Role: "tool", Content: "Sunny"}, + {Role: "system", Content: "Secondary."}, + }, + }, + { + name: "thinking_enabled", + messages: []api.Message{{Role: "user", Content: "Think briefly."}}, + think: &api.ThinkValue{Value: true}, + }, + { + name: "thinking_disabled", + messages: []api.Message{{Role: "user", Content: "Answer directly."}}, + think: &api.ThinkValue{Value: false}, + }, + } + + renderer := &LagunaRenderer{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := renderer.Render(tt.messages, nil, tt.think) + if err != nil { + t.Fatal(err) + } + for _, modelDir := range []string{ + "/Users/daniel/Models/poolside/laguna-xs-23-04-2026", + } { + want := renderLagunaChatTemplate(t, python, modelDir, tt.messages, tt.think) + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf("%s mismatch (-chat_template +renderer):\n%s", modelDir, diff) + } + } + }) + } +} + +func renderLagunaChatTemplate(t *testing.T, python, modelDir string, messages []api.Message, think *api.ThinkValue) string { + t.Helper() + + type templateMessage struct { + Role string `json:"role"` + Content string `json:"content"` + } + templateMessages := make([]templateMessage, 0, len(messages)) + for _, msg := range messages { + templateMessages = append(templateMessages, templateMessage{ + Role: msg.Role, + Content: msg.Content, + }) + } + messagesJSON, err := json.Marshal(templateMessages) + if err != nil { + t.Fatalf("failed to marshal messages: %v", err) + } + + enableThinking := "True" + if think != nil && !think.Bool() { + enableThinking = "False" + } + + script := ` +import json +import sys +from transformers import AutoTokenizer + +model_dir = sys.argv[1] +messages = json.loads(sys.argv[2]) +enable_thinking = sys.argv[3] == "True" +tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) +print(tok.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=enable_thinking, +), end="") +` + cmd := exec.Command(python, "-c", script, modelDir, string(messagesJSON), enableThinking) + var stdout, stderr strings.Builder + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + t.Fatalf("chat_template render failed: %v\nstderr: %s", err, stderr.String()) + } + return stdout.String() +} + +func lagunaWeatherTool() []api.Tool { + return []api.Tool{{ + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: testPropsOrdered([]orderedProp{{ + Key: "location", + Value: api.ToolProperty{ + Type: api.PropertyType{"string"}, + Description: "City", + }, + }}), + }, + }, + }} +} diff --git a/model/renderers/nemotron3nano.go b/model/renderers/nemotron3nano.go index df847b48f..8c3cb58aa 100644 --- a/model/renderers/nemotron3nano.go +++ b/model/renderers/nemotron3nano.go @@ -3,6 +3,9 @@ package renderers import ( "encoding/json" "fmt" + "reflect" + "sort" + "strconv" "strings" "github.com/ollama/ollama/api" @@ -12,15 +15,15 @@ type Nemotron3NanoRenderer struct{} func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) { var sb strings.Builder + imageOffset := 0 - // thinking is enabled if user requests it - enableThinking := thinkValue != nil && thinkValue.Bool() + enableThinking := r.resolveThinking(messages, thinkValue) // Extract system message if present var systemMessage string var loopMessages []api.Message if len(messages) > 0 && messages[0].Role == "system" { - systemMessage = messages[0].Content + systemMessage = r.sanitizeSystemMessage(messages[0].Content) loopMessages = messages[1:] } else { loopMessages = messages @@ -34,6 +37,7 @@ func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool, } } + sb.WriteString("\n\n\n") sb.WriteString("<|im_start|>system\n") if systemMessage != "" { sb.WriteString(systemMessage) @@ -45,28 +49,30 @@ func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool, } sb.WriteString(r.renderTools(tools)) } - sb.WriteString("<|im_end|>\n") + sb.WriteString("<|im_end|>\n\n") for i, message := range loopMessages { switch message.Role { case "assistant": - // Build content with thinking tags content := r.buildContent(message) shouldTruncate := i < lastUserIdx if len(message.ToolCalls) > 0 { sb.WriteString("<|im_start|>assistant\n") - sb.WriteString(r.formatContent(content, shouldTruncate, true)) + sb.WriteString(r.formatToolCallContent(content, shouldTruncate)) r.writeToolCalls(&sb, message.ToolCalls) sb.WriteString("<|im_end|>\n") } else { - formatted := r.formatContent(content, shouldTruncate, false) - sb.WriteString("<|im_start|>assistant\n" + formatted + "<|im_end|>\n") + formatted := r.formatAssistantContent(content, shouldTruncate) + sb.WriteString("<|im_start|>assistant\n") + sb.WriteString(formatted) + sb.WriteString("<|im_end|>\n") } case "user", "system": sb.WriteString("<|im_start|>" + message.Role + "\n") - sb.WriteString(message.Content) + sb.WriteString(r.renderMessageContent(message, imageOffset)) + imageOffset += len(message.Images) sb.WriteString("<|im_end|>\n") case "tool": @@ -90,6 +96,8 @@ func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool, } } + sb.WriteString("\n") + // Add generation prompt if enableThinking { sb.WriteString("<|im_start|>assistant\n\n") @@ -119,7 +127,7 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string { sb.WriteString("\n" + paramName + "") if len(paramFields.Type) > 0 { - sb.WriteString("\n" + strings.Join(paramFields.Type, ", ") + "") + sb.WriteString("\n" + r.formatPropertyType(paramFields.Type) + "") } if paramFields.Description != "" { @@ -127,17 +135,17 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string { } if len(paramFields.Enum) > 0 { - enumJSON, _ := json.Marshal(paramFields.Enum) - sb.WriteString("\n" + string(enumJSON) + "") + sb.WriteString("\n" + r.pythonJSON(paramFields.Enum) + "") } + r.renderToolPropertyExtraKeys(&sb, paramFields) sb.WriteString("\n") } } + r.renderToolParameterExtraKeys(&sb, fn.Parameters) if len(fn.Parameters.Required) > 0 { - reqJSON, _ := json.Marshal(fn.Parameters.Required) - sb.WriteString("\n" + string(reqJSON) + "") + sb.WriteString("\n" + r.pythonJSON(fn.Parameters.Required) + "") } sb.WriteString("\n") @@ -159,27 +167,38 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string { } func (r *Nemotron3NanoRenderer) buildContent(message api.Message) string { - // The parser always extracts thinking into the Thinking field, - // so Content will never have tags embedded + content := nemotron3NanoRenderContent(message.Content) if message.Thinking != "" { - return "\n" + message.Thinking + "\n\n" + message.Content + return "\n" + message.Thinking + "\n\n" + content } - return "" + message.Content + if !strings.Contains(content, "") && !strings.Contains(content, "") { + return "" + content + } + return content } -func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, addNewline bool) string { - if content == "" { +func (r *Nemotron3NanoRenderer) formatAssistantContent(content string, truncate bool) string { + if !truncate { + return strings.TrimSpace(content) + } + + c := content + if strings.Contains(c, "") && strings.Contains(c, "") { + parts := strings.Split(c, "") + c = "" + parts[len(parts)-1] + } + return strings.TrimSpace(c) +} + +func (r *Nemotron3NanoRenderer) formatToolCallContent(content string, truncate bool) string { + if strings.TrimSpace(content) == "" { return "" } if !truncate { - if addNewline { - return strings.TrimSpace(content) + "\n" - } - return strings.TrimSpace(content) + return strings.TrimSpace(content) + "\n" } - // Truncate thinking - keep only content after c := content if strings.Contains(c, "") { parts := strings.Split(c, "") @@ -190,13 +209,7 @@ func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, add } c = "" + strings.TrimSpace(c) - if addNewline && len(c) > len("") { - return c + "\n" - } - if c == "" { - return c - } - return strings.TrimSpace(c) + return strings.TrimSpace(c) + "\n" } func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) { @@ -212,9 +225,225 @@ func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls [] func (r *Nemotron3NanoRenderer) formatArgValue(value any) string { switch v := value.(type) { case map[string]any, []any: - jsonBytes, _ := json.Marshal(v) - return string(jsonBytes) + return r.pythonJSON(v) default: return fmt.Sprintf("%v", v) } } + +func (r *Nemotron3NanoRenderer) renderMessageContent(message api.Message, imageOffset int) string { + content := nemotron3NanoRenderContent(message.Content) + if len(message.Images) == 0 { + return content + } + + if strings.Contains(content, "[img-") { + return content + } + + if strings.Contains(content, "[img]") { + for i := range message.Images { + content = strings.Replace(content, "[img]", fmt.Sprintf("[img-%d]", imageOffset+i), 1) + } + return content + } + + var sb strings.Builder + for i := range message.Images { + sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset+i)) + } + sb.WriteString(content) + return sb.String() +} + +func nemotron3NanoRenderContent(content any) string { + switch v := content.(type) { + case string: + return v + case []any: + var sb strings.Builder + for _, item := range v { + obj, ok := item.(map[string]any) + if !ok { + bts, _ := json.Marshal(item) + sb.Write(bts) + continue + } + + switch obj["type"] { + case "image": + sb.WriteString("") + case "text": + if text, ok := obj["text"].(string); ok { + sb.WriteString(text) + } + default: + bts, _ := json.Marshal(item) + sb.Write(bts) + } + } + return sb.String() + default: + bts, _ := json.Marshal(v) + return string(bts) + } +} + +func (r *Nemotron3NanoRenderer) resolveThinking(messages []api.Message, thinkValue *api.ThinkValue) bool { + enableThinking := thinkValue == nil || thinkValue.Bool() + for _, message := range messages { + if message.Role != "user" && message.Role != "system" { + continue + } + content := message.Content + if strings.Contains(strings.ReplaceAll(content, "", ""), "/think") { + enableThinking = true + } else if strings.Contains(content, "/no_think") { + enableThinking = false + } + } + return enableThinking +} + +func (r *Nemotron3NanoRenderer) sanitizeSystemMessage(content string) string { + system := nemotron3NanoRenderContent(content) + system = strings.ReplaceAll(system, "", "<_end_think>") + system = strings.ReplaceAll(system, "/think", "") + system = strings.ReplaceAll(system, "/no_think", "") + system = strings.ReplaceAll(system, "<_end_think>", "") + return system +} + +func (r *Nemotron3NanoRenderer) formatPropertyType(propertyType api.PropertyType) string { + if len(propertyType) == 1 { + return propertyType[0] + } + quoted := make([]string, 0, len(propertyType)) + for _, v := range propertyType { + quoted = append(quoted, "'"+v+"'") + } + return "[" + strings.Join(quoted, ", ") + "]" +} + +func (r *Nemotron3NanoRenderer) renderToolPropertyExtraKeys(sb *strings.Builder, prop api.ToolProperty) { + if len(prop.AnyOf) > 0 { + sb.WriteString("\n" + r.pythonJSON(prop.AnyOf) + "") + } + if prop.Items != nil { + sb.WriteString("\n" + r.pythonJSON(prop.Items) + "") + } + if prop.Properties != nil { + sb.WriteString("\n" + r.pythonJSON(prop.Properties) + "") + } + if len(prop.Required) > 0 { + sb.WriteString("\n" + r.pythonJSON(prop.Required) + "") + } +} + +func (r *Nemotron3NanoRenderer) renderToolParameterExtraKeys(sb *strings.Builder, params api.ToolFunctionParameters) { + if params.Defs != nil { + sb.WriteString("\n<$defs>" + r.pythonJSON(params.Defs) + "") + } + if params.Items != nil { + sb.WriteString("\n" + r.pythonJSON(params.Items) + "") + } +} + +func (r *Nemotron3NanoRenderer) pythonJSON(v any) string { + switch value := v.(type) { + case nil: + return "null" + case string: + return strconv.Quote(value) + case bool: + if value { + return "true" + } + return "false" + case int, int8, int16, int32, int64: + return fmt.Sprintf("%d", reflect.ValueOf(value).Int()) + case uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("%d", reflect.ValueOf(value).Uint()) + case float32, float64: + b, _ := json.Marshal(value) + return string(b) + case api.PropertyType: + return r.pythonJSON([]string(value)) + case []string: + parts := make([]string, 0, len(value)) + for _, item := range value { + parts = append(parts, r.pythonJSON(item)) + } + return "[" + strings.Join(parts, ", ") + "]" + case []any: + parts := make([]string, 0, len(value)) + for _, item := range value { + parts = append(parts, r.pythonJSON(item)) + } + return "[" + strings.Join(parts, ", ") + "]" + case []api.ToolProperty: + parts := make([]string, 0, len(value)) + for _, item := range value { + parts = append(parts, r.pythonJSON(item)) + } + return "[" + strings.Join(parts, ", ") + "]" + case map[string]any: + keys := make([]string, 0, len(value)) + for key := range value { + keys = append(keys, key) + } + sort.Strings(keys) + parts := make([]string, 0, len(keys)) + for _, key := range keys { + parts = append(parts, strconv.Quote(key)+": "+r.pythonJSON(value[key])) + } + return "{" + strings.Join(parts, ", ") + "}" + case *api.ToolPropertiesMap: + if value == nil { + return "null" + } + parts := make([]string, 0, value.Len()) + for key, prop := range value.All() { + parts = append(parts, strconv.Quote(key)+": "+r.pythonJSON(prop)) + } + return "{" + strings.Join(parts, ", ") + "}" + case api.ToolProperty: + parts := make([]string, 0, 6) + if len(value.AnyOf) > 0 { + parts = append(parts, `"anyOf": `+r.pythonJSON(value.AnyOf)) + } + if len(value.Type) > 0 { + if len(value.Type) == 1 { + parts = append(parts, `"type": `+r.pythonJSON(value.Type[0])) + } else { + parts = append(parts, `"type": `+r.pythonJSON([]string(value.Type))) + } + } + if value.Items != nil { + parts = append(parts, `"items": `+r.pythonJSON(value.Items)) + } + if value.Description != "" { + parts = append(parts, `"description": `+r.pythonJSON(value.Description)) + } + if len(value.Enum) > 0 { + parts = append(parts, `"enum": `+r.pythonJSON(value.Enum)) + } + if value.Properties != nil { + parts = append(parts, `"properties": `+r.pythonJSON(value.Properties)) + } + if len(value.Required) > 0 { + parts = append(parts, `"required": `+r.pythonJSON(value.Required)) + } + return "{" + strings.Join(parts, ", ") + "}" + default: + b, err := json.Marshal(value) + if err != nil { + return "null" + } + var generic any + if err := json.Unmarshal(b, &generic); err != nil { + return string(b) + } + return r.pythonJSON(generic) + } +} diff --git a/model/renderers/nemotron3nano_reference_test.go b/model/renderers/nemotron3nano_reference_test.go new file mode 100644 index 000000000..8a39f0c4d --- /dev/null +++ b/model/renderers/nemotron3nano_reference_test.go @@ -0,0 +1,614 @@ +package renderers + +import ( + "encoding/json" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +const nemotron3NanoTemplate = "testdata/nemotron3nano_chat_template.jinja2" + +func TestNemotron3NanoRendererMatchesReference(t *testing.T) { + toolText := `<|im_start|>system +# Tools + +You have access to the following functions: + + + +search_docs +Search docs + + +query +string +Search query +["api", "cli"] + + +mode +['string', 'null'] +Mode +[{"type": "string"}, {"type": "number"}] + + +payload +object +Payload +{"enabled": {"type": "boolean"}} +["enabled"] + + +tags +array +Tags +{"type": "string"} + +<$defs>{"shared": {"type": "string"}} +["query"] + + + + +If you choose to call a function ONLY reply in the following format with NO suffix: + + + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + + + + +Reminder: +- Function calls MUST follow the specified format: an inner block must be nested within XML tags +- Required parameters MUST be specified +- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +<|im_end|> +` + toolTextWithSystem := `<|im_start|>system +Follow policy. + +# Tools + +You have access to the following functions: + + + +search_docs +Search docs + + +query +string +Search query +["api", "cli"] + + +mode +['string', 'null'] +Mode +[{"type": "string"}, {"type": "number"}] + + +payload +object +Payload +{"enabled": {"type": "boolean"}} +["enabled"] + + +tags +array +Tags +{"type": "string"} + +<$defs>{"shared": {"type": "string"}} +["query"] + + + + +If you choose to call a function ONLY reply in the following format with NO suffix: + + + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + + + + +Reminder: +- Function calls MUST follow the specified format: an inner block must be nested within XML tags +- Required parameters MUST be specified +- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +<|im_end|> +` + + tests := []struct { + name string + messages []api.Message + tools []api.Tool + think *api.ThinkValue + expected string + }{ + { + name: "no system default thinking on", + messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHello<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "no system explicit thinking off", + messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + think: thinkFalse(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHello<|im_end|>\n\n<|im_start|>assistant\n", + }, + { + name: "literal endthink does not enable thinking", + messages: []api.Message{ + {Role: "user", Content: "literal only"}, + }, + think: thinkFalse(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nliteral only<|im_end|>\n\n<|im_start|>assistant\n", + }, + { + name: "user no think toggle", + messages: []api.Message{ + {Role: "user", Content: "Hello /no_think"}, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHello /no_think<|im_end|>\n\n<|im_start|>assistant\n", + }, + { + name: "system think toggle overrides false", + messages: []api.Message{ + {Role: "system", Content: "Policy /think"}, + {Role: "user", Content: "Hello"}, + }, + think: thinkFalse(), + expected: "\n\n\n<|im_start|>system\nPolicy <|im_end|>\n\n<|im_start|>user\nHello<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "later toggle wins", + messages: []api.Message{ + {Role: "system", Content: "Policy /no_think"}, + {Role: "user", Content: "Actually /think"}, + }, + think: thinkFalse(), + expected: "\n\n\n<|im_start|>system\nPolicy <|im_end|>\n\n<|im_start|>user\nActually /think<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "system sanitizes toggles but preserves closing tag", + messages: []api.Message{ + {Role: "system", Content: "A /think B /no_think C "}, + {Role: "user", Content: "Hello"}, + }, + think: thinkFalse(), + expected: "\n\n\n<|im_start|>system\nA B C <|im_end|>\n\n<|im_start|>user\nHello<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "assistant plain content adds empty think block", + messages: []api.Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "Hello there"}, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\nHello there<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "assistant reasoning content", + messages: []api.Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "Answer", Thinking: "Need to think"}, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n\nNeed to think\n\nAnswer<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "assistant preserves existing think tags", + messages: []api.Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "keptAnswer"}, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\nkeptAnswer<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "tools without system", + messages: []api.Message{ + {Role: "user", Content: "Use a tool"}, + }, + tools: nemotron3NanoReferenceTools(), + think: thinkTrue(), + expected: "\n\n\n" + toolText + "\n<|im_start|>user\nUse a tool<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "system with tools", + messages: []api.Message{ + {Role: "system", Content: "Follow policy."}, + {Role: "user", Content: "Use a tool"}, + }, + tools: nemotron3NanoReferenceTools(), + think: thinkTrue(), + expected: "\n\n\n" + toolTextWithSystem + "\n<|im_start|>user\nUse a tool<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "assistant tool call with content", + messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + { + Role: "assistant", + Content: "Checking now.", + ToolCalls: []api.ToolCall{{ + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{"city": "Paris"}), + }, + }}, + }, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nWeather?<|im_end|>\n<|im_start|>assistant\nChecking now.\n\n\n\nParis\n\n\n\n<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "assistant tool call with structured arguments", + messages: []api.Message{ + {Role: "user", Content: "Create data"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{{ + Function: api.ToolCallFunction{ + Name: "create", + Arguments: testArgsOrdered([]orderedArg{ + {Key: "payload", Value: map[string]any{"count": 42, "nested": map[string]any{"value": "ok"}}}, + {Key: "tags", Value: []any{"a", "b"}}, + }), + }, + }}, + }, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nCreate data<|im_end|>\n<|im_start|>assistant\n\n\n\n\n{\"count\": 42, \"nested\": {\"value\": \"ok\"}}\n\n\n[\"a\", \"b\"]\n\n\n\n<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "assistant tool call truncated with reasoning", + messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + { + Role: "assistant", + Content: "Checking now.", + Thinking: "Need weather", + ToolCalls: []api.ToolCall{{ + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{"city": "Paris"}), + }, + }}, + }, + {Role: "user", Content: "And tomorrow?"}, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nWeather?<|im_end|>\n<|im_start|>assistant\nChecking now.\n\n\n\nParis\n\n\n\n<|im_end|>\n<|im_start|>user\nAnd tomorrow?<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "assistant tool call truncated open think only", + messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + { + Role: "assistant", + Content: "draft", + ToolCalls: []api.ToolCall{{ + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{"city": "Paris"}), + }, + }}, + }, + {Role: "user", Content: "And tomorrow?"}, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nWeather?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nParis\n\n\n\n<|im_end|>\n<|im_start|>user\nAnd tomorrow?<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "assistant tool call empty content", + messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + { + Role: "assistant", + Content: "", + ToolCalls: []api.ToolCall{{ + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{"city": "Paris"}), + }, + }}, + }, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nWeather?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nParis\n\n\n\n<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "assistant truncated with think pair", + messages: []api.Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "hiddenVisible"}, + {Role: "user", Content: "Next"}, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\nVisible<|im_end|>\n<|im_start|>user\nNext<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "assistant truncated reasoning content", + messages: []api.Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Thinking: "hidden", Content: "Visible"}, + {Role: "user", Content: "Next"}, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n\nVisible<|im_end|>\n<|im_start|>user\nNext<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "assistant truncated plain content", + messages: []api.Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "Visible"}, + {Role: "user", Content: "Next"}, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\nVisible<|im_end|>\n<|im_start|>user\nNext<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "assistant truncated empty content", + messages: []api.Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: ""}, + {Role: "user", Content: "Next"}, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nNext<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "consecutive tool messages grouped", + messages: []api.Message{ + {Role: "user", Content: "Do work"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{{ + Function: api.ToolCallFunction{ + Name: "step", + Arguments: testArgs(map[string]any{"value": 1}), + }, + }}, + }, + {Role: "tool", Content: "one"}, + {Role: "tool", Content: "two"}, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nDo work<|im_end|>\n<|im_start|>assistant\n\n\n\n\n1\n\n\n\n<|im_end|>\n<|im_start|>user\n\none\n\n\ntwo\n\n<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + { + name: "fallback role", + messages: []api.Message{ + {Role: "developer", Content: "Custom role content"}, + }, + think: thinkTrue(), + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>developer\nCustom role content<|im_end|>\n\n<|im_start|>assistant\n\n", + }, + } + + verifyJinja2 := os.Getenv("VERIFY_JINJA2") != "" + if verifyJinja2 { + if _, err := os.Stat(filepath.Join(nemotron3NanoRepoRoot(t), ".venv", "bin", "python3")); err != nil { + t.Fatal("VERIFY_JINJA2=1 requires .venv/bin/python3") + } + } + + renderer := &Nemotron3NanoRenderer{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := renderer.Render(tt.messages, tt.tools, tt.think) + if err != nil { + t.Fatalf("Render() error = %v", err) + } + + if diff := cmp.Diff(tt.expected, got); diff != "" { + t.Fatalf("renderer mismatch (-want +got):\n%s", diff) + } + + if verifyJinja2 { + jinja2Output := renderNemotron3NanoWithJinja2(t, tt.messages, tt.tools, tt.think) + if diff := cmp.Diff(tt.expected, jinja2Output); diff != "" { + t.Fatalf("reference template mismatch (-want +got):\n%s", diff) + } + } + }) + } +} + +func nemotron3NanoReferenceTools() []api.Tool { + return []api.Tool{{ + Type: "function", + Function: api.ToolFunction{ + Name: "search_docs", + Description: "Search docs", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Defs: map[string]any{"shared": map[string]any{"type": "string"}}, + Required: []string{"query"}, + Properties: testPropsOrdered([]orderedProp{ + { + Key: "query", + Value: api.ToolProperty{ + Type: api.PropertyType{"string"}, + Description: "Search query", + Enum: []any{"api", "cli"}, + }, + }, + { + Key: "mode", + Value: api.ToolProperty{ + Type: api.PropertyType{"string", "null"}, + Description: "Mode", + AnyOf: []api.ToolProperty{ + {Type: api.PropertyType{"string"}}, + {Type: api.PropertyType{"number"}}, + }, + }, + }, + { + Key: "payload", + Value: api.ToolProperty{ + Type: api.PropertyType{"object"}, + Description: "Payload", + Properties: testPropsOrdered([]orderedProp{{Key: "enabled", Value: api.ToolProperty{Type: api.PropertyType{"boolean"}}}}), + Required: []string{"enabled"}, + }, + }, + { + Key: "tags", + Value: api.ToolProperty{ + Type: api.PropertyType{"array"}, + Description: "Tags", + Items: map[string]any{"type": "string"}, + }, + }, + }), + }, + }, + }} +} + +func renderNemotron3NanoWithJinja2(t *testing.T, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string { + t.Helper() + + type jinja2ToolCall struct { + ID string `json:"id,omitempty"` + Function struct { + Name string `json:"name"` + Arguments any `json:"arguments"` + } `json:"function"` + } + type jinja2Message struct { + Role string `json:"role"` + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []jinja2ToolCall `json:"tool_calls,omitempty"` + Name string `json:"name,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + } + + var jMsgs []jinja2Message + for _, m := range messages { + jm := jinja2Message{ + Role: m.Role, + Content: m.Content, + ReasoningContent: m.Thinking, + Name: m.ToolName, + ToolCallID: m.ToolCallID, + } + for _, tc := range m.ToolCalls { + jtc := jinja2ToolCall{ID: tc.ID} + jtc.Function.Name = tc.Function.Name + var args map[string]any + raw, _ := tc.Function.Arguments.MarshalJSON() + if err := json.Unmarshal(raw, &args); err != nil { + t.Fatalf("failed to unmarshal tool args: %v", err) + } + jtc.Function.Arguments = args + jm.ToolCalls = append(jm.ToolCalls, jtc) + } + jMsgs = append(jMsgs, jm) + } + + msgsJSON, err := json.Marshal(jMsgs) + if err != nil { + t.Fatalf("failed to marshal messages: %v", err) + } + + toolsJSON := "None" + if len(tools) > 0 { + b, err := json.Marshal(tools) + if err != nil { + t.Fatalf("failed to marshal tools: %v", err) + } + toolsJSON = string(b) + } + + thinking := "unset" + if think != nil { + if think.Bool() { + thinking = "true" + } else { + thinking = "false" + } + } + + repoRoot := nemotron3NanoRepoRoot(t) + templatePath := filepath.Join(repoRoot, "model", "renderers", nemotron3NanoTemplate) + pythonPath := filepath.Join(repoRoot, ".venv", "bin", "python3") + script := ` +import json +import sys +from pathlib import Path +from transformers.utils.chat_template_utils import _compile_jinja_template + +template_path, messages_json, tools_json, thinking = sys.argv[1:5] +tmpl = _compile_jinja_template(Path(template_path).read_text()) +kwargs = { + "messages": json.loads(messages_json), + "add_generation_prompt": True, +} +if tools_json != "None": + kwargs["tools"] = json.loads(tools_json) +if thinking == "true": + kwargs["enable_thinking"] = True +elif thinking == "false": + kwargs["enable_thinking"] = False +print(tmpl.render(**kwargs), end="") +` + + cmd := exec.Command(pythonPath, "-c", script, templatePath, string(msgsJSON), toolsJSON, thinking) + var stdout, stderr strings.Builder + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + t.Fatalf("python render failed: %v\nstderr: %s", err, stderr.String()) + } + return stdout.String() +} + +func nemotron3NanoRepoRoot(t *testing.T) string { + t.Helper() + _, filename, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("failed to locate test file") + } + return filepath.Dir(filepath.Dir(filepath.Dir(filename))) +} diff --git a/model/renderers/nemotron3nano_test.go b/model/renderers/nemotron3nano_test.go index db8329fa7..1c55ab3e7 100644 --- a/model/renderers/nemotron3nano_test.go +++ b/model/renderers/nemotron3nano_test.go @@ -8,561 +8,46 @@ import ( "github.com/ollama/ollama/api" ) -func TestNemotron3NanoRenderer(t *testing.T) { +func TestNemotron3NanoRenderer_Images(t *testing.T) { tests := []struct { - name string - msgs []api.Message - tools []api.Tool - thinkValue *api.ThinkValue - expected string + name string + msgs []api.Message + expected string }{ { - name: "basic user message - thinking mode", + name: "single image inserts placeholder", msgs: []api.Message{ - {Role: "user", Content: "Hello!"}, + {Role: "user", Content: "Describe this image.", Images: []api.ImageData{api.ImageData("img1")}}, }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\n<|im_end|>\n" + - "<|im_start|>user\nHello!<|im_end|>\n" + - "<|im_start|>assistant\n\n", + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0]Describe this image.<|im_end|>\n\n<|im_start|>assistant\n\n", }, { - name: "basic user message - no thinking", + name: "generic image placeholder is rewritten", msgs: []api.Message{ - {Role: "user", Content: "Hello!"}, + {Role: "user", Content: "[img]Describe this image.", Images: []api.ImageData{api.ImageData("img1")}}, }, - thinkValue: nil, - expected: "<|im_start|>system\n<|im_end|>\n" + - "<|im_start|>user\nHello!<|im_end|>\n" + - "<|im_start|>assistant\n", + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0]Describe this image.<|im_end|>\n\n<|im_start|>assistant\n\n", }, { - name: "with system message", + name: "image offsets increment across messages", msgs: []api.Message{ - {Role: "system", Content: "You are a helpful assistant."}, - {Role: "user", Content: "Hello!"}, + {Role: "user", Content: "Describe the first image.", Images: []api.ImageData{api.ImageData("img1")}}, + {Role: "assistant", Content: "It shows something."}, + {Role: "user", Content: "Compare these.", Images: []api.ImageData{api.ImageData("img2"), api.ImageData("img3")}}, }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + - "<|im_start|>user\nHello!<|im_end|>\n" + - "<|im_start|>assistant\n\n", - }, - { - name: "multi-turn conversation", - msgs: []api.Message{ - {Role: "user", Content: "Hi"}, - {Role: "assistant", Content: "Hello! How can I help?"}, - {Role: "user", Content: "Tell me a joke"}, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\n<|im_end|>\n" + - "<|im_start|>user\nHi<|im_end|>\n" + - "<|im_start|>assistant\nHello! How can I help?<|im_end|>\n" + - "<|im_start|>user\nTell me a joke<|im_end|>\n" + - "<|im_start|>assistant\n\n", - }, - { - name: "with tools", - msgs: []api.Message{ - {Role: "user", Content: "What's the weather in Paris?"}, - }, - tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_weather", - Description: "Get the current weather", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Required: []string{"city"}, - Properties: testPropsMap(map[string]api.ToolProperty{ - "city": {Type: api.PropertyType{"string"}, Description: "The city name"}, - }), - }, - }, - }, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\n" + - "# Tools\n\nYou have access to the following functions:\n\n\n" + - "\nget_weather\n" + - "Get the current weather\n" + - "\n" + - "\ncity\nstring\nThe city name\n\n" + - "[\"city\"]\n" + - "\n\n\n\n" + - "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + - "\n\n\nvalue_1\n\n" + - "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + - "\n\n\n\n\nReminder:\n" + - "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + - "- Required parameters MUST be specified\n" + - "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + - "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + - "<|im_end|>\n" + - "<|im_start|>user\nWhat's the weather in Paris?<|im_end|>\n" + - "<|im_start|>assistant\n\n", - }, - { - name: "tool call with response", - msgs: []api.Message{ - {Role: "user", Content: "What's the weather in Paris?"}, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: testArgs(map[string]any{"city": "Paris"}), - }, - }, - }, - }, - {Role: "tool", Content: "Sunny, 72F"}, - }, - tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_weather", - Description: "Get the current weather", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Required: []string{"city"}, - Properties: testPropsMap(map[string]api.ToolProperty{ - "city": {Type: api.PropertyType{"string"}, Description: "The city name"}, - }), - }, - }, - }, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\n" + - "# Tools\n\nYou have access to the following functions:\n\n\n" + - "\nget_weather\n" + - "Get the current weather\n" + - "\n" + - "\ncity\nstring\nThe city name\n\n" + - "[\"city\"]\n" + - "\n\n\n\n" + - "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + - "\n\n\nvalue_1\n\n" + - "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + - "\n\n\n\n\nReminder:\n" + - "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + - "- Required parameters MUST be specified\n" + - "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + - "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + - "<|im_end|>\n" + - "<|im_start|>user\nWhat's the weather in Paris?<|im_end|>\n" + - "<|im_start|>assistant\n\n" + - "\n\n\nParis\n\n\n\n<|im_end|>\n" + - "<|im_start|>user\n\nSunny, 72F\n\n<|im_end|>\n" + - "<|im_start|>assistant\n\n", - }, - { - name: "assistant with content and tool call", - msgs: []api.Message{ - {Role: "user", Content: "What's the weather?"}, - { - Role: "assistant", - Content: "Let me check that for you.", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: testArgs(map[string]any{"city": "Paris"}), - }, - }, - }, - }, - {Role: "tool", Content: "Sunny"}, - }, - tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_weather", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Properties: testPropsMap(map[string]api.ToolProperty{ - "city": {Type: api.PropertyType{"string"}}, - }), - }, - }, - }, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\n" + - "# Tools\n\nYou have access to the following functions:\n\n\n" + - "\nget_weather\n" + - "\n" + - "\ncity\nstring\n\n" + - "\n\n\n\n" + - "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + - "\n\n\nvalue_1\n\n" + - "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + - "\n\n\n\n\nReminder:\n" + - "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + - "- Required parameters MUST be specified\n" + - "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + - "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + - "<|im_end|>\n" + - "<|im_start|>user\nWhat's the weather?<|im_end|>\n" + - "<|im_start|>assistant\nLet me check that for you.\n" + - "\n\n\nParis\n\n\n\n<|im_end|>\n" + - "<|im_start|>user\n\nSunny\n\n<|im_end|>\n" + - "<|im_start|>assistant\n\n", - }, - { - name: "thinking in history is truncated", - msgs: []api.Message{ - {Role: "user", Content: "Hi"}, - {Role: "assistant", Content: "Hello!", Thinking: "Let me think about this..."}, - {Role: "user", Content: "How are you?"}, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\n<|im_end|>\n" + - "<|im_start|>user\nHi<|im_end|>\n" + - "<|im_start|>assistant\nHello!<|im_end|>\n" + - "<|im_start|>user\nHow are you?<|im_end|>\n" + - "<|im_start|>assistant\n\n", - }, - { - name: "parallel tool calls", - msgs: []api.Message{ - {Role: "user", Content: "Weather in Paris and London?"}, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: testArgs(map[string]any{"city": "Paris"}), - }, - }, - { - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: testArgs(map[string]any{"city": "London"}), - }, - }, - }, - }, - {Role: "tool", Content: "Sunny"}, - {Role: "tool", Content: "Rainy"}, - }, - tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_weather", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Properties: testPropsMap(map[string]api.ToolProperty{ - "city": {Type: api.PropertyType{"string"}}, - }), - }, - }, - }, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\n" + - "# Tools\n\nYou have access to the following functions:\n\n\n" + - "\nget_weather\n" + - "\n" + - "\ncity\nstring\n\n" + - "\n\n\n\n" + - "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + - "\n\n\nvalue_1\n\n" + - "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + - "\n\n\n\n\nReminder:\n" + - "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + - "- Required parameters MUST be specified\n" + - "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + - "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + - "<|im_end|>\n" + - "<|im_start|>user\nWeather in Paris and London?<|im_end|>\n" + - "<|im_start|>assistant\n\n" + - "\n\n\nParis\n\n\n\n" + - "\n\n\nLondon\n\n\n\n<|im_end|>\n" + - "<|im_start|>user\n\nSunny\n\n\nRainy\n\n<|im_end|>\n" + - "<|im_start|>assistant\n\n", - }, - { - name: "thinking disabled when user doesn't request it", - msgs: []api.Message{ - {Role: "user", Content: "Hello!"}, - }, - thinkValue: nil, - expected: "<|im_start|>system\n<|im_end|>\n" + - "<|im_start|>user\nHello!<|im_end|>\n" + - "<|im_start|>assistant\n", - }, - { - name: "complex message history with thinking, tools, tool calls, tool results and content", - msgs: []api.Message{ - {Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"}, - {Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "Paris"})}}, - {Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "London"})}}, - }}, - {Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"}, - {Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"}, - {Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{Name: "calculate", Arguments: testArgs(map[string]any{"expression": "2+2"})}}, - }}, - {Role: "tool", Content: "4", ToolCallID: "call3"}, - {Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."}, - }, - tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_weather", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Properties: testPropsMap(map[string]api.ToolProperty{ - "city": {Type: api.PropertyType{"string"}}, - }), - }, - }, - }, - { - Type: "function", - Function: api.ToolFunction{ - Name: "calculate", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Properties: testPropsMap(map[string]api.ToolProperty{ - "expression": {Type: api.PropertyType{"string"}}, - }), - }, - }, - }, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\n" + - "# Tools\n\nYou have access to the following functions:\n\n\n" + - "\nget_weather\n" + - "\n" + - "\ncity\nstring\n\n" + - "\n\n" + - "\ncalculate\n" + - "\n" + - "\nexpression\nstring\n\n" + - "\n\n\n\n" + - "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + - "\n\n\nvalue_1\n\n" + - "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + - "\n\n\n\n\nReminder:\n" + - "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + - "- Required parameters MUST be specified\n" + - "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + - "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + - "<|im_end|>\n" + - "<|im_start|>user\nWhat's the weather in Paris and London? Also, what's 2+2?<|im_end|>\n" + - "<|im_start|>assistant\n" + - "\nI need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.\n\n" + - "\n\n\nParis\n\n\n\n" + - "\n\n\nLondon\n\n\n\n<|im_end|>\n" + - "<|im_start|>user\n\nSunny, 22°C\n\n\nRainy, 15°C\n\n<|im_end|>\n" + - "<|im_start|>assistant\n" + - "\nNow I have the weather data. Let me calculate 2+2.\n\n" + - "\n\n\n2+2\n\n\n\n<|im_end|>\n" + - "<|im_start|>user\n\n4\n\n<|im_end|>\n" + - "<|im_start|>assistant\n" + - "\nPerfect! I have all the information needed to provide a complete answer.\n\n" + - "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.<|im_end|>\n" + - "<|im_start|>assistant\n\n", - }, - { - name: "empty messages list", - msgs: []api.Message{}, - thinkValue: nil, - expected: "<|im_start|>system\n<|im_end|>\n<|im_start|>assistant\n", - }, - { - name: "tool result with JSON content", - msgs: []api.Message{ - {Role: "user", Content: "Get user info"}, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{Name: "get_user", Arguments: testArgs(map[string]any{"id": "123"})}}, - }, - }, - {Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`}, - }, - tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_user", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Properties: testPropsMap(map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}}), - }, - }, - }, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\n" + - "# Tools\n\nYou have access to the following functions:\n\n\n" + - "\nget_user\n\n" + - "\nid\nstring\n\n" + - "\n\n\n\n" + - "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + - "\n\n\nvalue_1\n\n" + - "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + - "\n\n\n\n\nReminder:\n" + - "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + - "- Required parameters MUST be specified\n" + - "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + - "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + - "<|im_end|>\n" + - "<|im_start|>user\nGet user info<|im_end|>\n" + - "<|im_start|>assistant\n\n" + - "\n\n\n123\n\n\n\n<|im_end|>\n" + - "<|im_start|>user\n\n{\"name\": \"John\", \"age\": 30, \"active\": true}\n\n<|im_end|>\n" + - "<|im_start|>assistant\n\n", - }, - { - name: "assistant message with only thinking no content", - msgs: []api.Message{ - {Role: "user", Content: "Think about this"}, - {Role: "assistant", Thinking: "Deep thoughts here...", Content: ""}, - {Role: "user", Content: "What did you think?"}, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\n<|im_end|>\n" + - "<|im_start|>user\nThink about this<|im_end|>\n" + - "<|im_start|>assistant\n<|im_end|>\n" + - "<|im_start|>user\nWhat did you think?<|im_end|>\n" + - "<|im_start|>assistant\n\n", - }, - { - name: "tool call with complex nested argument", - msgs: []api.Message{ - {Role: "user", Content: "Create data"}, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{ - Name: "create", - Arguments: testArgs(map[string]any{ - "data": map[string]any{"nested": "value", "count": 42}, - }), - }}, - }, - }, - {Role: "tool", Content: "Created"}, - }, - tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "create", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Properties: testPropsMap(map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}}), - }, - }, - }, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\n" + - "# Tools\n\nYou have access to the following functions:\n\n\n" + - "\ncreate\n\n" + - "\ndata\nobject\n\n" + - "\n\n\n\n" + - "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + - "\n\n\nvalue_1\n\n" + - "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + - "\n\n\n\n\nReminder:\n" + - "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + - "- Required parameters MUST be specified\n" + - "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + - "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + - "<|im_end|>\n" + - "<|im_start|>user\nCreate data<|im_end|>\n" + - "<|im_start|>assistant\n\n" + - "\n\n\n{\"count\":42,\"nested\":\"value\"}\n\n\n\n<|im_end|>\n" + - "<|im_start|>user\n\nCreated\n\n<|im_end|>\n" + - "<|im_start|>assistant\n\n", - }, - { - name: "content explaining the format itself", - msgs: []api.Message{ - {Role: "user", Content: "How do I format a tool call?"}, - {Role: "assistant", Content: "To call a tool, use tags with inside."}, - {Role: "user", Content: "Thanks!"}, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\n<|im_end|>\n" + - "<|im_start|>user\nHow do I format a tool call?<|im_end|>\n" + - "<|im_start|>assistant\nTo call a tool, use tags with inside.<|im_end|>\n" + - "<|im_start|>user\nThanks!<|im_end|>\n" + - "<|im_start|>assistant\n\n", - }, - { - name: "unicode in content and tool args", - msgs: []api.Message{ - {Role: "user", Content: "Translate 你好"}, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - {Function: api.ToolCallFunction{Name: "translate", Arguments: testArgs(map[string]any{"text": "你好"})}}, - }, - }, - {Role: "tool", Content: "Hello"}, - }, - tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "translate", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Properties: testPropsMap(map[string]api.ToolProperty{ - "text": {Type: api.PropertyType{"string"}}, - }), - }, - }, - }, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>system\n" + - "# Tools\n\nYou have access to the following functions:\n\n\n" + - "\ntranslate\n\n" + - "\ntext\nstring\n\n" + - "\n\n\n\n" + - "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + - "\n\n\nvalue_1\n\n" + - "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + - "\n\n\n\n\nReminder:\n" + - "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + - "- Required parameters MUST be specified\n" + - "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + - "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + - "<|im_end|>\n" + - "<|im_start|>user\nTranslate 你好<|im_end|>\n" + - "<|im_start|>assistant\n\n" + - "\n\n\n你好\n\n\n\n<|im_end|>\n" + - "<|im_start|>user\n\nHello\n\n<|im_end|>\n" + - "<|im_start|>assistant\n\n", + expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0]Describe the first image.<|im_end|>\n<|im_start|>assistant\nIt shows something.<|im_end|>\n<|im_start|>user\n[img-1][img-2]Compare these.<|im_end|>\n\n<|im_start|>assistant\n\n", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { renderer := &Nemotron3NanoRenderer{} - rendered, err := renderer.Render(tt.msgs, tt.tools, tt.thinkValue) + rendered, err := renderer.Render(tt.msgs, nil, nil) if err != nil { t.Fatal(err) } - if diff := cmp.Diff(rendered, tt.expected); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) + if diff := cmp.Diff(tt.expected, rendered); diff != "" { + t.Fatalf("mismatch (-want +got):\n%s", diff) } }) } diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go index 84cc78f8d..95ca1a178 100644 --- a/model/renderers/renderer.go +++ b/model/renderers/renderer.go @@ -95,6 +95,8 @@ func rendererForName(name string) Renderer { return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags} case "lfm2-thinking": return &LFM2Renderer{IsThinking: true, useImgTags: RenderImgTags} + case "laguna": + return &LagunaRenderer{} default: return nil } diff --git a/model/renderers/testdata/nemotron3nano_chat_template.jinja2 b/model/renderers/testdata/nemotron3nano_chat_template.jinja2 new file mode 100644 index 000000000..0f45006a8 --- /dev/null +++ b/model/renderers/testdata/nemotron3nano_chat_template.jinja2 @@ -0,0 +1,222 @@ +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '' }} + {%- else %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{% endmacro %} +{%- set enable_thinking = enable_thinking if enable_thinking is defined else True %} +{%- set reasoning_budget = reasoning_budget if reasoning_budget is defined else None %} +{%- set truncate_history_thinking = truncate_history_thinking if truncate_history_thinking is defined else True %} +{%- set response_format = response_format if response_format is defined else None %} + +{# Scan messages for VLM thinking toggles to override enable_thinking #} +{%- set toggle = namespace(enable=enable_thinking) %} +{%- for m in messages %} + {%- if m['role'] == 'user' or m['role'] == 'system' -%} + {%- if m['content'] is string -%} + {%- set c = m['content'] %} + {%- if '/think' in c.replace('', '') -%} + {%- set toggle.enable = true -%} + {%- elif '/no_think' in c -%} + {%- set toggle.enable = false -%} + {%- endif -%} + {%- endif -%} + {%- endif -%} +{%- endfor %} + +{# Prepare message iteration similar to LM template #} +{%- set ns = namespace(last_user_idx = -1) %} +{%- set loop_messages = messages %} +{%- for m in loop_messages %} + {%- if m["role"] == "user" %} + {%- set ns.last_user_idx = loop.index0 %} + {%- endif %} +{%- endfor %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set system_message = "" %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} +{# Recompute last_user_idx relative to loop_messages after handling system #} +{%- set ns = namespace(last_user_idx = -1) %} +{%- for m in loop_messages %} + {%- if m["role"] == "user" %} + {%- set ns.last_user_idx = loop.index0 %} + {%- endif %} +{%- endfor %} + +{# System preamble with LM formatting, sanitize thinking toggles #} +{%- if system_message is defined %} + {%- set sys_content = system_message | string %} + {%- set sys_content = sys_content.replace('', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '') %} + {{- "<|im_start|>system\n" + sys_content }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\n" }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {%- if system_message is defined and system_message | length > 0 %} + {{- "\n\n" }} + {%- endif %} + {{- "# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n\n" ~ tool.name ~ "" }} + {%- if tool.description is defined %} + {{- '\n' ~ (tool.description | trim) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n' }} + {{- '\n' ~ param_name ~ '' }} + {%- if param_fields.type is defined %} + {{- '\n' ~ (param_fields.type | string) ~ '' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n' ~ (param_fields.description | trim) ~ '' }} + {%- endif %} + {%- if param_fields.enum is defined %} + {{- '\n' ~ (param_fields.enum | tojson | safe) ~ '' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description', 'enum'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {% set handled_keys = ['type', 'properties', 'required'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {%- if tool.parameters is defined and tool.parameters.required is defined %} + {{- '\n' ~ (tool.parameters.required | tojson | safe) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {{- "\n" }} + + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} +{%- endif %} + +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} + +{# Iterate conversation #} +{%- for message in loop_messages %} + {%- if message.role == "assistant" %} + {# Use LM assistant handling #} + {%- if message.reasoning_content is defined and message.reasoning_content is string and message.reasoning_content | trim | length > 0 %} + {%- set content = "\n" ~ message.reasoning_content ~ "\n\n" ~ (message.content | default('', true)) %} + {%- else %} + {%- set content = message.content | default('', true) %} + {%- if content is string -%} + {%- if '' not in content and '' not in content -%} + {%- set content = "" ~ content -%} + {%- endif -%} + {%- else -%} + {%- set content = content -%} + {%- endif -%} + {%- endif %} + {%- if message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>assistant\n' }} + {%- set include_content = not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %} + {%- if content is string and content | trim | length > 0 %} + {%- if include_content %} + {{- (content | trim) ~ '\n' -}} + {%- else %} + {%- set c = (content | string) %} + {%- if '' in c %} + {%- set c = c.split('')[-1] %} + {%- elif '' in c %} + {%- set c = c.split('')[0] %} + {%- endif %} + {%- set c = "" ~ c | trim %} + {%- if c | length > 0 %} + {{- c ~ '\n' -}} + {%- endif %} + {%- endif %} + {%- else %} + {{- "" -}} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n' -}} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' -}} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value ~ '\n\n' -}} + {%- endfor %} + {%- endif %} + {{- '\n\n' -}} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- else %} + {%- if not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %} + {{- '<|im_start|>assistant\n' ~ (content | default('', true) | string | trim) ~ '<|im_end|>\n' }} + {%- else %} + {%- set c = (content | default('', true) | string) %} + {%- if '' in c and '' in c %} + {%- set c = "" ~ c.split('')[-1] %} + {%- endif %} + {%- set c = c | trim %} + {%- if c | length > 0 %} + {{- '<|im_start|>assistant\n' ~ c ~ '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>assistant\n<|im_end|>\n' }} + {%- endif %} + {%- endif %} + {%- endif %} + {%- elif message.role == "user" or message.role == "system" %} + {{- '<|im_start|>' + message.role + '\n' }} + {%- set content = message.content | string %} + {{- content }} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user\n' }} + {%- endif %} + {{- '\n' }} + {{- message.content }} + {{- '\n\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} + +{# Generation prompt using computed thinking toggle #} +{%- if add_generation_prompt %} + {%- if toggle.enable %} + {{- '<|im_start|>assistant\n\n' }} + {%- else %} + {{- '<|im_start|>assistant\n' }} + {%- endif %} +{%- endif %} diff --git a/server/create.go b/server/create.go index 9ddb2bf8b..49ca901e6 100644 --- a/server/create.go +++ b/server/create.go @@ -494,15 +494,18 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, for _, layer := range baseLayers { if layer.GGML != nil { quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization)) + ft := layer.GGML.KV().FileType() + if quantType == "" && hasSourceFP8Tensors(layer.GGML.KV()) && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" && slices.Contains([]string{"F16", "BF16", "F32"}, ft.String()) { + quantType = "Q8_0" + } if quantType != "" && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" { want, err := ggml.ParseFileType(quantType) if err != nil { return err } - ft := layer.GGML.KV().FileType() - if !slices.Contains([]string{"F16", "F32"}, ft.String()) { - return errors.New("quantization is only supported for F16 and F32 models") + if !slices.Contains([]string{"F16", "BF16", "F32"}, ft.String()) { + return errors.New("quantization is only supported for F16, BF16 and F32 models") } else if ft != want { layer, err = quantizeLayer(layer, quantType, fn) if err != nil { @@ -531,6 +534,12 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, } r.Parameters["stop"] = []string{""} } + case "laguna": + config.Renderer = cmp.Or(config.Renderer, "laguna") + config.Parser = cmp.Or(config.Parser, "laguna") + case "nemotron_h", "nemotron_h_moe", "nemotron_h_omni": + config.Renderer = cmp.Or(config.Renderer, "nemotron-3-nano") + config.Parser = cmp.Or(config.Parser, "nemotron-3-nano") } } } @@ -606,6 +615,10 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, return nil } +func hasSourceFP8Tensors(kv ggml.KV) bool { + return kv.String("source_quantization") == "hf_fp8" && len(kv.Strings("source_fp8_tensors")) > 0 +} + func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.ProgressResponse)) (*layerGGML, error) { ft := layer.GGML.KV().FileType() var doneBytes atomic.Uint64 diff --git a/server/laguna_quantization_test.go b/server/laguna_quantization_test.go new file mode 100644 index 000000000..2f90b9f6c --- /dev/null +++ b/server/laguna_quantization_test.go @@ -0,0 +1,90 @@ +package server + +import ( + "testing" + + fsggml "github.com/ollama/ollama/fs/ggml" +) + +func TestLagunaGGUFQuantization(t *testing.T) { + cases := []struct { + name string + tensor string + originalType fsggml.TensorType + requestedType fsggml.TensorType + fileType fsggml.FileType + blockCount int + wantType fsggml.TensorType + wantQuantize bool + }{ + { + name: "non_routed_weights_preserved", + tensor: "blk.1.attn_q.weight", + originalType: fsggml.TensorTypeBF16, + requestedType: fsggml.TensorTypeQ8_0, + fileType: fsggml.FileTypeQ8_0, + blockCount: 2, + wantType: fsggml.TensorTypeBF16, + wantQuantize: false, + }, + { + name: "shared_expert_weights_preserved", + tensor: "blk.1.ffn_gate_shexp.weight", + originalType: fsggml.TensorTypeBF16, + requestedType: fsggml.TensorTypeQ4_K, + fileType: fsggml.FileTypeQ4_K_M, + blockCount: 2, + wantType: fsggml.TensorTypeBF16, + wantQuantize: false, + }, + { + name: "routed_gate_q8", + tensor: "blk.1.ffn_gate_exps.weight", + originalType: fsggml.TensorTypeBF16, + requestedType: fsggml.TensorTypeQ8_0, + fileType: fsggml.FileTypeQ8_0, + blockCount: 2, + wantType: fsggml.TensorTypeQ8_0, + wantQuantize: true, + }, + { + name: "routed_down_q4_promoted", + tensor: "blk.1.ffn_down_exps.weight", + originalType: fsggml.TensorTypeBF16, + requestedType: fsggml.TensorTypeQ4_K, + fileType: fsggml.FileTypeQ4_K_M, + blockCount: 2, + wantType: fsggml.TensorTypeQ6_K, + wantQuantize: true, + }, + { + name: "routed_down_q4_not_promoted_when_q8_requested", + tensor: "blk.1.ffn_down_exps.weight", + originalType: fsggml.TensorTypeBF16, + requestedType: fsggml.TensorTypeQ8_0, + fileType: fsggml.FileTypeQ4_K_M, + blockCount: 2, + wantType: fsggml.TensorTypeQ8_0, + wantQuantize: true, + }, + { + name: "routed_down_q4_k_s_promoted", + tensor: "blk.0.ffn_down_exps.weight", + originalType: fsggml.TensorTypeBF16, + requestedType: fsggml.TensorTypeQ4_K, + fileType: fsggml.FileTypeQ4_K_S, + blockCount: 8, + wantType: fsggml.TensorTypeQ5_K, + wantQuantize: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + gotType, gotQuantize := lagunaGGUFQuantization(tt.tensor, tt.originalType, tt.requestedType, tt.fileType, tt.blockCount) + if gotType != tt.wantType || gotQuantize != tt.wantQuantize { + t.Fatalf("lagunaGGUFQuantization(%q) = (%s, %v), want (%s, %v)", tt.tensor, gotType, gotQuantize, tt.wantType, tt.wantQuantize) + } + }) + } +} diff --git a/server/quantization.go b/server/quantization.go index ee70b5fc0..8d917645c 100644 --- a/server/quantization.go +++ b/server/quantization.go @@ -7,6 +7,7 @@ import ( "maps" "os" "slices" + "strconv" "strings" "unsafe" @@ -51,11 +52,14 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) { } type quantizeState struct { - nAttnV int // Number of attn_*v* weight tensors - nFfnDown int // Number of ffn_down tensors - iAttnV int // Running counter of number of attn_v tensors that have been processed - iFfnDown int // Running counter of number of ffn_down tensors that have been processed - hasOutput bool // used to figure out if a model shares tok_embd with the output weight + nAttnV int // Number of attn_*v* weight tensors + nFfnDown int // Number of ffn_down tensors + iAttnV int // Running counter of number of attn_v tensors that have been processed + iFfnDown int // Running counter of number of ffn_down tensors that have been processed + hasOutput bool // used to figure out if a model shares tok_embd with the output weight + preserveSourceFP8ToQ8 bool + preserveSourceQ4 bool + sourceFP8Tensors map[string]struct{} } func useMoreBits(iLayer, nLayers int) bool { @@ -108,6 +112,53 @@ func qwen3LinearAttnQuantType(name string) (fsggml.TensorType, bool) { return 0, false } +func isLagunaGGUFRoutedExpertWeight(name string) bool { + return strings.HasSuffix(name, ".weight") && (strings.Contains(name, "ffn_gate_exps") || + strings.Contains(name, "ffn_up_exps") || + strings.Contains(name, "ffn_down_exps")) +} + +func lagunaGGUFBlockIndex(name string) (int, bool) { + if !strings.HasPrefix(name, "blk.") { + return 0, false + } + + parts := strings.SplitN(strings.TrimPrefix(name, "blk."), ".", 2) + if len(parts) != 2 { + return 0, false + } + + i, err := strconv.Atoi(parts[0]) + if err != nil { + return 0, false + } + + return i, true +} + +func lagunaGGUFQuantization(name string, originalType, requestedType fsggml.TensorType, ftype fsggml.FileType, blockCount int) (fsggml.TensorType, bool) { + if !isLagunaGGUFRoutedExpertWeight(name) { + return originalType, false + } + + if strings.HasSuffix(name, ".ffn_down_exps.weight") { + if i, ok := lagunaGGUFBlockIndex(name); ok && blockCount > 0 { + switch ftype { + case fsggml.FileTypeQ4_K_M: + if requestedType != fsggml.TensorTypeQ8_0 && useMoreBits(i, blockCount) { + return fsggml.TensorTypeQ6_K, true + } + case fsggml.FileTypeQ4_K_S: + if requestedType != fsggml.TensorTypeQ8_0 && i < blockCount/8 { + return fsggml.TensorTypeQ5_K, true + } + } + } + } + + return requestedType, true +} + func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType, name string, shape []uint64, ftype fsggml.FileType) fsggml.TensorType { // Ported from llama_tensor_get_type, removed unsupported quantization types nExperts := max(1, kv.Uint("expert_count", 0)) @@ -120,10 +171,10 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType newType = fsggml.TensorTypeQ6_K } } else if strings.Contains(name, "attn_v.weight") { - if (ftype == fsggml.FileTypeQ4_K_M) && + if newType != fsggml.TensorTypeQ8_0 && (ftype == fsggml.FileTypeQ4_K_M) && useMoreBits(qs.iAttnV, qs.nAttnV) { newType = fsggml.TensorTypeQ6_K - } else if ftype == fsggml.FileTypeQ4_K_S && qs.iAttnV < 4 { + } else if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_S && qs.iAttnV < 4 { newType = fsggml.TensorTypeQ5_K } @@ -158,31 +209,35 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType // expert alphabetically, so dense increments the counter and expert uses counter-1. var iLayer int if strings.Contains(name, "_exps") { + if kv.Architecture() == "laguna" { + goto finalize + } iLayer = max(0, qs.iFfnDown-1) } else { iLayer = qs.iFfnDown qs.iFfnDown++ } n_layer := qs.nFfnDown - if ftype == fsggml.FileTypeQ4_K_M { + if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_M { if useMoreBits(iLayer, n_layer) { newType = fsggml.TensorTypeQ6_K } - } else if ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 { + } else if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 { newType = fsggml.TensorTypeQ5_K } } else if strings.Contains(name, "attn_output.weight") { - if nExperts == 8 { + if newType != fsggml.TensorTypeQ8_0 && nExperts == 8 { if ftype == fsggml.FileTypeQ4_K_S || ftype == fsggml.FileTypeQ4_K_M { newType = fsggml.TensorTypeQ5_K } } } else if strings.Contains(name, "attn_qkv.weight") { - if ftype == fsggml.FileTypeQ4_K_M { + if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_M { newType = fsggml.TensorTypeQ5_K } } +finalize: if newType.IsQuantized() { nx := shape[0] qk_k := newType.BlockSize() @@ -218,7 +273,12 @@ func quantize(in, out *os.File, orig *fsggml.GGML, newFileType fsggml.FileType, kv := maps.Clone(orig.KV()) kv["general.file_type"] = newFileType // kv["general.quantization_version"] = ggml.QuantizationVersion() - qs := &quantizeState{} + qs := &quantizeState{ + sourceFP8Tensors: sourceFP8TensorSet(kv), + } + hasSourceFP8 := hasSourceFP8Tensors(kv) + qs.preserveSourceFP8ToQ8 = hasSourceFP8 && newFileType == fsggml.FileTypeQ8_0 + qs.preserveSourceQ4 = hasSourceFP8 && slices.Contains([]fsggml.FileType{fsggml.FileTypeQ4_K_M, fsggml.FileTypeQ4_K_S}, newFileType) // Build up the quantize state so newType can adjust types layerCount := 0 for k, l := range orig.Tensors().GroupLayers() { @@ -304,13 +364,34 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil newType := fsggml.TensorType(t.Kind) if quantize { + if qs.preserveSourceFP8ToQ8 { + if _, ok := qs.sourceFP8Tensors[name]; !ok { + return newType + } + } + if slices.Contains([]string{"qwen3next", "qwen35", "qwen35moe"}, kv.Architecture()) && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) { if qt, ok := qwen3LinearAttnQuantType(name); ok { return qt } } + // TODO: Consider extracting architecture-specific GGUF quantization policy + // from server so different quantization backends can share one source of + // truth for model-family specializations. // get more optimal quantization type based on the tensor shape, layer, etc. + if qs.preserveSourceQ4 { + if _, ok := qs.sourceFP8Tensors[name]; !ok { + defaultType = fsggml.TensorTypeQ8_0 + } + } + if kv.Architecture() == "laguna" { + var ok bool + defaultType, ok = lagunaGGUFQuantization(name, newType, defaultType, ftype, int(kv.Uint("block_count", 0))) + if !ok { + return newType + } + } newType = getTensorNewType(kv, qs, defaultType, t.Name, t.Shape, ftype) if newType != defaultType { slog.Debug("tensor quantization adjusted for better quality", "name", t.Name, "requested", defaultType, "quantization", newType) @@ -318,3 +399,16 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil } return newType } + +func sourceFP8TensorSet(kv fsggml.KV) map[string]struct{} { + names := kv.Strings("source_fp8_tensors") + if len(names) == 0 { + return nil + } + + out := make(map[string]struct{}, len(names)) + for _, name := range names { + out[name] = struct{}{} + } + return out +} diff --git a/server/quantization_test.go b/server/quantization_test.go index f8f10659c..49665c36f 100644 --- a/server/quantization_test.go +++ b/server/quantization_test.go @@ -308,6 +308,95 @@ func TestQuantizeModel(t *testing.T) { "output.weight": fsggml.TensorTypeQ8_0, }, }, + { + name: "source_fp8_q8_preserves_bf16_tensors", + kv: map[string]any{ + "general.architecture": "test", + "source_quantization": "hf_fp8", + "source_fp8_tensors": []string{"blk.1.ffn_down_exps.weight"}, + }, + tensors: []*fsggml.Tensor{ + { + Name: "blk.1.ffn_down_exps.weight", Kind: uint32(fsggml.TensorTypeBF16), + Offset: uint64(0), Shape: []uint64{256, 1}, + WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]), + }, + { + Name: "blk.1.attn_q.weight", Kind: uint32(fsggml.TensorTypeBF16), + Offset: uint64(0), Shape: []uint64{256, 1}, + WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]), + }, + }, + newType: "Q8_0", + expectedTensorTypes: map[string]fsggml.TensorType{ + "blk.1.ffn_down_exps.weight": fsggml.TensorTypeQ8_0, + "blk.1.attn_q.weight": fsggml.TensorTypeBF16, + }, + }, + { + name: "source_fp8_q4_promotes_bf16_tensors_to_q8", + kv: map[string]any{ + "general.architecture": "test", + "source_quantization": "hf_fp8", + "source_fp8_tensors": []string{ + "blk.1.ffn_gate_exps.weight", + "blk.1.ffn_down_exps.weight", + }, + }, + tensors: []*fsggml.Tensor{ + { + Name: "blk.1.ffn_gate_exps.weight", Kind: uint32(fsggml.TensorTypeBF16), + Offset: uint64(0), Shape: []uint64{256, 1}, + WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]), + }, + { + Name: "blk.1.ffn_down_exps.weight", Kind: uint32(fsggml.TensorTypeBF16), + Offset: uint64(0), Shape: []uint64{256, 1}, + WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]), + }, + { + Name: "blk.1.attn_q.weight", Kind: uint32(fsggml.TensorTypeBF16), + Offset: uint64(0), Shape: []uint64{256, 1}, + WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]), + }, + { + Name: "blk.1.attn_v.weight", Kind: uint32(fsggml.TensorTypeBF16), + Offset: uint64(0), Shape: []uint64{256, 1}, + WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]), + }, + { + Name: "blk.1.ffn_down.weight", Kind: uint32(fsggml.TensorTypeBF16), + Offset: uint64(0), Shape: []uint64{256, 1}, + WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]), + }, + { + Name: "blk.1.attn_q_norm.weight", Kind: uint32(fsggml.TensorTypeBF16), + Offset: uint64(0), Shape: []uint64{256}, + WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]), + }, + { + Name: "blk.1.ffn_gate_inp.weight", Kind: uint32(fsggml.TensorTypeBF16), + Offset: uint64(0), Shape: []uint64{256, 1}, + WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]), + }, + { + Name: "output.weight", Kind: uint32(fsggml.TensorTypeBF16), + Offset: uint64(0), Shape: []uint64{256, 1}, + WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]), + }, + }, + newType: "Q4_K_M", + expectedTensorTypes: map[string]fsggml.TensorType{ + "blk.1.ffn_gate_exps.weight": fsggml.TensorTypeQ4_K, + "blk.1.ffn_down_exps.weight": fsggml.TensorTypeQ6_K, + "blk.1.attn_q.weight": fsggml.TensorTypeQ8_0, + "blk.1.attn_v.weight": fsggml.TensorTypeQ8_0, + "blk.1.ffn_down.weight": fsggml.TensorTypeQ8_0, + "blk.1.attn_q_norm.weight": fsggml.TensorTypeBF16, + "blk.1.ffn_gate_inp.weight": fsggml.TensorTypeBF16, + "output.weight": fsggml.TensorTypeQ8_0, + }, + }, { name: "f32_short_data", kv: map[string]any{ diff --git a/server/routes.go b/server/routes.go index ef3aff174..42caff02c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -618,8 +618,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if builtinParser != nil { - // only send messages with meaningful content (empty messages confuse clients) - if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 { + // Emit chunks that carry logprobs even if the parser is still buffering + // visible content, otherwise generate logprobs disappear for models with + // builtin thinking/tool parsers. + if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 || len(res.Logprobs) > 0 { ch <- res } diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 3a4dfb6dc..0e1fc5819 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -102,6 +102,36 @@ func createRequest(t *testing.T, fn func(*gin.Context), body any) *httptest.Resp return w.ResponseRecorder } +func readCreatedModelConfig(t *testing.T, name string) model.ConfigV2 { + t.Helper() + + mf, err := manifest.ParseNamedManifest(model.ParseName(name)) + if err != nil { + t.Fatalf("parse manifest: %v", err) + } + if mf.Config.Digest == "" { + t.Fatalf("unexpected empty config digest for manifest") + } + + configPath, err := manifest.BlobsPath(mf.Config.Digest) + if err != nil { + t.Fatalf("config blob path: %v", err) + } + + cfgFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("open config blob: %v", err) + } + defer cfgFile.Close() + + var cfg model.ConfigV2 + if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil { + t.Fatalf("decode config: %v", err) + } + + return cfg +} + func checkFileExists(t *testing.T, p string, expect []string) { t.Helper() @@ -981,6 +1011,134 @@ func TestCreateGemma4KeepsDynamicRendererAlias(t *testing.T) { } } +func TestCreateLagunaDetectsRendererParser(t *testing.T) { + gin.SetMode(gin.TestMode) + + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + _, digest := createBinFile(t, ggml.KV{ + "general.architecture": "laguna", + "general.parameter_count": uint64(33_400_000_000), + }, nil) + + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test", + Files: map[string]string{"test.gguf": digest}, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + mf, err := manifest.ParseNamedManifest(model.ParseName("test")) + if err != nil { + t.Fatalf("parse manifest: %v", err) + } + if mf.Config.Digest == "" { + t.Fatalf("unexpected empty config digest for manifest") + } + + configPath, err := manifest.BlobsPath(mf.Config.Digest) + if err != nil { + t.Fatalf("config blob path: %v", err) + } + + cfgFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("open config blob: %v", err) + } + defer cfgFile.Close() + + var cfg model.ConfigV2 + if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil { + t.Fatalf("decode config: %v", err) + } + + if cfg.Renderer != "laguna" { + t.Fatalf("expected renderer %q, got %q", "laguna", cfg.Renderer) + } + if cfg.Parser != "laguna" { + t.Fatalf("expected parser %q, got %q", "laguna", cfg.Parser) + } +} + +func TestCreateNemotronHDefaultsRendererParser(t *testing.T) { + gin.SetMode(gin.TestMode) + + for _, arch := range []string{"nemotron_h", "nemotron_h_moe", "nemotron_h_omni"} { + t.Run(arch, func(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + _, digest := createBinFile(t, ggml.KV{ + "general.architecture": arch, + }, nil) + + name := strings.ReplaceAll(arch, "_", "-") + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: name, + Files: map[string]string{"test.gguf": digest}, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + cfg := readCreatedModelConfig(t, name) + if cfg.Renderer != "nemotron-3-nano" { + t.Fatalf("expected renderer %q, got %q", "nemotron-3-nano", cfg.Renderer) + } + if cfg.Parser != "nemotron-3-nano" { + t.Fatalf("expected parser %q, got %q", "nemotron-3-nano", cfg.Parser) + } + }) + } +} + +func TestCreateNemotronHDefaultsKeepExplicitRendererParser(t *testing.T) { + gin.SetMode(gin.TestMode) + + for _, arch := range []string{"nemotron_h", "nemotron_h_moe", "nemotron_h_omni"} { + t.Run(arch, func(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + _, digest := createBinFile(t, ggml.KV{ + "general.architecture": arch, + }, nil) + + const ( + renderer = "custom-renderer" + parser = "custom-parser" + ) + + name := strings.ReplaceAll(arch, "_", "-") + "-custom" + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: name, + Files: map[string]string{"test.gguf": digest}, + Renderer: renderer, + Parser: parser, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + cfg := readCreatedModelConfig(t, name) + if cfg.Renderer != renderer { + t.Fatalf("expected renderer %q, got %q", renderer, cfg.Renderer) + } + if cfg.Parser != parser { + t.Fatalf("expected parser %q, got %q", parser, cfg.Parser) + } + }) + } +} + func TestDetectModelTypeFromFiles(t *testing.T) { t.Run("gguf file", func(t *testing.T) { _, digest := createBinFile(t, nil, nil) diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 80442c238..008b499cd 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -1473,6 +1473,119 @@ func TestGenerateLogprobs(t *testing.T) { }) } +func TestGenerateLogprobsWithBuiltinParser(t *testing.T) { + gin.SetMode(gin.TestMode) + + mock := mockRunner{} + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + responses := []llm.CompletionResponse{ + { + Content: "h", + Logprobs: []llm.Logprob{{TokenLogprob: llm.TokenLogprob{Token: "h", Logprob: -0.1}}}, + }, + { + Content: "iHello", + Logprobs: []llm.Logprob{{TokenLogprob: llm.TokenLogprob{Token: "hi", Logprob: -0.2}}}, + }, + { + Content: " world", + Done: true, + DoneReason: llm.DoneReasonStop, + Logprobs: []llm.Logprob{{TokenLogprob: llm.TokenLogprob{Token: " world", Logprob: -0.3}}}, + }, + } + + for _, resp := range responses { + select { + case <-ctx.Done(): + return ctx.Err() + default: + fn(resp) + } + } + + return nil + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: getGpuFn, + getSystemInfoFn: getSystemInfoFn, + waitForRecovery: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + req.successCh <- &runnerRef{llama: &mock} + return false + }, + }, + } + + go s.sched.Run(t.Context()) + + _, digest := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + "llama.block_count": uint32(1), + "llama.context_length": uint32(8192), + "llama.embedding_length": uint32(4096), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(8), + "tokenizer.ggml.tokens": []string{""}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []*ggml.Tensor{ + {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + }) + + if w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test-generate-logprob-parser", + Files: map[string]string{"file.gguf": digest}, + Parser: "deepseek3", + Template: `{{ .Prompt }}`, + Stream: &stream, + }); w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + noStream := false + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-generate-logprob-parser", + Prompt: "Why is the sky blue?", + Stream: &noStream, + Logprobs: true, + Options: map[string]any{ + "temperature": 0, + }, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp api.GenerateResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if got := len(resp.Logprobs); got != 3 { + t.Fatalf("expected 3 logprob entries, got %d", got) + } +} + func TestChatLogprobs(t *testing.T) { t.Run("invalid top_logprobs negative", func(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/server/sched.go b/server/sched.go index f040e34f3..d784c2751 100644 --- a/server/sched.go +++ b/server/sched.go @@ -418,7 +418,7 @@ func (s *Scheduler) load(req *LlmRequest, systemInfo ml.SystemInfo, gpus []ml.De // Some architectures are not safe with num_parallel > 1. // ref: https://github.com/ollama/ollama/issues/4165 - if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen35", "qwen35moe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 { + if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen35", "qwen35moe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe", "nemotron_h_omni"}, req.model.Config.ModelFamily) && numParallel != 1 { numParallel = 1 slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily) } diff --git a/x/create/client/create.go b/x/create/client/create.go index c5962fdd7..f386813e2 100644 --- a/x/create/client/create.go +++ b/x/create/client/create.go @@ -18,6 +18,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/manifest" + modelparsers "github.com/ollama/ollama/model/parsers" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/types/model" @@ -132,11 +133,11 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error { if isSafetensors { modelType = "safetensors model" spinnerKey = "create" - capabilities = inferSafetensorsCapabilities(opts.ModelDir) // Set parser and renderer name based on architecture parserName = getParserName(opts.ModelDir) rendererName = getRendererName(opts.ModelDir) + capabilities = inferSafetensorsCapabilities(opts.ModelDir, resolveParserName(opts.Modelfile, parserName)) } else { modelType = "image generation model" spinnerKey = "imagegen" @@ -183,7 +184,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error { return nil } -func inferSafetensorsCapabilities(modelDir string) []string { +func inferSafetensorsCapabilities(modelDir, parserName string) []string { capabilities := []string{"completion"} // Qwen3.5 multimodal checkpoints use ConditionalGeneration architectures. @@ -195,7 +196,16 @@ func inferSafetensorsCapabilities(modelDir string) []string { capabilities = append(capabilities, "audio") } - if supportsThinking(modelDir) { + var builtinParser modelparsers.Parser + if parserName != "" { + builtinParser = modelparsers.ParserForName(parserName) + } + + if builtinParser != nil && builtinParser.HasToolSupport() { + capabilities = append(capabilities, "tools") + } + + if supportsThinking(modelDir) || (builtinParser != nil && builtinParser.HasThinkingSupport()) { capabilities = append(capabilities, "thinking") } @@ -453,8 +463,8 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) { return layers, nil } -// supportsThinking checks if the model supports thinking mode based on its architecture. -// This reads the config.json from the model directory and checks the architectures field. +// supportsThinking checks if the model supports thinking mode based on known +// architectures that do not expose a cleaner signal in their local metadata. func supportsThinking(modelDir string) bool { configPath := filepath.Join(modelDir, "config.json") data, err := os.ReadFile(configPath) @@ -554,6 +564,9 @@ func getParserName(modelDir string) string { // Check architectures for known parsers for _, arch := range cfg.Architectures { archLower := strings.ToLower(arch) + if strings.Contains(archLower, "laguna") { + return "laguna" + } if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") { return "glm-4.7" } @@ -571,6 +584,9 @@ func getParserName(modelDir string) string { // Also check model_type if cfg.ModelType != "" { typeLower := strings.ToLower(cfg.ModelType) + if strings.Contains(typeLower, "laguna") { + return "laguna" + } if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") { return "glm-4.7" } @@ -608,6 +624,9 @@ func getRendererName(modelDir string) string { // Check architectures for known renderers for _, arch := range cfg.Architectures { archLower := strings.ToLower(arch) + if strings.Contains(archLower, "laguna") { + return "laguna" + } if strings.Contains(archLower, "gemma4") { return "gemma4" } @@ -625,6 +644,9 @@ func getRendererName(modelDir string) string { // Also check model_type if cfg.ModelType != "" { typeLower := strings.ToLower(cfg.ModelType) + if strings.Contains(typeLower, "laguna") { + return "laguna" + } if strings.Contains(typeLower, "gemma4") { return "gemma4" } diff --git a/x/create/client/create_test.go b/x/create/client/create_test.go index a8a6fc4a8..ead7e4233 100644 --- a/x/create/client/create_test.go +++ b/x/create/client/create_test.go @@ -352,7 +352,7 @@ func TestInferSafetensorsCapabilities(t *testing.T) { t.Fatal(err) } - if got := inferSafetensorsCapabilities(dir); !slices.Equal(got, tt.want) { + if got := inferSafetensorsCapabilities(dir, ""); !slices.Equal(got, tt.want) { t.Fatalf("inferSafetensorsCapabilities() = %#v, want %#v", got, tt.want) } }) @@ -554,6 +554,11 @@ func TestSupportsThinking(t *testing.T) { configJSON: `{"model_type": "deepseek"}`, want: true, }, + { + name: "laguna architecture without template", + configJSON: `{"architectures": ["LagunaForCausalLM"], "model_type": "laguna"}`, + want: false, + }, { name: "empty config", configJSON: `{}`, @@ -584,6 +589,55 @@ func TestSupportsThinking_NoConfig(t *testing.T) { } } +func TestInferSafetensorsCapabilitiesFromParser(t *testing.T) { + tests := []struct { + name string + parserName string + want []string + }{ + { + name: "laguna tools and thinking", + parserName: "laguna", + want: []string{"completion", "tools", "thinking"}, + }, + { + name: "functiongemma tools only", + parserName: "functiongemma", + want: []string{"completion", "tools"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil { + t.Fatal(err) + } + + if got := inferSafetensorsCapabilities(dir, tt.parserName); !slices.Equal(got, tt.want) { + t.Fatalf("inferSafetensorsCapabilities() = %#v, want %#v", got, tt.want) + } + }) + } +} + +func TestInferSafetensorsCapabilitiesLaguna(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"architectures": ["LagunaForCausalLM"], "model_type": "laguna"}`), 0o644); err != nil { + t.Fatal(err) + } + + got := inferSafetensorsCapabilities(dir, "laguna") + for _, want := range []string{"completion", "tools", "thinking"} { + if !slices.Contains(got, want) { + t.Fatalf("capabilities %v missing %q", got, want) + } + } + if slices.Contains(got, "vision") || slices.Contains(got, "audio") { + t.Fatalf("unexpected non-text capability in %v", got) + } +} + func TestGetParserName(t *testing.T) { tests := []struct { name string @@ -615,6 +669,11 @@ func TestGetParserName(t *testing.T) { configJSON: `{"model_type": "qwen3"}`, want: "qwen3", }, + { + name: "laguna model", + configJSON: `{"architectures": ["LagunaForCausalLM"], "model_type": "laguna"}`, + want: "laguna", + }, { name: "no config", configJSON: `{}`, @@ -660,6 +719,11 @@ func TestGetRendererName(t *testing.T) { configJSON: `{"architectures": ["LlamaForCausalLM"]}`, want: "", }, + { + name: "laguna model", + configJSON: `{"architectures": ["LagunaForCausalLM"], "model_type": "laguna"}`, + want: "laguna", + }, } for _, tt := range tests { diff --git a/x/create/create.go b/x/create/create.go index 1388b6067..79b174487 100644 --- a/x/create/create.go +++ b/x/create/create.go @@ -776,6 +776,11 @@ type tensorImportTransform interface { quantizationType(name string, shape []int32, quantize string) string } +type sourceFP8TensorImportTransform interface { + sourceFP8TensorQuantization(name string, shape []int32, requested string, fallback string) string + sourceFP8BF16Quantization(name string, shape []int32, requested string) string +} + type noopImportTransform struct{} func (noopImportTransform) skipTensor(string) bool { return false } @@ -804,6 +809,7 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{ "Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform, "Gemma4ForCausalLM": newGemma4ImportTransform, "Gemma4ForConditionalGeneration": newGemma4ImportTransform, + "LagunaForCausalLM": newLagunaImportTransform, } func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) { @@ -842,6 +848,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La if err != nil { return fmt.Errorf("failed to construct import transform for architecture %q: %w", sourceConfig.Architecture(), err) } + sourceFP8Transform, _ := importTransform.(sourceFP8TensorImportTransform) // Resolve the optional packed layer creator var packedCreator PackedTensorLayerCreator @@ -991,9 +998,17 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La // synthetic tests may not pass the generic import size filter. quantizeType = "mxfp8" } - quantizeType = sourceFP8TensorQuantization(outTD.Name, outTD.Shape, quantize, quantizeType) + if sourceFP8Transform != nil { + quantizeType = sourceFP8Transform.sourceFP8TensorQuantization(outTD.Name, outTD.Shape, quantize, quantizeType) + } else { + quantizeType = sourceFP8TensorQuantization(outTD.Name, outTD.Shape, quantize, quantizeType) + } case sourceQuantKind == sourceQuantizedKindSourceFP8: - quantizeType = sourceFP8BF16PromotionQuantization(outTD.Name, outTD.Shape, quantize) + if sourceFP8Transform != nil { + quantizeType = sourceFP8Transform.sourceFP8BF16Quantization(outTD.Name, outTD.Shape, quantize) + } else { + quantizeType = sourceFP8BF16PromotionQuantization(outTD.Name, outTD.Shape, quantize) + } case effectiveQuantize != "": quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, effectiveQuantize) } diff --git a/x/create/laguna.go b/x/create/laguna.go new file mode 100644 index 000000000..8355c9dba --- /dev/null +++ b/x/create/laguna.go @@ -0,0 +1,59 @@ +package create + +import ( + "strings" + + "github.com/ollama/ollama/x/safetensors" +) + +type lagunaImportTransform struct{} + +func newLagunaImportTransform(string, sourceModelConfig) (tensorImportTransform, error) { + return lagunaImportTransform{}, nil +} + +func (lagunaImportTransform) skipTensor(string) bool { return false } + +func (lagunaImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) { + if td == nil { + return nil, nil + } + return []*safetensors.TensorData{td}, nil +} + +func (lagunaImportTransform) quantizationType(name string, shape []int32, quantize string) string { + if !lagunaIsHFRoutedExpertWeight(name) { + return "" + } + return GetTensorQuantization(name, shape, quantize) +} + +func (lagunaImportTransform) sourceFP8TensorQuantization(name string, shape []int32, requested string, fallback string) string { + if !lagunaIsHFRoutedExpertWeight(name) { + return "" + } + + switch normalizeQuantType(requested) { + case "nvfp4", "mxfp4": + if lagunaKeepSourceFP8TensorAtMXFP8(name, shape) { + return "mxfp8" + } + } + return fallback +} + +func (lagunaImportTransform) sourceFP8BF16Quantization(string, []int32, string) string { + return "" +} + +func lagunaKeepSourceFP8TensorAtMXFP8(name string, shape []int32) bool { + if len(shape) != 2 || !isAligned(shape, "mxfp8") { + return false + } + + return strings.Contains(name, "down_proj") +} + +func lagunaIsHFRoutedExpertWeight(name string) bool { + return strings.HasSuffix(name, ".weight") && strings.Contains(name, ".mlp.experts.") +} diff --git a/x/create/laguna_test.go b/x/create/laguna_test.go new file mode 100644 index 000000000..7e122366e --- /dev/null +++ b/x/create/laguna_test.go @@ -0,0 +1,200 @@ +package create + +import ( + "io" + "os" + "path/filepath" + "testing" + + st "github.com/ollama/ollama/x/safetensors" +) + +func TestCreateSafetensorsModel_LagunaHFFP8RespectsSourceTensorPrecision(t *testing.T) { + tests := []struct { + name string + requested string + wantFP8Gate string + wantFP8Up string + wantFP8Down string + wantBF16QProj string + }{ + { + name: "default mxfp8 import keeps source bf16 tensors", + requested: "", + wantFP8Gate: "mxfp8", + wantFP8Up: "mxfp8", + wantFP8Down: "mxfp8", + wantBF16QProj: "", + }, + { + name: "nvfp4 import keeps source bf16 tensors and preserves down_proj at mxfp8", + requested: "nvfp4", + wantFP8Gate: "nvfp4", + wantFP8Up: "nvfp4", + wantFP8Down: "mxfp8", + wantBF16QProj: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + configJSON := `{ + "model_type": "laguna", + "architectures": ["LagunaForCausalLM"], + "quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]} + }` + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil { + t.Fatalf("failed to write config.json: %v", err) + } + + createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{ + st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)), + st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)), + st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)), + st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)), + st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)), + st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)), + st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + st.NewTensorDataFromBytes("model.embed_tokens.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + st.NewTensorDataFromBytes("model.layers.0.mlp.gate.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + }) + + quantizeByName := make(map[string]string) + + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + if _, err := io.ReadAll(r); err != nil { + return LayerInfo{}, err + } + return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil + } + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + if _, err := io.ReadAll(r); err != nil { + return nil, err + } + quantizeByName[name] = quantize + return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil + } + writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil } + + if err := CreateSafetensorsModel("test-model", dir, tt.requested, createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil { + t.Fatalf("CreateSafetensorsModel failed: %v", err) + } + + if got := quantizeByName["model.layers.0.mlp.experts.0.gate_proj.weight"]; got != tt.wantFP8Gate { + t.Fatalf("gate_proj quantization = %q, want %q", got, tt.wantFP8Gate) + } + if got := quantizeByName["model.layers.0.mlp.experts.0.up_proj.weight"]; got != tt.wantFP8Up { + t.Fatalf("up_proj quantization = %q, want %q", got, tt.wantFP8Up) + } + if got := quantizeByName["model.layers.0.mlp.experts.0.down_proj.weight"]; got != tt.wantFP8Down { + t.Fatalf("down_proj quantization = %q, want %q", got, tt.wantFP8Down) + } + for _, name := range []string{ + "model.layers.0.self_attn.q_proj.weight", + "model.embed_tokens.weight", + "lm_head.weight", + "model.layers.0.mlp.gate.weight", + } { + if got := quantizeByName[name]; got != tt.wantBF16QProj { + t.Fatalf("%s quantization = %q, want %q", name, got, tt.wantBF16QProj) + } + } + }) + } +} + +func TestCreateSafetensorsModel_LagunaBF16QuantizesOnlyRoutedExperts(t *testing.T) { + tests := []struct { + name string + requested string + want map[string]string + }{ + { + name: "int8 quantizes only routed experts", + requested: "int8", + want: map[string]string{ + "model.layers.0.mlp.experts.0.gate_proj.weight": "int8", + "model.layers.0.mlp.experts.0.up_proj.weight": "int8", + "model.layers.0.mlp.experts.0.down_proj.weight": "int8", + "model.layers.0.mlp.shared_experts.gate_proj.weight": "", + "model.layers.0.mlp.shared_experts.down_proj.weight": "", + "model.layers.0.self_attn.q_proj.weight": "", + "model.layers.0.mlp.down_proj.weight": "", + "model.embed_tokens.weight": "", + "lm_head.weight": "", + "model.layers.0.mlp.gate.weight": "", + }, + }, + { + name: "int4 keeps routed down_proj at int8 and leaves others bf16", + requested: "int4", + want: map[string]string{ + "model.layers.0.mlp.experts.0.gate_proj.weight": "int4", + "model.layers.0.mlp.experts.0.up_proj.weight": "int4", + "model.layers.0.mlp.experts.0.down_proj.weight": "int8", + "model.layers.0.mlp.shared_experts.gate_proj.weight": "", + "model.layers.0.mlp.shared_experts.down_proj.weight": "", + "model.layers.0.self_attn.q_proj.weight": "", + "model.layers.0.mlp.down_proj.weight": "", + "model.embed_tokens.weight": "", + "lm_head.weight": "", + "model.layers.0.mlp.gate.weight": "", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + configJSON := `{ + "model_type": "laguna", + "architectures": ["LagunaForCausalLM"] + }` + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil { + t.Fatalf("failed to write config.json: %v", err) + } + + createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{ + st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.gate_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.up_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.down_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + st.NewTensorDataFromBytes("model.layers.0.mlp.shared_experts.gate_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + st.NewTensorDataFromBytes("model.layers.0.mlp.shared_experts.down_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + st.NewTensorDataFromBytes("model.layers.0.mlp.down_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + st.NewTensorDataFromBytes("model.embed_tokens.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + st.NewTensorDataFromBytes("model.layers.0.mlp.gate.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)), + }) + + quantizeByName := make(map[string]string) + + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + if _, err := io.ReadAll(r); err != nil { + return LayerInfo{}, err + } + return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil + } + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + if _, err := io.ReadAll(r); err != nil { + return nil, err + } + quantizeByName[name] = quantize + return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil + } + writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil } + + if err := CreateSafetensorsModel("test-model", dir, tt.requested, createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil { + t.Fatalf("CreateSafetensorsModel failed: %v", err) + } + + for name, want := range tt.want { + if got := quantizeByName[name]; got != want { + t.Fatalf("%s quantization = %q, want %q", name, got, want) + } + } + }) + } +} diff --git a/x/mlxrunner/imports.go b/x/mlxrunner/imports.go index ea16e4ea2..316691d94 100644 --- a/x/mlxrunner/imports.go +++ b/x/mlxrunner/imports.go @@ -4,6 +4,7 @@ import ( _ "github.com/ollama/ollama/x/models/gemma3" _ "github.com/ollama/ollama/x/models/gemma4" _ "github.com/ollama/ollama/x/models/glm4_moe_lite" + _ "github.com/ollama/ollama/x/models/laguna" _ "github.com/ollama/ollama/x/models/llama" _ "github.com/ollama/ollama/x/models/qwen3" _ "github.com/ollama/ollama/x/models/qwen3_5" diff --git a/x/mlxrunner/mlx/act.go b/x/mlxrunner/mlx/act.go index 3a67da279..4fa21e270 100644 --- a/x/mlxrunner/mlx/act.go +++ b/x/mlxrunner/mlx/act.go @@ -34,6 +34,19 @@ var SiLU = Compile1( Shapeless(), ) +// SoftplusF32 returns softplus(x) computed in float32 precision and cast back +// to x's original dtype, as a fused kernel. Matches the laguna attention +// output-gate formula: softplus(cast_f32(x)).cast(orig_dtype). +var SoftplusF32 = Compile1( + "SoftplusF32", + func(x *Array) *Array { + dt := x.DType() + zero := FromValue[float32](0) + return Logaddexp(x.AsType(DTypeFloat32), zero).AsType(dt) + }, + Shapeless(), +) + // SwiGLU returns silu(gate) * up as a fused kernel. var SwiGLU = Compile2( "SwiGLU", diff --git a/x/models/laguna/laguna.go b/x/models/laguna/laguna.go new file mode 100644 index 000000000..fea4e3f26 --- /dev/null +++ b/x/models/laguna/laguna.go @@ -0,0 +1,1216 @@ +// Package laguna provides the Poolside Laguna text model implementation for MLX. +package laguna + +import ( + "encoding/json" + "fmt" + "math" + "strings" + + "github.com/ollama/ollama/x/mlxrunner/batch" + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model" + "github.com/ollama/ollama/x/mlxrunner/model/base" + "github.com/ollama/ollama/x/models/nn" + "github.com/ollama/ollama/x/tokenizer" +) + +func init() { + base.Register("LagunaForCausalLM", NewModel) +} + +var _ base.Model = (*Model)(nil) + +type RopeParameters struct { + RopeTheta float32 `json:"rope_theta"` + RopeType string `json:"rope_type"` + Type string `json:"type"` + PartialRotaryFactor float32 `json:"partial_rotary_factor"` + Factor float32 `json:"factor"` + OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"` + BetaFast float32 `json:"beta_fast"` + BetaSlow float32 `json:"beta_slow"` + AttentionFactor float32 `json:"attention_factor"` +} + +type gatingMode string + +type ropeConfig struct { + flat *RopeParameters + full *RopeParameters + sliding *RopeParameters + nested bool +} + +type Config struct { + ModelType string `json:"model_type"` + HiddenSize int32 `json:"hidden_size"` + IntermediateSize int32 `json:"intermediate_size"` + MoeIntermediateSize int32 `json:"moe_intermediate_size"` + SharedExpertIntermediate int32 `json:"shared_expert_intermediate_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumAttentionHeadsPerLayer []int32 `json:"num_attention_heads_per_layer"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + HeadDim int32 `json:"head_dim"` + RMSNormEps float32 `json:"rms_norm_eps"` + VocabSize int32 `json:"vocab_size"` + MaxPositionEmbeddings int32 `json:"max_position_embeddings"` + LayerTypes []string `json:"layer_types"` + SlidingWindow int32 `json:"sliding_window"` + MLPOnlyLayers []int32 `json:"mlp_only_layers"` + DecoderSparseStep int32 `json:"decoder_sparse_step"` + NumExperts int32 `json:"num_experts"` + NumExpertsPerTok int32 `json:"num_experts_per_tok"` + NormTopKProb bool `json:"norm_topk_prob"` + MoeRoutedScalingFactor float32 `json:"moe_routed_scaling_factor"` + MoeApplyRouterWeightOnInput bool `json:"moe_apply_router_weight_on_input"` + Gating string `json:"gating"` + TieWordEmbeddings bool `json:"tie_word_embeddings"` + RopeTheta float32 `json:"rope_theta"` + PartialRotaryFactor float32 `json:"partial_rotary_factor"` + RopeParameters *RopeParameters `json:"rope_parameters"` + RopeScaling *RopeParameters `json:"rope_scaling"` + SWARopeParameters *RopeParameters `json:"swa_rope_parameters"` + + QuantGroupSize int `json:"-"` + QuantBits int `json:"-"` + QuantMode string `json:"-"` + TensorQuant map[string]*model.TensorQuantInfo `json:"-"` + + Scale float32 `json:"-"` + FullRopeDim int `json:"-"` + FullRopeBase float32 `json:"-"` + FullRopeScale float32 `json:"-"` + FullRopeFreqs *mlx.Array `json:"-"` + SlidingRopeDim int `json:"-"` + SlidingRopeBase float32 `json:"-"` + SlidingRopeScale float32 `json:"-"` +} + +type Model struct { + EmbedTokens nn.EmbeddingLayer + Layers []*Layer + Norm *nn.RMSNorm + LMHead nn.LinearLayer + + tok *tokenizer.Tokenizer + *Config +} + +type Layer struct { + InputNorm *nn.RMSNorm + PostAttentionNorm *nn.RMSNorm + Attention *Attention + MLP MLPBlock + + LayerIdx int32 + IsSliding bool +} + +type Attention struct { + QProj nn.LinearLayer + KProj nn.LinearLayer + VProj nn.LinearLayer + OProj nn.LinearLayer + GProj nn.LinearLayer + + QNorm *nn.RMSNorm + KNorm *nn.RMSNorm + + NumHeads int32 +} + +type MLPBlock interface { + Forward(x *mlx.Array, cfg *Config) *mlx.Array +} + +type DenseMLP struct { + GateProj nn.LinearLayer + UpProj nn.LinearLayer + DownProj nn.LinearLayer +} + +type SparseMoE struct { + Gate nn.LinearLayer + SwitchMLP *SwitchMLP + SharedExpert *DenseMLP + EScoreCorrectionBias *mlx.Array +} + +type SwitchMLP struct { + GateUpWeight *mlx.Array + GateWeight *mlx.Array + UpWeight *mlx.Array + DownWeight *mlx.Array + + GateUpWeightQ, GateUpScales, GateUpBiases *mlx.Array + GateWeightQ, GateScales, GateBiases *mlx.Array + UpWeightQ, UpScales, UpBiases *mlx.Array + DownWeightQ, DownScales, DownBiases *mlx.Array + GateGlobalScale, UpGlobalScale *mlx.Array + DownGlobalScale *mlx.Array + + GateUpBits, GateBits, UpBits, DownBits int + GateUpGroupSize, GateGroupSize, UpGroupSize, DownGroupSize int + GateUpMode, GateMode, UpMode, DownMode string + UseQuantized, UseFusedGateUp bool +} + +type stackedExpertWeights struct { + Weight *mlx.Array + Scales *mlx.Array + Biases *mlx.Array + GlobalScales *mlx.Array + Bits int + GroupSize int + Mode string +} + +func parseConfig(configData []byte) (Config, error) { + type rawConfig struct { + ModelType string `json:"model_type"` + HiddenSize int32 `json:"hidden_size"` + IntermediateSize int32 `json:"intermediate_size"` + MoeIntermediateSize int32 `json:"moe_intermediate_size"` + SharedExpertIntermediate int32 `json:"shared_expert_intermediate_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumAttentionHeadsPerLayer []int32 `json:"num_attention_heads_per_layer"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + HeadDim int32 `json:"head_dim"` + RMSNormEps float32 `json:"rms_norm_eps"` + VocabSize int32 `json:"vocab_size"` + MaxPositionEmbeddings int32 `json:"max_position_embeddings"` + LayerTypes []string `json:"layer_types"` + SlidingWindow int32 `json:"sliding_window"` + MLPOnlyLayers []int32 `json:"mlp_only_layers"` + MLPLayerTypes []string `json:"mlp_layer_types"` + DecoderSparseStep int32 `json:"decoder_sparse_step"` + NumExperts int32 `json:"num_experts"` + NumExpertsPerTok int32 `json:"num_experts_per_tok"` + NormTopKProb *bool `json:"norm_topk_prob"` + MoeRoutedScalingFactor float32 `json:"moe_routed_scaling_factor"` + MoeApplyRouterWeightOnInput bool `json:"moe_apply_router_weight_on_input"` + Gating gatingMode `json:"gating"` + TieWordEmbeddings bool `json:"tie_word_embeddings"` + RopeTheta float32 `json:"rope_theta"` + PartialRotaryFactor float32 `json:"partial_rotary_factor"` + RopeParameters ropeConfig `json:"rope_parameters"` + RopeScaling *RopeParameters `json:"rope_scaling"` + SWARopeParameters *RopeParameters `json:"swa_rope_parameters"` + } + + var raw rawConfig + if err := json.Unmarshal(configData, &raw); err != nil { + return Config{}, fmt.Errorf("parse config: %w", err) + } + + mlpOnlyLayers, err := denseLayers(raw.MLPOnlyLayers, raw.MLPLayerTypes) + if err != nil { + return Config{}, err + } + + fullRope := raw.RopeParameters.fullParams() + if fullRope == nil { + fullRope = raw.RopeScaling + } + swaRope := raw.SWARopeParameters + if nestedSwa := raw.RopeParameters.slidingParams(); nestedSwa != nil { + swaRope = nestedSwa + } + + cfg := Config{ + ModelType: raw.ModelType, + HiddenSize: raw.HiddenSize, + IntermediateSize: raw.IntermediateSize, + MoeIntermediateSize: raw.MoeIntermediateSize, + SharedExpertIntermediate: raw.SharedExpertIntermediate, + NumHiddenLayers: raw.NumHiddenLayers, + NumAttentionHeads: raw.NumAttentionHeads, + NumAttentionHeadsPerLayer: raw.NumAttentionHeadsPerLayer, + NumKeyValueHeads: raw.NumKeyValueHeads, + HeadDim: raw.HeadDim, + RMSNormEps: raw.RMSNormEps, + VocabSize: raw.VocabSize, + MaxPositionEmbeddings: raw.MaxPositionEmbeddings, + LayerTypes: raw.LayerTypes, + SlidingWindow: raw.SlidingWindow, + MLPOnlyLayers: mlpOnlyLayers, + DecoderSparseStep: raw.DecoderSparseStep, + NumExperts: raw.NumExperts, + NumExpertsPerTok: raw.NumExpertsPerTok, + NormTopKProb: defaultBool(raw.NormTopKProb, true), + MoeRoutedScalingFactor: raw.MoeRoutedScalingFactor, + MoeApplyRouterWeightOnInput: raw.MoeApplyRouterWeightOnInput, + Gating: raw.Gating.normalized(), + TieWordEmbeddings: raw.TieWordEmbeddings, + RopeTheta: raw.RopeTheta, + PartialRotaryFactor: raw.PartialRotaryFactor, + RopeParameters: fullRope, + RopeScaling: raw.RopeScaling, + SWARopeParameters: swaRope, + } + + if cfg.HiddenSize <= 0 { + return Config{}, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize) + } + if cfg.NumHiddenLayers <= 0 { + return Config{}, fmt.Errorf("invalid num_hidden_layers: %d", cfg.NumHiddenLayers) + } + if cfg.NumAttentionHeads <= 0 && len(cfg.NumAttentionHeadsPerLayer) == 0 { + return Config{}, fmt.Errorf("missing num_attention_heads") + } + if cfg.NumKeyValueHeads <= 0 { + cfg.NumKeyValueHeads = cfg.NumAttentionHeads + } + if cfg.HeadDim <= 0 { + if cfg.NumAttentionHeads <= 0 || cfg.HiddenSize%cfg.NumAttentionHeads != 0 { + return Config{}, fmt.Errorf("cannot infer head_dim") + } + cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads + } + if cfg.RMSNormEps == 0 { + cfg.RMSNormEps = 1e-6 + } + if cfg.IntermediateSize <= 0 { + return Config{}, fmt.Errorf("invalid intermediate_size: %d", cfg.IntermediateSize) + } + if cfg.MoeIntermediateSize <= 0 { + cfg.MoeIntermediateSize = cfg.IntermediateSize + } + if cfg.SharedExpertIntermediate <= 0 { + cfg.SharedExpertIntermediate = cfg.MoeIntermediateSize + } + if cfg.DecoderSparseStep <= 0 { + cfg.DecoderSparseStep = 1 + } + if cfg.NumExpertsPerTok <= 0 && cfg.NumExperts > 0 { + cfg.NumExpertsPerTok = 1 + } + if cfg.MoeRoutedScalingFactor == 0 { + cfg.MoeRoutedScalingFactor = 1 + } + + ropeParams := cfg.RopeParameters + if ropeParams == nil { + ropeParams = cfg.RopeScaling + } + cfg.FullRopeBase = cfg.RopeTheta + if cfg.FullRopeBase == 0 && ropeParams != nil && ropeParams.RopeTheta > 0 { + cfg.FullRopeBase = ropeParams.RopeTheta + } + if cfg.FullRopeBase == 0 { + cfg.FullRopeBase = 10000 + } + fullPartial := cfg.PartialRotaryFactor + if fullPartial == 0 && ropeParams != nil && ropeParams.PartialRotaryFactor > 0 { + fullPartial = ropeParams.PartialRotaryFactor + } + if fullPartial == 0 { + fullPartial = 1 + } + cfg.FullRopeDim = clampRopeDim(int(float32(cfg.HeadDim)*fullPartial), int(cfg.HeadDim)) + cfg.FullRopeScale = 1 + if ropeParams != nil && strings.EqualFold(ropeParams.ropeType(), "yarn") { + cfg.FullRopeFreqs, cfg.FullRopeScale = buildYarnRopeFreqs(cfg.FullRopeDim, cfg.FullRopeBase, ropeParams) + } + + cfg.SlidingRopeBase = cfg.FullRopeBase + slidingPartial := fullPartial + if cfg.SWARopeParameters != nil { + if cfg.SWARopeParameters.RopeTheta > 0 { + cfg.SlidingRopeBase = cfg.SWARopeParameters.RopeTheta + } + if cfg.SWARopeParameters.PartialRotaryFactor > 0 { + slidingPartial = cfg.SWARopeParameters.PartialRotaryFactor + } + } + cfg.SlidingRopeDim = clampRopeDim(int(float32(cfg.HeadDim)*slidingPartial), int(cfg.HeadDim)) + cfg.SlidingRopeScale = 1 + cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + return cfg, nil +} + +func (g *gatingMode) UnmarshalJSON(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err == nil { + *g = gatingMode(s) + return nil + } + + var enabled bool + if err := json.Unmarshal(b, &enabled); err == nil { + if enabled { + *g = "per-head" + } else { + *g = "false" + } + return nil + } + + if string(b) == "null" { + return nil + } + return fmt.Errorf("unsupported Laguna gating JSON value %s", string(b)) +} + +func (g gatingMode) normalized() string { + if strings.EqualFold(string(g), "true") { + return "per-head" + } + return string(g) +} + +func (r *ropeConfig) UnmarshalJSON(b []byte) error { + if string(b) == "null" { + return nil + } + + var probe map[string]json.RawMessage + if err := json.Unmarshal(b, &probe); err != nil { + return err + } + if len(probe) == 0 { + return nil + } + + if raw, ok := probe["full_attention"]; ok { + r.nested = true + r.full = &RopeParameters{} + if err := json.Unmarshal(raw, r.full); err != nil { + return err + } + if raw = probe["sliding_attention"]; raw != nil { + r.sliding = &RopeParameters{} + if err := json.Unmarshal(raw, r.sliding); err != nil { + return err + } + } + return nil + } + + if raw, ok := probe["global_attention"]; ok { + r.nested = true + r.full = &RopeParameters{} + if err := json.Unmarshal(raw, r.full); err != nil { + return err + } + if raw = probe["sliding_attention"]; raw != nil { + r.sliding = &RopeParameters{} + if err := json.Unmarshal(raw, r.sliding); err != nil { + return err + } + } + return nil + } + + r.flat = &RopeParameters{} + return json.Unmarshal(b, r.flat) +} + +func (r ropeConfig) fullParams() *RopeParameters { + if r.nested { + return r.full + } + return r.flat +} + +func (r ropeConfig) slidingParams() *RopeParameters { + if !r.nested { + return nil + } + return r.sliding +} + +func defaultBool(v *bool, fallback bool) bool { + if v == nil { + return fallback + } + return *v +} + +func denseLayers(mlpOnlyLayers []int32, mlpLayerTypes []string) ([]int32, error) { + if len(mlpOnlyLayers) > 0 { + return mlpOnlyLayers, nil + } + if len(mlpLayerTypes) == 0 { + return nil, nil + } + + dense := make([]int32, 0, len(mlpLayerTypes)) + for i, layerType := range mlpLayerTypes { + switch { + case strings.EqualFold(layerType, "dense"): + dense = append(dense, int32(i)) + case strings.EqualFold(layerType, "sparse"): + default: + return nil, fmt.Errorf("unsupported mlp_layer_types[%d]=%q", i, layerType) + } + } + return dense, nil +} + +func (rp *RopeParameters) ropeType() string { + if rp == nil { + return "" + } + if rp.RopeType != "" { + return rp.RopeType + } + return rp.Type +} + +func buildYarnRopeFreqs(dim int, base float32, rp *RopeParameters) (*mlx.Array, float32) { + if rp == nil || dim <= 0 { + return nil, 1 + } + factor := rp.Factor + if factor <= 0 { + factor = 1 + } + attentionFactor := rp.AttentionFactor + if attentionFactor == 0 && factor > 1 { + attentionFactor = float32(0.1*math.Log(float64(factor)) + 1.0) + } else if attentionFactor == 0 { + attentionFactor = 1 + } + if factor <= 1 { + return nil, attentionFactor + } + + originalMax := rp.OriginalMaxPositionEmbeddings + if originalMax <= 0 { + originalMax = 4096 + } + betaFast := rp.BetaFast + if betaFast == 0 { + betaFast = 32 + } + betaSlow := rp.BetaSlow + if betaSlow == 0 { + betaSlow = 1 + } + half := dim / 2 + low, high := yarnCorrectionRange(betaFast, betaSlow, dim, base, originalMax) + freqs := make([]float32, half) + for i := range half { + posFreq := math.Pow(float64(base), float64(2*i)/float64(dim)) + invExtrapolation := 1.0 / posFreq + invInterpolation := 1.0 / (float64(factor) * posFreq) + ramp := yarnRamp(float64(i), low, high) + mask := 1 - ramp + inv := invInterpolation*(1-mask) + invExtrapolation*mask + freqs[i] = float32(1.0 / inv) + } + arr := mlx.FromValues(freqs, half) + mlx.Eval(arr) + return arr, attentionFactor +} + +func yarnCorrectionRange(betaFast, betaSlow float32, dim int, base float32, maxPosition int32) (float64, float64) { + findDim := func(rot float32) float64 { + return float64(dim) * math.Log(float64(maxPosition)/(float64(rot)*2*math.Pi)) / (2 * math.Log(float64(base))) + } + low := math.Floor(findDim(betaFast)) + high := math.Ceil(findDim(betaSlow)) + low = math.Max(low, 0) + high = math.Min(high, float64(dim-1)) + if low == high { + high += 0.001 + } + return low, high +} + +func yarnRamp(i, low, high float64) float64 { + v := (i - low) / (high - low) + if v < 0 { + return 0 + } + if v > 1 { + return 1 + } + return v +} + +func clampRopeDim(v, maxDim int) int { + if v <= 0 { + return maxDim + } + if v > maxDim { + return maxDim + } + if v%2 != 0 { + v-- + } + if v <= 0 { + return maxDim + } + return v +} + +func NewModel(root *model.Root) (base.Model, error) { + configData, err := root.Manifest.ReadConfig("config.json") + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + cfg, err := parseConfig(configData) + if err != nil { + return nil, err + } + if qt := root.QuantType(); qt != "" { + cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt) + if gs := root.GroupSize(); gs > 0 { + cfg.QuantGroupSize = gs + } + } else { + cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("") + } + cfg.TensorQuant = root.AllTensorQuant() + + tokData, err := root.Manifest.ReadConfig("tokenizer.json") + if err != nil { + return nil, fmt.Errorf("load tokenizer config: %w", err) + } + tokConfig := &tokenizer.TokenizerConfig{ConfigJSON: configData} + if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil { + tokConfig.GenerationConfigJSON = genConfigData + } + if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil { + tokConfig.TokenizerConfigJSON = tokConfigData + } + tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig) + if err != nil { + return nil, fmt.Errorf("parse tokenizer: %w", err) + } + + m := &Model{ + Layers: make([]*Layer, cfg.NumHiddenLayers), + Config: &cfg, + tok: tok, + } + for i := range cfg.NumHiddenLayers { + m.Layers[i] = &Layer{LayerIdx: i, IsSliding: layerIsSliding(&cfg, i)} + } + return m, nil +} + +func layerIsSliding(cfg *Config, layer int32) bool { + if len(cfg.LayerTypes) == int(cfg.NumHiddenLayers) { + return strings.EqualFold(cfg.LayerTypes[layer], "sliding_attention") + } + return false +} + +func layerUsesMoE(cfg *Config, layer int32) bool { + if cfg.NumExperts <= 0 { + return false + } + for _, l := range cfg.MLPOnlyLayers { + if l == layer { + return false + } + } + return (layer+1)%cfg.DecoderSparseStep == 0 +} + +func numHeadsForLayer(cfg *Config, layer int32) int32 { + if int(layer) < len(cfg.NumAttentionHeadsPerLayer) && cfg.NumAttentionHeadsPerLayer[layer] > 0 { + return cfg.NumAttentionHeadsPerLayer[layer] + } + return cfg.NumAttentionHeads +} + +func resolveWeightPrefix(tensors map[string]*mlx.Array) string { + for _, prefix := range []string{"model.", "", "language_model.model.", "language_model.", "model.language_model.model.", "model.language_model."} { + if tensors[prefix+"embed_tokens.weight"] != nil { + return prefix + } + } + return "model." +} + +func tensorAny(tensors map[string]*mlx.Array, keys ...string) (*mlx.Array, string) { + for _, k := range keys { + if v := tensors[k]; v != nil { + return v, k + } + } + return nil, "" +} + +func supportsGatherQMM(mode string, bits int) bool { + switch mode { + case "affine": + return bits == 4 || bits == 8 + case "mxfp8": + return bits == 8 + case "nvfp4", "mxfp4": + return bits == 4 + default: + return false + } +} + +func freeTensorKeys(tensors map[string]*mlx.Array, keys ...string) { + for _, k := range keys { + if k == "" { + continue + } + if t := tensors[k]; t != nil { + delete(tensors, k) + } + } +} + +func stackAndClone(parts []*mlx.Array) *mlx.Array { + if len(parts) == 0 { + return nil + } + stacked := mlx.Stack(parts, 0).Clone() + mlx.Eval(stacked) + return stacked +} + +func transposeExpertWeightForGatherMM(w *mlx.Array) *mlx.Array { + if w == nil || !w.Valid() || w.NumDims() != 3 { + return w + } + t := mlx.Transpose(w, 0, 2, 1).Clone() + mlx.Eval(t) + return t +} + +func canFuseQuantizedGateUp(gateW, upW *stackedExpertWeights) bool { + if gateW == nil || upW == nil || gateW.Scales == nil || upW.Scales == nil { + return false + } + if gateW.GlobalScales != nil || upW.GlobalScales != nil { + return false + } + if gateW.Bits != upW.Bits || gateW.GroupSize != upW.GroupSize || gateW.Mode != upW.Mode { + return false + } + if (gateW.Biases == nil) != (upW.Biases == nil) { + return false + } + return gateW.Weight.NumDims() == 3 && upW.Weight.NumDims() == 3 +} + +func fuseExpertStacks(a, b *mlx.Array, axis int) *mlx.Array { + if a == nil || !a.Valid() || b == nil || !b.Valid() { + return nil + } + out := mlx.Concatenate([]*mlx.Array{a, b}, axis).Clone() + mlx.Eval(out) + return out +} + +func combinedTensorGlobalScale(tensors map[string]*mlx.Array, key string) (*mlx.Array, []string) { + var names []string + weightGlobal := tensors[key+".global_scale"] + if weightGlobal == nil { + weightGlobal = tensors[key+".weight.global_scale"] + } + if weightGlobal != nil { + names = append(names, key+".global_scale", key+".weight.global_scale") + } + if tensors[key+".input_global_scale"] != nil || tensors[key+".weight.input_global_scale"] != nil { + names = append(names, key+".input_global_scale", key+".weight.input_global_scale") + } + switch { + case weightGlobal != nil: + return weightGlobal, names + default: + return nil, nil + } +} + +func collectPerExpertProjection(tensors map[string]*mlx.Array, cfg *Config, useQuantized bool, layerPrefix, proj string, numExperts int32) *stackedExpertWeights { + weights := make([]*mlx.Array, 0, numExperts) + scales := make([]*mlx.Array, 0, numExperts) + biases := make([]*mlx.Array, 0, numExperts) + globalScales := make([]*mlx.Array, 0, numExperts) + consumedKeys := make([]string, 0, numExperts*5) + bits := 0 + groupSize := 0 + mode := cfg.QuantMode + + for e := range numExperts { + base := fmt.Sprintf("%s.mlp.experts.%d.%s", layerPrefix, e, proj) + w, key := tensorAny(tensors, base+".weight", base) + if w == nil { + return nil + } + consumedKeys = append(consumedKeys, key) + s := tensors[key+"_scale"] + if s == nil { + s = tensors[key+".scale"] + } + if s == nil { + weights = append(weights, w) + continue + } + consumedKeys = append(consumedKeys, key+"_scale", key+".scale") + qb := tensors[key+"_qbias"] + if qb == nil { + qb = tensors[key+".bias"] + } + if qb != nil { + consumedKeys = append(consumedKeys, key+"_qbias", key+".bias") + } + globalScale, globalScaleKeys := combinedTensorGlobalScale(tensors, key) + if globalScale != nil { + consumedKeys = append(consumedKeys, globalScaleKeys...) + } + gs, b, m := model.ResolveLinearQuantParams(cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant, key, w, s) + if bits == 0 { + bits = b + groupSize = gs + mode = m + } + if useQuantized && supportsGatherQMM(m, b) { + weights = append(weights, w) + scales = append(scales, s) + if globalScale != nil { + globalScales = append(globalScales, globalScale) + } + if qb != nil { + biases = append(biases, qb) + } + } else { + deq := mlx.Dequantize(w, s, qb, gs, b, m) + if globalScale != nil { + deq = mlx.Mul(deq, globalScale) + globalScales = append(globalScales, globalScale) + } + weights = append(weights, deq) + } + } + + out := &stackedExpertWeights{Weight: stackAndClone(weights), Bits: bits, GroupSize: groupSize, Mode: mode} + if len(scales) == len(weights) { + out.Scales = stackAndClone(scales) + } + if len(biases) == len(weights) { + out.Biases = stackAndClone(biases) + } + if len(globalScales) == len(weights) { + out.GlobalScales = stackAndClone(globalScales) + } + freeTensorKeys(tensors, consumedKeys...) + return out +} + +func loadStackedProjection(tensors map[string]*mlx.Array, cfg *Config, useQuantized bool, bases ...string) *stackedExpertWeights { + for _, base := range bases { + w, key := tensorAny(tensors, base+".weight", base) + if w == nil { + continue + } + s := tensors[key+"_scale"] + if s == nil { + s = tensors[key+".scale"] + } + if s == nil { + return &stackedExpertWeights{Weight: w} + } + qb := tensors[key+"_qbias"] + if qb == nil { + qb = tensors[key+".bias"] + } + globalScale, _ := combinedTensorGlobalScale(tensors, key) + gs, b, m := model.ResolveLinearQuantParams(cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant, key, w, s) + if useQuantized && supportsGatherQMM(m, b) { + return &stackedExpertWeights{Weight: w, Scales: s, Biases: qb, GlobalScales: globalScale, Bits: b, GroupSize: gs, Mode: m} + } + deq := mlx.Dequantize(w, s, qb, gs, b, m) + if globalScale != nil { + deq = mlx.Mul(deq, globalScale) + } + return &stackedExpertWeights{Weight: deq, GlobalScales: globalScale, Bits: b, GroupSize: gs, Mode: m} + } + return nil +} + +func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { + prefix := resolveWeightPrefix(tensors) + cfg := m.Config + linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant) + + m.EmbedTokens = model.MakeEmbeddingLayer(tensors, prefix+"embed_tokens", cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant) + if m.EmbedTokens == nil { + return fmt.Errorf("missing embedding weight: %sembed_tokens.weight", prefix) + } + if w := tensors[prefix+"norm.weight"]; w != nil { + m.Norm = nn.NewRMSNorm(w, cfg.RMSNormEps) + } else { + return fmt.Errorf("missing final norm weight: %snorm.weight", prefix) + } + if cfg.TieWordEmbeddings { + m.LMHead = m.EmbedTokens.AsLinear() + } else if lmHead := linears.Make("lm_head"); lmHead != nil { + m.LMHead = lmHead + } else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil { + m.LMHead = lmHead + } else { + return fmt.Errorf("missing lm_head.weight") + } + + useQuantizedExperts := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits) + if !useQuantizedExperts && cfg.TensorQuant != nil { + for _, tq := range cfg.TensorQuant { + if tq == nil { + continue + } + _, bits, mode := model.QuantizationParams(tq.QuantType) + if supportsGatherQMM(mode, bits) { + useQuantizedExperts = true + break + } + } + } + + for i := range cfg.NumHiddenLayers { + layerPrefix := fmt.Sprintf("%slayers.%d", prefix, i) + layer := &Layer{ + LayerIdx: i, + IsSliding: layerIsSliding(cfg, i), + Attention: &Attention{NumHeads: numHeadsForLayer(cfg, i)}, + } + if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil { + layer.InputNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) + } + if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil { + layer.PostAttentionNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) + } + if layer.InputNorm == nil || layer.PostAttentionNorm == nil { + return fmt.Errorf("layer %d: missing layer norms", i) + } + + layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj") + layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj") + layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj") + layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj") + layer.Attention.GProj = linears.Make(layerPrefix + ".self_attn.g_proj") + if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil { + layer.Attention.QNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) + } + if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil { + layer.Attention.KNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) + } + if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil || layer.Attention.GProj == nil { + return fmt.Errorf("layer %d: missing attention projections", i) + } + if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil { + return fmt.Errorf("layer %d: missing attention q/k norms", i) + } + + if layerUsesMoE(cfg, i) { + moe := &SparseMoE{Gate: linears.Make(layerPrefix + ".mlp.gate")} + if moe.Gate == nil { + return fmt.Errorf("layer %d: missing moe gate", i) + } + moe.EScoreCorrectionBias, _ = tensorAny(tensors, + layerPrefix+".mlp.experts.e_score_correction_bias", + layerPrefix+".mlp.switch_mlp.e_score_correction_bias", + ) + if moe.EScoreCorrectionBias != nil && moe.EScoreCorrectionBias.DType() != mlx.DTypeFloat32 { + bias := moe.EScoreCorrectionBias.AsType(mlx.DTypeFloat32).Clone() + mlx.Eval(bias) + moe.EScoreCorrectionBias = bias + } + + gateW := loadStackedProjection(tensors, cfg, useQuantizedExperts, + layerPrefix+".mlp.switch_mlp.gate_proj", + layerPrefix+".mlp.experts.gate_proj", + ) + upW := loadStackedProjection(tensors, cfg, useQuantizedExperts, + layerPrefix+".mlp.switch_mlp.up_proj", + layerPrefix+".mlp.experts.up_proj", + ) + downW := loadStackedProjection(tensors, cfg, useQuantizedExperts, + layerPrefix+".mlp.switch_mlp.down_proj", + layerPrefix+".mlp.experts.down_proj", + ) + if gateW == nil || upW == nil || downW == nil { + gateW = collectPerExpertProjection(tensors, cfg, useQuantizedExperts, layerPrefix, "gate_proj", cfg.NumExperts) + upW = collectPerExpertProjection(tensors, cfg, useQuantizedExperts, layerPrefix, "up_proj", cfg.NumExperts) + downW = collectPerExpertProjection(tensors, cfg, useQuantizedExperts, layerPrefix, "down_proj", cfg.NumExperts) + } + if gateW == nil || upW == nil || downW == nil { + return fmt.Errorf("layer %d: missing moe expert weights", i) + } + sw := &SwitchMLP{} + if gateW.Scales != nil && upW.Scales != nil && downW.Scales != nil { + sw.UseQuantized = true + sw.DownWeightQ, sw.DownScales, sw.DownBiases = downW.Weight, downW.Scales, downW.Biases + sw.DownGlobalScale = downW.GlobalScales + sw.DownBits, sw.DownGroupSize, sw.DownMode = downW.Bits, downW.GroupSize, downW.Mode + if canFuseQuantizedGateUp(gateW, upW) { + sw.UseFusedGateUp = true + sw.GateUpWeightQ = fuseExpertStacks(gateW.Weight, upW.Weight, 1) + sw.GateUpScales = fuseExpertStacks(gateW.Scales, upW.Scales, 1) + sw.GateUpBiases = fuseExpertStacks(gateW.Biases, upW.Biases, 1) + sw.GateUpBits, sw.GateUpGroupSize, sw.GateUpMode = gateW.Bits, gateW.GroupSize, gateW.Mode + } else { + sw.GateWeightQ, sw.GateScales, sw.GateBiases = gateW.Weight, gateW.Scales, gateW.Biases + sw.UpWeightQ, sw.UpScales, sw.UpBiases = upW.Weight, upW.Scales, upW.Biases + sw.GateGlobalScale = gateW.GlobalScales + sw.UpGlobalScale = upW.GlobalScales + sw.GateBits, sw.GateGroupSize, sw.GateMode = gateW.Bits, gateW.GroupSize, gateW.Mode + sw.UpBits, sw.UpGroupSize, sw.UpMode = upW.Bits, upW.GroupSize, upW.Mode + } + } else { + sw.GateWeight = transposeExpertWeightForGatherMM(gateW.Weight) + sw.UpWeight = transposeExpertWeightForGatherMM(upW.Weight) + sw.DownWeight = transposeExpertWeightForGatherMM(downW.Weight) + sw.GateUpWeight = fuseExpertStacks(sw.GateWeight, sw.UpWeight, 2) + sw.UseFusedGateUp = sw.GateUpWeight != nil + } + moe.SwitchMLP = sw + moe.SharedExpert = &DenseMLP{ + GateProj: linears.Make(layerPrefix + ".mlp.shared_expert.gate_proj"), + UpProj: linears.Make(layerPrefix + ".mlp.shared_expert.up_proj"), + DownProj: linears.Make(layerPrefix + ".mlp.shared_expert.down_proj"), + } + if moe.SharedExpert.GateProj == nil || moe.SharedExpert.UpProj == nil || moe.SharedExpert.DownProj == nil { + return fmt.Errorf("layer %d: missing shared expert weights", i) + } + layer.MLP = moe + } else { + mlp := &DenseMLP{ + GateProj: linears.Make(layerPrefix + ".mlp.gate_proj"), + UpProj: linears.Make(layerPrefix + ".mlp.up_proj"), + DownProj: linears.Make(layerPrefix + ".mlp.down_proj"), + } + if mlp.GateProj == nil || mlp.UpProj == nil || mlp.DownProj == nil { + return fmt.Errorf("layer %d: missing dense mlp projections", i) + } + layer.MLP = mlp + } + m.Layers[i] = layer + } + return nil +} + +func (a *Attention) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, layer *Layer, cfg *Config) *mlx.Array { + numHeads := a.NumHeads + q := a.QProj.Forward(x) + k := a.KProj.Forward(x) + v := a.VProj.Forward(x) + + q = mlx.Reshape(q, B, L, numHeads, cfg.HeadDim) + k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim) + v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim) + + q = a.QNorm.Forward(q, cfg.RMSNormEps) + k = a.KNorm.Forward(k, cfg.RMSNormEps) + + q = mlx.Transpose(q, 0, 2, 1, 3) + k = mlx.Transpose(k, 0, 2, 1, 3) + v = mlx.Transpose(v, 0, 2, 1, 3) + + ropeDim, ropeBase, ropeMSScale, ropeFreqs := cfg.FullRopeDim, cfg.FullRopeBase, cfg.FullRopeScale, cfg.FullRopeFreqs + if layer.IsSliding { + ropeDim, ropeBase, ropeMSScale, ropeFreqs = cfg.SlidingRopeDim, cfg.SlidingRopeBase, cfg.SlidingRopeScale, nil + } + q = scaleRotaryPart(mlx.RoPEWithFreqs(q, ropeDim, false, ropeBase, 1.0, positions, ropeFreqs), ropeDim, ropeMSScale) + k = scaleRotaryPart(mlx.RoPEWithFreqs(k, ropeDim, false, ropeBase, 1.0, positions, ropeFreqs), ropeDim, ropeMSScale) + + var kv nn.SDPAOption + if c != nil { + history := c.(cache.Attention).Update(b, k, v) + kv = nn.WithKVHistory(history) + } else { + kv = nn.WithKV(k, v, b.SeqQueryLens) + } + out := nn.ScaledDotProductAttention(b, q, cfg.Scale, kv, nn.WithMask(nn.CausalMask())) + + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, numHeads, cfg.HeadDim) + gate := mlx.ExpandDims(mlx.SoftplusF32(a.GProj.Forward(x)), -1) + out = mlx.Reshape(mlx.Mul(out, gate), B, L, numHeads*cfg.HeadDim) + return a.OProj.Forward(out) +} + +func scaleRotaryPart(x *mlx.Array, ropeDim int, scale float32) *mlx.Array { + if scale == 1 { + return x + } + dims := x.Dims() + last := dims[len(dims)-1] + if ropeDim >= last { + return mlx.MulScalar(x, scale) + } + start := make([]int32, len(dims)) + stopRot := make([]int32, len(dims)) + stopPass := make([]int32, len(dims)) + startPass := make([]int32, len(dims)) + for i, dim := range dims { + stopRot[i] = int32(dim) + stopPass[i] = int32(dim) + } + stopRot[len(dims)-1] = int32(ropeDim) + startPass[len(dims)-1] = int32(ropeDim) + rot := mlx.MulScalar(mlx.SliceStartStop(x, start, stopRot), scale) + pass := mlx.SliceStartStop(x, startPass, stopPass) + return mlx.Concatenate([]*mlx.Array{rot, pass}, -1) +} + +func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array { + return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x))) +} + +func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array { + dims := x.Dims() + B, L := int32(dims[0]), int32(dims[1]) + topK := cfg.NumExpertsPerTok + + xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2) + xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize) + idxFlat := mlx.Reshape(indices, B*L, topK) + doSort := B*L >= 64 + var invOrder *mlx.Array + n := B * L * topK + + if doSort { + idxAll := mlx.Flatten(idxFlat) + order := mlx.Argsort(idxAll, 0) + invOrder = mlx.Argsort(order, 0) + xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1) + idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1) + } + + var gate, up, hidden, down *mlx.Array + if s.UseQuantized { + if s.UseFusedGateUp { + gateUp := mlx.GatherQMM(xFlat, s.GateUpWeightQ, s.GateUpScales, s.GateUpBiases, nil, idxFlat, true, s.GateUpGroupSize, s.GateUpBits, s.GateUpMode, doSort) + guDims := gateUp.Dims() + mid := int32(guDims[len(guDims)-1] / 2) + gate = mlx.SliceStartStop(gateUp, []int32{0, 0, 0, 0}, []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), mid}) + up = mlx.SliceStartStop(gateUp, []int32{0, 0, 0, mid}, []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])}) + hidden = mlx.SwiGLU(gate, up) + } else { + gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases, nil, idxFlat, true, s.GateGroupSize, s.GateBits, s.GateMode, doSort) + if s.GateGlobalScale != nil { + gate = mlx.Mul(gate, mlx.Take(s.GateGlobalScale, idxFlat, 0)) + } + up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases, nil, idxFlat, true, s.UpGroupSize, s.UpBits, s.UpMode, doSort) + if s.UpGlobalScale != nil { + up = mlx.Mul(up, mlx.Take(s.UpGlobalScale, idxFlat, 0)) + } + hidden = mlx.SwiGLU(gate, up) + } + down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases, nil, idxFlat, true, s.DownGroupSize, s.DownBits, s.DownMode, doSort) + if s.DownGlobalScale != nil { + down = mlx.Mul(down, mlx.Take(s.DownGlobalScale, idxFlat, 0)) + } + } else { + if s.UseFusedGateUp && s.GateUpWeight != nil { + gateUp := mlx.GatherMM(xFlat, s.GateUpWeight, nil, idxFlat, doSort) + guDims := gateUp.Dims() + mid := int32(guDims[len(guDims)-1] / 2) + gate = mlx.SliceStartStop(gateUp, []int32{0, 0, 0, 0}, []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), mid}) + up = mlx.SliceStartStop(gateUp, []int32{0, 0, 0, mid}, []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])}) + hidden = mlx.SwiGLU(gate, up) + } else { + gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort) + up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort) + hidden = mlx.SwiGLU(gate, up) + } + down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort) + } + if doSort { + down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize) + } else { + down = mlx.Squeeze(down, 2) + } + return mlx.Reshape(down, B, L, topK, cfg.HiddenSize) +} + +func (m *SparseMoE) route(xFlat *mlx.Array, cfg *Config) (scores, inds *mlx.Array) { + gates := m.Gate.Forward(xFlat).AsType(mlx.DTypeFloat32) + var probs, neg *mlx.Array + if m.EScoreCorrectionBias != nil { + probs, neg = mlx.SigmoidRouter(gates, m.EScoreCorrectionBias) + } else { + probs = mlx.Sigmoid(gates) + neg = mlx.Neg(probs) + } + inds = mlx.Argpartition(neg, int(cfg.NumExpertsPerTok)-1, -1) + inds = mlx.SliceStartStop(inds, []int32{0, 0}, []int32{int32(xFlat.Dim(0)), cfg.NumExpertsPerTok}) + scores = mlx.TakeAlongAxis(probs, inds, -1) + if cfg.NormTopKProb && cfg.NumExpertsPerTok > 1 { + scores = mlx.Div(scores, mlx.Sum(scores, -1, true)) + } + return scores, inds +} + +func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array { + dims := x.Dims() + B, L := int32(dims[0]), int32(dims[1]) + BL := B * L + + shared := m.SharedExpert.Forward(x, cfg) + xFlat := mlx.Reshape(x, BL, cfg.HiddenSize) + scores, inds := m.route(xFlat, cfg) + scores = scores.AsType(x.DType()) + + expertOut := m.SwitchMLP.Forward(x, inds, cfg) + routed := mlx.Sum(mlx.Mul(expertOut, mlx.ExpandDims(mlx.Reshape(scores, B, L, cfg.NumExpertsPerTok), -1)), 2, false) + if cfg.MoeRoutedScalingFactor != 1 { + routed = mlx.MulScalar(routed, cfg.MoeRoutedScalingFactor) + } + return mlx.Add(routed, shared) +} + +func (l *Layer) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array { + r := l.Attention.Forward(l.InputNorm.Forward(x, cfg.RMSNormEps), b, c, positions, B, L, l, cfg) + h := mlx.Add(x, r) + r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg) + return mlx.Add(h, r) +} + +func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array { + dims := b.InputIDs.Dims() + B, L := int32(dims[0]), int32(dims[1]) + positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets)) + h := m.EmbedTokens.Forward(b.InputIDs) + for i, layer := range m.Layers { + var c cache.Cache + if caches != nil && i < len(caches) { + c = caches[i] + } + h = layer.Forward(h, b, c, positions, B, L, m.Config) + } + return m.Norm.Forward(h, m.RMSNormEps) +} + +func (m *Model) Unembed(x *mlx.Array) *mlx.Array { + return m.LMHead.Forward(x) +} + +func (m *Model) NumLayers() int { + return len(m.Layers) +} + +func (m *Model) MaxContextLength() int { + return int(m.MaxPositionEmbeddings) +} + +func (m *Model) Tokenizer() *tokenizer.Tokenizer { + return m.tok +} + +func (m *Model) NewCaches() []cache.Cache { + caches := make([]cache.Cache, len(m.Layers)) + for i, layer := range m.Layers { + if m.SlidingWindow > 0 && layer.IsSliding { + caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow)) + } else { + caches[i] = cache.NewKVCache() + } + } + return caches +} diff --git a/x/models/laguna/laguna_test.go b/x/models/laguna/laguna_test.go new file mode 100644 index 000000000..ad81d1ebd --- /dev/null +++ b/x/models/laguna/laguna_test.go @@ -0,0 +1,509 @@ +package laguna + +import ( + "math" + "testing" + + "github.com/ollama/ollama/x/mlxrunner/batch" + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/models/nn" +) + +func TestParseConfigLagunaXS(t *testing.T) { + skipIfNoMLX(t) + cfg, err := parseConfig([]byte(`{ + "model_type": "laguna", + "hidden_size": 2048, + "intermediate_size": 8192, + "moe_intermediate_size": 512, + "shared_expert_intermediate_size": 512, + "num_hidden_layers": 4, + "num_attention_heads": 48, + "num_attention_heads_per_layer": [48, 64, 64, 64], + "num_key_value_heads": 8, + "head_dim": 128, + "vocab_size": 100352, + "max_position_embeddings": 131072, + "layer_types": ["full_attention", "sliding_attention", "sliding_attention", "sliding_attention"], + "sliding_window": 512, + "mlp_only_layers": [0], + "decoder_sparse_step": 1, + "num_experts": 256, + "num_experts_per_tok": 8, + "norm_topk_prob": true, + "moe_routed_scaling_factor": 2.5, + "gating": "per-head", + "rms_norm_eps": 1e-6, + "partial_rotary_factor": 0.5, + "rope_parameters": { + "rope_theta": 500000, + "rope_type": "yarn", + "factor": 32, + "original_max_position_embeddings": 4096, + "beta_fast": 64, + "beta_slow": 1, + "attention_factor": 1 + }, + "swa_rope_parameters": { + "partial_rotary_factor": 1.0, + "rope_theta": 10000, + "rope_type": "linear" + } + }`)) + if err != nil { + t.Fatal(err) + } + + if cfg.FullRopeDim != 64 { + t.Fatalf("FullRopeDim = %d, want 64", cfg.FullRopeDim) + } + if cfg.FullRopeBase != 500000 { + t.Fatalf("FullRopeBase = %v, want 500000", cfg.FullRopeBase) + } + if cfg.FullRopeScale != 1 { + t.Fatalf("FullRopeScale = %v, want explicit YaRN attention_factor", cfg.FullRopeScale) + } + if cfg.FullRopeFreqs == nil { + t.Fatal("FullRopeFreqs should be precomputed for YaRN") + } + if cfg.SlidingRopeDim != 128 { + t.Fatalf("SlidingRopeDim = %d, want 128", cfg.SlidingRopeDim) + } + if cfg.SlidingRopeBase != 10000 { + t.Fatalf("SlidingRopeBase = %v, want 10000", cfg.SlidingRopeBase) + } + if !layerIsSliding(&cfg, 1) { + t.Fatal("layer 1 should use sliding attention") + } + if layerUsesMoE(&cfg, 0) { + t.Fatal("layer 0 should be dense due to mlp_only_layers") + } + if !layerUsesMoE(&cfg, 1) { + t.Fatal("layer 1 should use MoE") + } + if got := numHeadsForLayer(&cfg, 1); got != 64 { + t.Fatalf("numHeadsForLayer(1) = %d, want 64", got) + } +} + +func TestParseConfigLagunaFP8RopeScaling(t *testing.T) { + skipIfNoMLX(t) + cfg, err := parseConfig([]byte(`{ + "hidden_size": 2048, + "intermediate_size": 8192, + "num_hidden_layers": 1, + "num_attention_heads": 48, + "num_key_value_heads": 8, + "head_dim": 128, + "vocab_size": 100352, + "max_position_embeddings": 131072, + "rope_theta": 500000, + "partial_rotary_factor": 0.5, + "rope_scaling": { + "rope_type": "yarn", + "factor": 32 + } + }`)) + if err != nil { + t.Fatal(err) + } + if cfg.FullRopeBase != 500000 { + t.Fatalf("FullRopeBase = %v, want 500000", cfg.FullRopeBase) + } + if cfg.FullRopeDim != 64 { + t.Fatalf("FullRopeDim = %d, want 64", cfg.FullRopeDim) + } +} + +func TestParseConfigLagunaGASchema(t *testing.T) { + skipIfNoMLX(t) + cfg, err := parseConfig([]byte(`{ + "model_type": "laguna", + "hidden_size": 2048, + "intermediate_size": 8192, + "moe_intermediate_size": 512, + "shared_expert_intermediate_size": 512, + "num_hidden_layers": 4, + "num_attention_heads": 48, + "num_attention_heads_per_layer": [48, 64, 64, 64], + "num_key_value_heads": 8, + "head_dim": 128, + "vocab_size": 100352, + "max_position_embeddings": 131072, + "layer_types": ["full_attention", "sliding_attention", "sliding_attention", "sliding_attention"], + "sliding_window": 512, + "mlp_layer_types": ["dense", "sparse", "sparse", "sparse"], + "num_experts": 256, + "num_experts_per_tok": 8, + "moe_routed_scaling_factor": 2.5, + "gating": true, + "rms_norm_eps": 1e-6, + "partial_rotary_factor": 0.5, + "rope_parameters": { + "full_attention": { + "rope_theta": 500000, + "rope_type": "yarn", + "factor": 32, + "original_max_position_embeddings": 4096, + "beta_fast": 64, + "beta_slow": 1, + "attention_factor": 1, + "partial_rotary_factor": 0.5 + }, + "sliding_attention": { + "rope_theta": 10000, + "rope_type": "default", + "partial_rotary_factor": 1.0 + } + } + }`)) + if err != nil { + t.Fatal(err) + } + + if cfg.Gating != "per-head" { + t.Fatalf("Gating = %q, want per-head", cfg.Gating) + } + if !cfg.NormTopKProb { + t.Fatal("NormTopKProb should default true") + } + if cfg.FullRopeBase != 500000 { + t.Fatalf("FullRopeBase = %v, want 500000", cfg.FullRopeBase) + } + if cfg.SlidingRopeBase != 10000 { + t.Fatalf("SlidingRopeBase = %v, want 10000", cfg.SlidingRopeBase) + } + if cfg.FullRopeDim != 64 { + t.Fatalf("FullRopeDim = %d, want 64", cfg.FullRopeDim) + } + if cfg.SlidingRopeDim != 128 { + t.Fatalf("SlidingRopeDim = %d, want 128", cfg.SlidingRopeDim) + } + if layerUsesMoE(&cfg, 0) { + t.Fatal("layer 0 should be dense due to mlp_layer_types") + } + if !layerUsesMoE(&cfg, 1) { + t.Fatal("layer 1 should use MoE") + } +} + +func TestTinyLagunaLoadAndForward(t *testing.T) { + skipIfNoMLX(t) + cfg, err := parseConfig([]byte(`{ + "model_type": "laguna", + "hidden_size": 8, + "intermediate_size": 12, + "moe_intermediate_size": 4, + "shared_expert_intermediate_size": 4, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_attention_heads_per_layer": [2, 2], + "num_key_value_heads": 1, + "head_dim": 4, + "vocab_size": 16, + "max_position_embeddings": 64, + "layer_types": ["full_attention", "sliding_attention"], + "sliding_window": 2, + "mlp_only_layers": [0], + "decoder_sparse_step": 1, + "num_experts": 2, + "num_experts_per_tok": 1, + "norm_topk_prob": false, + "moe_routed_scaling_factor": 2.5, + "gating": "per-head", + "rms_norm_eps": 1e-5, + "partial_rotary_factor": 0.5, + "rope_parameters": { + "rope_theta": 10000, + "rope_type": "yarn", + "factor": 2, + "original_max_position_embeddings": 16, + "beta_fast": 32, + "beta_slow": 1 + }, + "swa_rope_parameters": { + "partial_rotary_factor": 1.0, + "rope_theta": 10000, + "rope_type": "linear" + } + }`)) + if err != nil { + t.Fatal(err) + } + + m := &Model{ + Config: &cfg, + Layers: []*Layer{ + {LayerIdx: 0, IsSliding: false}, + {LayerIdx: 1, IsSliding: true}, + }, + } + tensors := tinyLagunaTensors() + if err := m.LoadWeights(tensors); err != nil { + t.Fatalf("LoadWeights failed: %v", err) + } + + tokens := mlx.FromValues([]int32{1, 2, 3}, 1, 3) + caches := m.NewCaches() + defer func() { + for _, c := range caches { + if c != nil { + c.Free() + } + } + }() + hidden := m.Forward(&batch.Batch{ + InputIDs: tokens, + SeqOffsets: []int32{0}, + SeqQueryLens: []int32{int32(tokens.Dim(1))}, + }, caches) + mlx.Eval(hidden) + if got := hidden.Dims(); len(got) != 3 || got[0] != 1 || got[1] != 3 || got[2] != 8 { + t.Fatalf("hidden shape = %v, want [1 3 8]", got) + } + + logits := m.Unembed(hidden) + mlx.Eval(logits) + if got := logits.Dims(); len(got) != 3 || got[0] != 1 || got[1] != 3 || got[2] != 16 { + t.Fatalf("logits shape = %v, want [1 3 16]", got) + } + for i, v := range logits.Floats() { + if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) { + t.Fatalf("logits[%d] is not finite: %v", i, v) + } + } +} + +func TestTinyLagunaLoadWeightsFusesDenseGateUp(t *testing.T) { + skipIfNoMLX(t) + cfg, err := parseConfig([]byte(`{ + "model_type": "laguna", + "hidden_size": 8, + "intermediate_size": 12, + "moe_intermediate_size": 4, + "shared_expert_intermediate_size": 4, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_attention_heads_per_layer": [2, 2], + "num_key_value_heads": 1, + "head_dim": 4, + "vocab_size": 16, + "max_position_embeddings": 64, + "layer_types": ["full_attention", "sliding_attention"], + "sliding_window": 2, + "mlp_only_layers": [0], + "decoder_sparse_step": 1, + "num_experts": 2, + "num_experts_per_tok": 1, + "norm_topk_prob": false, + "moe_routed_scaling_factor": 2.5, + "gating": "per-head", + "rms_norm_eps": 1e-5 + }`)) + if err != nil { + t.Fatal(err) + } + + m := &Model{ + Config: &cfg, + Layers: []*Layer{ + {LayerIdx: 0, IsSliding: false}, + {LayerIdx: 1, IsSliding: true}, + }, + } + if err := m.LoadWeights(tinyLagunaTensors()); err != nil { + t.Fatalf("LoadWeights failed: %v", err) + } + + moe, ok := m.Layers[1].MLP.(*SparseMoE) + if !ok { + t.Fatalf("layer 1 MLP type = %T, want *SparseMoE", m.Layers[1].MLP) + } + if !moe.SwitchMLP.UseFusedGateUp { + t.Fatal("expected dense SwitchMLP to fuse gate/up expert weights") + } + if moe.SwitchMLP.GateUpWeight == nil { + t.Fatal("expected fused GateUpWeight to be populated") + } + if got, want := moe.SwitchMLP.GateUpWeight.Dims(), []int{2, 8, 8}; len(got) != len(want) || got[0] != want[0] || got[1] != want[1] || got[2] != want[2] { + t.Fatalf("GateUpWeight dims = %v, want %v", got, want) + } +} + +func TestSparseMoERouteBiasAffectsSelectionNotRoutingWeights(t *testing.T) { + skipIfNoMLX(t) + cfg := &Config{ + HiddenSize: 1, + NumExperts: 2, + NumExpertsPerTok: 1, + NormTopKProb: false, + } + + moe := &SparseMoE{ + Gate: nn.NewLinear(mlx.FromValues([]float32{-4, -3}, 2, 1).AsType(mlx.DTypeBFloat16), nil), + EScoreCorrectionBias: mlx.FromValues([]float32{0.5, 0}, 2), + } + + xFlat := mlx.FromValues([]float32{1}, 1, int(cfg.HiddenSize)).AsType(mlx.DTypeBFloat16) + scores, inds := moe.route(xFlat, cfg) + scores = scores.AsType(mlx.DTypeFloat32) + inds = inds.AsType(mlx.DTypeInt32) + mlx.Eval(scores, inds) + + gates := moe.Gate.Forward(xFlat).AsType(mlx.DTypeFloat32) + probs := mlx.Sigmoid(gates) + mlx.Eval(probs) + probVals := probs.Floats() + if probVals[0] >= probVals[1] { + t.Fatalf("expected unbiased sigmoid scores to prefer expert 1, got %v", probVals) + } + if probVals[0]+0.5 <= probVals[1] { + t.Fatalf("expected bias to flip selection to expert 0, got probs=%v", probVals) + } + if got := inds.Ints(); len(got) != 1 || got[0] != 0 { + t.Fatalf("selected experts = %v, want [0]", got) + } + if got := scores.Floats(); len(got) != 1 || math.Abs(float64(got[0]-probVals[0])) > 1e-6 { + t.Fatalf("routing weights = %v, want [%v] using unbiased sigmoid scores", got, probVals[0]) + } +} + +func TestSwitchMLPFusedGateUpMatchesSeparate(t *testing.T) { + skipIfNoMLX(t) + cfg := &Config{HiddenSize: 4, NumExpertsPerTok: 2} + B, L := int32(2), int32(3) + xVals := make([]float32, int(B*L*cfg.HiddenSize)) + for i := range xVals { + xVals[i] = float32((i%17)-8) * 0.01 + } + x := mlx.FromValues(xVals, int(B), int(L), int(cfg.HiddenSize)).AsType(mlx.DTypeBFloat16) + + indicesVals := make([]int32, B*L*cfg.NumExpertsPerTok) + for i := 0; i < len(indicesVals); i += int(cfg.NumExpertsPerTok) { + indicesVals[i] = int32((i / int(cfg.NumExpertsPerTok)) % 2) + indicesVals[i+1] = int32(((i / int(cfg.NumExpertsPerTok)) + 1) % 2) + } + indices := mlx.FromValues(indicesVals, int(B*L), int(cfg.NumExpertsPerTok)) + + separate := &SwitchMLP{ + GateWeight: makePatternExpertWeight(2, 4, 3, 0.011), + UpWeight: makePatternExpertWeight(2, 4, 3, 0.017), + DownWeight: makePatternExpertWeight(2, 3, 4, 0.013), + } + fused := &SwitchMLP{ + GateUpWeight: fuseExpertStacks(separate.GateWeight, separate.UpWeight, 2), + DownWeight: separate.DownWeight, + UseFusedGateUp: true, + } + + gotSeparate := separate.Forward(x, indices, cfg) + gotFused := fused.Forward(x, indices, cfg) + mlx.Eval(gotSeparate, gotFused) + + gotFusedF32 := gotFused.AsType(mlx.DTypeFloat32) + gotSeparateF32 := gotSeparate.AsType(mlx.DTypeFloat32) + mlx.Eval(gotFusedF32, gotSeparateF32) + assertFloatSlicesClose(t, gotFusedF32.Floats(), gotSeparateF32.Floats(), 1e-5) +} + +func TestCombinedTensorGlobalScaleIgnoresInputGlobalScale(t *testing.T) { + skipIfNoMLX(t) + tensors := map[string]*mlx.Array{ + "proj.weight.global_scale": mlx.FromValues([]float32{0.25}, 1), + "proj.weight.input_global_scale": mlx.FromValues([]float32{8}, 1), + } + + got, _ := combinedTensorGlobalScale(tensors, "proj.weight") + if got == nil { + t.Fatal("combinedTensorGlobalScale returned nil") + } + mlx.Eval(got) + vals := got.Floats() + if len(vals) != 1 || vals[0] != 0.25 { + t.Fatalf("combinedTensorGlobalScale = %v, want [0.25]", vals) + } +} + +func tinyLagunaTensors() map[string]*mlx.Array { + tensors := map[string]*mlx.Array{ + "model.embed_tokens.weight": weights(16, 8), + "model.norm.weight": ones(8), + "lm_head.weight": weights(16, 8), + } + for layer := range 2 { + prefix := "model.layers." + string(rune('0'+layer)) + tensors[prefix+".input_layernorm.weight"] = ones(8) + tensors[prefix+".post_attention_layernorm.weight"] = ones(8) + tensors[prefix+".self_attn.q_proj.weight"] = weights(8, 8) + tensors[prefix+".self_attn.k_proj.weight"] = weights(4, 8) + tensors[prefix+".self_attn.v_proj.weight"] = weights(4, 8) + tensors[prefix+".self_attn.o_proj.weight"] = weights(8, 8) + tensors[prefix+".self_attn.g_proj.weight"] = weights(2, 8) + tensors[prefix+".self_attn.q_norm.weight"] = ones(4) + tensors[prefix+".self_attn.k_norm.weight"] = ones(4) + } + + tensors["model.layers.0.mlp.gate_proj.weight"] = weights(12, 8) + tensors["model.layers.0.mlp.up_proj.weight"] = weights(12, 8) + tensors["model.layers.0.mlp.down_proj.weight"] = weights(8, 12) + + tensors["model.layers.1.mlp.gate.weight"] = weights(2, 8) + tensors["model.layers.1.mlp.experts.e_score_correction_bias"] = mlx.FromValues([]float32{0.1, -0.1}, 2) + for expert := range 2 { + prefix := "model.layers.1.mlp.experts." + string(rune('0'+expert)) + tensors[prefix+".gate_proj.weight"] = weights(4, 8) + tensors[prefix+".up_proj.weight"] = weights(4, 8) + tensors[prefix+".down_proj.weight"] = weights(8, 4) + } + tensors["model.layers.1.mlp.shared_expert.gate_proj.weight"] = weights(4, 8) + tensors["model.layers.1.mlp.shared_expert.up_proj.weight"] = weights(4, 8) + tensors["model.layers.1.mlp.shared_expert.down_proj.weight"] = weights(8, 4) + return tensors +} + +func makeExpertWeight(vals []float32, dims ...int) *mlx.Array { + return mlx.FromValues(vals, dims...).AsType(mlx.DTypeBFloat16) +} + +func makePatternExpertWeight(numExperts, rows, cols int, scale float32) *mlx.Array { + vals := make([]float32, numExperts*rows*cols) + for i := range vals { + vals[i] = float32((i%23)-11) * scale + } + return makeExpertWeight(vals, numExperts, rows, cols) +} + +func assertFloatSlicesClose(t *testing.T, got, want []float32, tol float64) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("length mismatch: got %d want %d", len(got), len(want)) + } + for i := range got { + if math.Abs(float64(got[i]-want[i])) > tol { + t.Fatalf("value[%d] = %v, want %v (tol=%g)", i, got[i], want[i], tol) + } + } +} + +func weights(rows, cols int) *mlx.Array { + vals := make([]float32, rows*cols) + for i := range vals { + vals[i] = float32((i%7)-3) * 0.01 + } + return mlx.FromValues(vals, rows, cols) +} + +func ones(n int) *mlx.Array { + vals := make([]float32, n) + for i := range vals { + vals[i] = 1 + } + return mlx.FromValues(vals, n) +} + +func skipIfNoMLX(t *testing.T) { + t.Helper() + if err := mlx.CheckInit(); err != nil { + t.Skipf("MLX not available: %v", err) + } +} diff --git a/x/tokenizer/tokenizer_load.go b/x/tokenizer/tokenizer_load.go index efd086628..d7e45edc5 100644 --- a/x/tokenizer/tokenizer_load.go +++ b/x/tokenizer/tokenizer_load.go @@ -447,7 +447,9 @@ func extractPretokenizer(data json.RawMessage) string { if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" { for _, pt := range seq.Pretokenizers { if pt.Type == "Split" && pt.Pattern.Regex != "" { - return pt.Pattern.Regex + if _, err := regexp.Compile(rewritePatternForRE2(pt.Pattern.Regex)); err == nil { + return pt.Pattern.Regex + } } } } diff --git a/x/tokenizer/tokenizer_load_test.go b/x/tokenizer/tokenizer_load_test.go index caf2b0d35..54e5023c1 100644 --- a/x/tokenizer/tokenizer_load_test.go +++ b/x/tokenizer/tokenizer_load_test.go @@ -22,3 +22,31 @@ func TestLoadFromBytesRejectsWordPiece(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +func TestExtractPretokenizerSkipsUnsupportedSequenceSplit(t *testing.T) { + data := []byte(`{ + "type": "Sequence", + "pretokenizers": [ + { + "type": "Split", + "pattern": { + "Regex": "(?:\\r?\\n)+(?!\\r?\\n)" + } + }, + { + "type": "Split", + "pattern": { + "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + } + } + ] + }`) + + pattern := extractPretokenizer(data) + if pattern == "" { + t.Fatal("expected supported Split pretokenizer") + } + if strings.Contains(pattern, `(?!\r?\n)`) { + t.Fatalf("selected unsupported newline splitter: %q", pattern) + } +}