New models (#15861)

* mlx: add laguna model support

* convert: support fp8 safetensors import

Decode HF F8_E4M3 safetensors with block scale companions into GGUF-supported tensor types, and record which output tensors came from FP8 source weights.

Use that source-precision metadata during create quantization: default FP8-sourced GGUFs to Q8_0, keep non-FP8 tensors at their original precision for Q8_0, and promote non-FP8 quantizable tensors to Q8_0 for Q4_K requests.

* ggml: add laguna model support

* server: preserve generate logprobs with builtin parsers

Generate requests were dropping logprob-only chunks whenever a builtin parser buffered visible content. Chat already handled this case, but generate only forwarded chunks with visible response, thinking, or tool-call output.

Keep generate chunks that carry logprobs even when the builtin parser has not flushed visible content yet, and add a regression test that exercises the behavior with a generic thinking parser.

* review comments - perf improvements

* ggml: implement nemotron 3 nano omni

* add poolside integration

* update poolside doc

* adapt to new cache setup

* fix test

* fix test

---------

Co-authored-by: Eva Ho <hoyyeva@gmail.com>
This commit is contained in:
Daniel Hiltgen
2026-04-28 11:50:12 -07:00
committed by GitHub
parent 2bbe2405fe
commit 87288ced4f
62 changed files with 11284 additions and 633 deletions

View File

@@ -75,8 +75,7 @@ func TestIntegrationLookup(t *testing.T) {
}
func TestIntegrationRegistry(t *testing.T) {
expectedIntegrations := []string{"claude", "codex", "kimi", "droid", "opencode", "hermes"}
expectedIntegrations := []string{"claude", "codex", "kimi", "droid", "opencode", "hermes", "pool"}
for _, name := range expectedIntegrations {
t.Run(name, func(t *testing.T) {
r, ok := integrations[name]
@@ -1494,6 +1493,11 @@ func TestIntegration_InstallHint(t *testing.T) {
input: "openclaw",
wantURL: "https://docs.openclaw.ai",
},
{
name: "pool has hint",
input: "pool",
wantURL: "https://github.com/poolsideai/pool",
},
{
name: "unknown has no hint",
input: "unknown",
@@ -1549,7 +1553,19 @@ func TestListIntegrationInfos(t *testing.T) {
for _, info := range infos {
got = append(got, info.Name)
}
if diff := compareStrings(got, integrationOrder); diff != "" {
want := append([]string(nil), integrationOrder...)
if poolsideGOOS == "windows" {
filtered := make([]string, 0, len(want))
for _, name := range want {
if name != "pool" {
filtered = append(filtered, name)
}
}
want = filtered
}
if diff := compareStrings(got, want); diff != "" {
t.Fatalf("launcher integration order mismatch: %s", diff)
}
})
@@ -1567,6 +1583,9 @@ func TestListIntegrationInfos(t *testing.T) {
t.Run("includes known integrations", func(t *testing.T) {
known := map[string]bool{"claude": false, "codex": false, "opencode": false}
if poolsideGOOS != "windows" {
known["pool"] = false
}
for _, info := range infos {
if _, ok := known[info.Name]; ok {
known[info.Name] = true
@@ -1601,6 +1620,17 @@ func TestListIntegrationInfos(t *testing.T) {
}
})
}
func TestListIntegrationInfos_HidesPoolsideOnWindows(t *testing.T) {
prev := poolsideGOOS
poolsideGOOS = "windows"
t.Cleanup(func() { poolsideGOOS = prev })
for _, info := range ListIntegrationInfos() {
if info.Name == "pool" {
t.Fatal("expected pool to be hidden on Windows")
}
}
}
func TestBuildModelList_Descriptions(t *testing.T) {
t.Run("installed recommended has base description", func(t *testing.T) {
@@ -1707,6 +1737,20 @@ func TestIntegration_AutoInstallable(t *testing.T) {
}
}
func TestEnsureIntegrationInstalled_PoolsideUnsupportedOnWindows(t *testing.T) {
prev := poolsideGOOS
poolsideGOOS = "windows"
t.Cleanup(func() { poolsideGOOS = prev })
err := EnsureIntegrationInstalled("pool", &Poolside{})
if err == nil {
t.Fatal("expected Windows unsupported error")
}
if !strings.Contains(err.Error(), "not currently supported on Windows") {
t.Fatalf("expected Windows warning, got %v", err)
}
}
func TestIntegrationModels(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)

View File

@@ -213,6 +213,7 @@ Supported integrations:
opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot)
pi Pi
pool Poolside
vscode    VS Code (aliases: code)
Examples:

51
cmd/launch/poolside.go Normal file
View File

@@ -0,0 +1,51 @@
package launch
import (
"fmt"
"os"
"os/exec"
"runtime"
"github.com/ollama/ollama/envconfig"
)
// Poolside implements Runner for Poolside's CLI.
type Poolside struct{}
var poolsideGOOS = runtime.GOOS
func (p *Poolside) String() string { return "Poolside" }
func poolsideUnsupportedError() error {
return fmt.Errorf("Warning: Poolside is not currently supported on Windows")
}
func (p *Poolside) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "-m", model)
}
args = append(args, extra...)
return args
}
func (p *Poolside) Run(model string, args []string) error {
if poolsideGOOS == "windows" {
return poolsideUnsupportedError()
}
bin, err := exec.LookPath("pool")
if err != nil {
return fmt.Errorf("pool is not installed")
}
cmd := exec.Command(bin, p.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"POOLSIDE_STANDALONE_BASE_URL="+envconfig.Host().String()+"/v1",
"POOLSIDE_API_KEY=ollama",
)
return cmd.Run()
}

View File

@@ -0,0 +1,88 @@
package launch
import (
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func TestPoolsideArgs(t *testing.T) {
p := &Poolside{}
tests := []struct {
name string
model string
extra []string
want []string
}{
{name: "with model", model: "qwen3.5", want: []string{"-m", "qwen3.5"}},
{name: "without model", extra: []string{"session"}, want: []string{"session"}},
{name: "with model and extra args", model: "llama3.2", extra: []string{"--foo", "bar"}, want: []string{"-m", "llama3.2", "--foo", "bar"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := p.args(tt.model, tt.extra)
if !slices.Equal(got, tt.want) {
t.Fatalf("args(%q, %v) = %v, want %v", tt.model, tt.extra, got, tt.want)
}
})
}
}
func TestPoolsideRunSetsOllamaEnv(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binary")
}
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "pool.log")
poolPath := filepath.Join(tmpDir, "pool")
script := "#!/bin/sh\n" +
"printf 'base=%s\\nkey=%s\\nargs=%s\\n' \"$POOLSIDE_STANDALONE_BASE_URL\" \"$POOLSIDE_API_KEY\" \"$*\" > \"" + logPath + "\"\n"
if err := os.WriteFile(poolPath, []byte(script), 0o755); err != nil {
t.Fatalf("failed to write fake pool binary: %v", err)
}
t.Setenv("PATH", tmpDir)
t.Setenv("OLLAMA_HOST", "http://127.0.0.1:11434")
p := &Poolside{}
if err := p.Run("qwen3.5", []string{"session"}); err != nil {
t.Fatalf("Run returned error: %v", err)
}
data, err := os.ReadFile(logPath)
if err != nil {
t.Fatalf("failed to read pool log: %v", err)
}
got := string(data)
if !strings.Contains(got, "base=http://127.0.0.1:11434/v1") {
t.Fatalf("expected Poolside base URL override in log, got:\n%s", got)
}
if !strings.Contains(got, "key=ollama") {
t.Fatalf("expected Poolside API key override in log, got:\n%s", got)
}
if !strings.Contains(got, "args=-m qwen3.5 session") {
t.Fatalf("expected model and extra args in log, got:\n%s", got)
}
}
func TestPoolsideRunWindowsUnsupported(t *testing.T) {
prev := poolsideGOOS
poolsideGOOS = "windows"
t.Cleanup(func() { poolsideGOOS = prev })
p := &Poolside{}
err := p.Run("kimi-k2.6:cloud", nil)
if err == nil {
t.Fatal("expected Windows unsupported error")
}
if !strings.Contains(err.Error(), "not currently supported on Windows") {
t.Fatalf("expected Windows warning, got %v", err)
}
}

View File

@@ -33,7 +33,7 @@ type IntegrationInfo struct {
Description string
}
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi"}
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi", "pool"}
var integrationSpecs = []*IntegrationSpec{
{
@@ -166,6 +166,18 @@ var integrationSpecs = []*IntegrationSpec{
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
},
},
{
Name: "pool",
Runner: &Poolside{},
Description: "Poolside's software agent for enterprise development",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("pool")
return err == nil
},
URL: "https://github.com/poolsideai/pool",
},
},
{
Name: "hermes",
Runner: &Hermes{},
@@ -285,6 +297,9 @@ func ListVisibleIntegrationSpecs() []IntegrationSpec {
if spec.Hidden {
continue
}
if spec.Name == "pool" && poolsideGOOS == "windows" {
continue
}
visible = append(visible, *spec)
}
@@ -399,6 +414,10 @@ func EnsureIntegrationInstalled(name string, runner Runner) error {
return fmt.Errorf("%s is not installed", runner)
}
if integration.spec.Name == "pool" && poolsideGOOS == "windows" {
return poolsideUnsupportedError()
}
if integration.installed {
return nil
}

View File

@@ -45,6 +45,14 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
return filepath.Join(home, ".pi", "agent", "models.json")
},
},
{
name: "pool",
binary: "pool",
runner: &Poolside{},
checkPath: func(home string) string {
return filepath.Join(home, ".poolside", "config")
},
},
{
name: "kimi",
binary: "kimi",
@@ -57,6 +65,10 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.name == "pool" && poolsideGOOS == "windows" {
t.Skip("Poolside is intentionally unsupported on Windows")
}
home := t.TempDir()
setTestHome(t, home)

View File

@@ -316,6 +316,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &deepseek2Model{}
case "Glm4MoeLiteForCausalLM":
conv = &glm4MoeLiteModel{}
case "LagunaForCausalLM":
conv = &lagunaModel{}
case "GlmOcrForConditionalGeneration":
conv = &glmOcrModel{}
case "Lfm2ForCausalLM", "Lfm2MoeForCausalLM":
@@ -324,6 +326,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &lfm2VLTextModel{}
case "Qwen3NextForCausalLM", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration":
conv = &qwen3NextModel{}
case "NemotronH_Nano_VL_V2", "NemotronH_Nano_Omni_Reasoning_V3":
conv = &nemotronHNanoVLModel{}
case "NemotronHForCausalLM":
conv = &nemotronHModel{}
default:
@@ -387,6 +391,10 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
}
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
for k, v := range sourceTensorKV(ts) {
kv[k] = v
}
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)

604
convert/convert_laguna.go Normal file
View File

@@ -0,0 +1,604 @@
package convert
import (
"cmp"
"encoding/json"
"fmt"
iofs "io/fs"
"math"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type lagunaModel struct {
ModelParameters
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
RMSNormEPS float32 `json:"rms_norm_eps"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
SlidingWindow uint32 `json:"sliding_window"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
Gating lagunaGatingMode `json:"gating"`
QKNormType string `json:"qk_norm_type"`
LayerTypes []string `json:"layer_types"`
NumAttentionHeadsPerLayer []uint32 `json:"num_attention_heads_per_layer"`
NumExperts uint32 `json:"num_experts"`
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
SharedExpertIntermediateSize uint32 `json:"shared_expert_intermediate_size"`
NormTopKProb bool `json:"norm_topk_prob"`
MoeRoutedScalingFactor float32 `json:"moe_routed_scaling_factor"`
MoERouterUseSigmoid bool `json:"moe_router_use_sigmoid"`
MoEApplyRouterWeightOnInput bool `json:"moe_apply_router_weight_on_input"`
DecoderSparseStep uint32 `json:"decoder_sparse_step"`
MLPOnlyLayers []uint32 `json:"mlp_only_layers"`
MLPLayerTypes []string `json:"mlp_layer_types"`
RopeParameters lagunaRopeParameters `json:"rope_parameters"`
SwaRopeParameters lagunaRopeParameters `json:"swa_rope_parameters"`
SwaAttentionSinkEnabled bool `json:"swa_attention_sink_enabled"`
}
type lagunaGatingMode string
type lagunaRopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
RopeType string `json:"rope_type"`
Type string `json:"type"`
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
BetaSlow float32 `json:"beta_slow"`
BetaFast float32 `json:"beta_fast"`
AttentionFactor float32 `json:"attention_factor"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
}
type lagunaRopeConfig struct {
flat lagunaRopeParameters
full lagunaRopeParameters
sliding lagunaRopeParameters
nested bool
}
func (g *lagunaGatingMode) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err == nil {
*g = lagunaGatingMode(s)
return nil
}
var enabled bool
if err := json.Unmarshal(b, &enabled); err == nil {
if enabled {
*g = "true"
} else {
*g = "false"
}
return nil
}
if string(b) == "null" {
return nil
}
return fmt.Errorf("unsupported Laguna gating JSON value %s", string(b))
}
func (g lagunaGatingMode) perHead() bool {
return strings.EqualFold(string(g), "per-head") || strings.EqualFold(string(g), "true")
}
func (r *lagunaRopeConfig) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
return nil
}
var probe map[string]json.RawMessage
if err := json.Unmarshal(b, &probe); err != nil {
return err
}
if len(probe) == 0 {
return nil
}
if raw, ok := probe["full_attention"]; ok {
r.nested = true
if err := json.Unmarshal(raw, &r.full); err != nil {
return err
}
if raw = probe["sliding_attention"]; raw != nil {
if err := json.Unmarshal(raw, &r.sliding); err != nil {
return err
}
}
return nil
}
if raw, ok := probe["global_attention"]; ok {
r.nested = true
if err := json.Unmarshal(raw, &r.full); err != nil {
return err
}
if raw = probe["sliding_attention"]; raw != nil {
if err := json.Unmarshal(raw, &r.sliding); err != nil {
return err
}
}
return nil
}
return json.Unmarshal(b, &r.flat)
}
func (r lagunaRopeConfig) fullParams() lagunaRopeParameters {
if r.nested {
return r.full
}
return r.flat
}
func (r lagunaRopeConfig) slidingParams() (lagunaRopeParameters, bool) {
if !r.nested {
return lagunaRopeParameters{}, false
}
return r.sliding, true
}
func (r lagunaRopeParameters) ropeType() string {
return cmp.Or(r.RopeType, r.Type)
}
func (r lagunaRopeParameters) withDefaultPartialRotaryFactor(v float32) lagunaRopeParameters {
if r.PartialRotaryFactor == 0 {
r.PartialRotaryFactor = v
}
return r
}
func (r lagunaRopeParameters) empty() bool {
return r == (lagunaRopeParameters{})
}
type rawLagunaModel struct {
ModelParameters
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
RMSNormEPS float32 `json:"rms_norm_eps"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
SlidingWindow uint32 `json:"sliding_window"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
Gating lagunaGatingMode `json:"gating"`
QKNormType string `json:"qk_norm_type"`
LayerTypes []string `json:"layer_types"`
NumAttentionHeadsPerLayer []uint32 `json:"num_attention_heads_per_layer"`
NumExperts uint32 `json:"num_experts"`
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
SharedExpertIntermediateSize uint32 `json:"shared_expert_intermediate_size"`
NormTopKProb *bool `json:"norm_topk_prob"`
MoeRoutedScalingFactor float32 `json:"moe_routed_scaling_factor"`
MoERouterUseSigmoid *bool `json:"moe_router_use_sigmoid"`
MoEApplyRouterWeightOnInput bool `json:"moe_apply_router_weight_on_input"`
DecoderSparseStep uint32 `json:"decoder_sparse_step"`
MLPOnlyLayers []uint32 `json:"mlp_only_layers"`
MLPLayerTypes []string `json:"mlp_layer_types"`
RopeParameters lagunaRopeConfig `json:"rope_parameters"`
SwaRopeParameters lagunaRopeParameters `json:"swa_rope_parameters"`
SwaAttentionSinkEnabled bool `json:"swa_attention_sink_enabled"`
}
func (p *lagunaModel) UnmarshalJSON(b []byte) error {
var raw rawLagunaModel
if err := json.Unmarshal(b, &raw); err != nil {
return err
}
mlpOnlyLayers, err := lagunaDenseLayers(raw.MLPOnlyLayers, raw.MLPLayerTypes)
if err != nil {
return err
}
fullRope := raw.RopeParameters.fullParams().withDefaultPartialRotaryFactor(cmp.Or(raw.PartialRotaryFactor, float32(1)))
swaRope := raw.SwaRopeParameters
if nestedSwa, ok := raw.RopeParameters.slidingParams(); ok && !nestedSwa.empty() {
swaRope = nestedSwa
}
swaRope = swaRope.withDefaultPartialRotaryFactor(cmp.Or(fullRope.PartialRotaryFactor, float32(1)))
*p = lagunaModel{
ModelParameters: raw.ModelParameters,
NumHiddenLayers: raw.NumHiddenLayers,
HiddenSize: raw.HiddenSize,
IntermediateSize: raw.IntermediateSize,
NumAttentionHeads: raw.NumAttentionHeads,
NumKeyValueHeads: raw.NumKeyValueHeads,
HeadDim: raw.HeadDim,
RMSNormEPS: raw.RMSNormEPS,
MaxPositionEmbeddings: raw.MaxPositionEmbeddings,
SlidingWindow: raw.SlidingWindow,
PartialRotaryFactor: cmp.Or(raw.PartialRotaryFactor, fullRope.PartialRotaryFactor),
Gating: raw.Gating,
QKNormType: cmp.Or(raw.QKNormType, "rmsnorm"),
LayerTypes: raw.LayerTypes,
NumAttentionHeadsPerLayer: raw.NumAttentionHeadsPerLayer,
NumExperts: raw.NumExperts,
NumExpertsPerTok: raw.NumExpertsPerTok,
MoEIntermediateSize: raw.MoEIntermediateSize,
SharedExpertIntermediateSize: raw.SharedExpertIntermediateSize,
NormTopKProb: defaultBool(raw.NormTopKProb, true),
MoeRoutedScalingFactor: raw.MoeRoutedScalingFactor,
MoERouterUseSigmoid: defaultBool(raw.MoERouterUseSigmoid, true),
MoEApplyRouterWeightOnInput: raw.MoEApplyRouterWeightOnInput,
DecoderSparseStep: raw.DecoderSparseStep,
MLPOnlyLayers: mlpOnlyLayers,
MLPLayerTypes: raw.MLPLayerTypes,
RopeParameters: fullRope,
SwaRopeParameters: swaRope,
SwaAttentionSinkEnabled: raw.SwaAttentionSinkEnabled,
}
return nil
}
func defaultBool(v *bool, fallback bool) bool {
if v == nil {
return fallback
}
return *v
}
const (
lagunaGatingFuncSoftmax uint32 = 1
lagunaGatingFuncSigmoid uint32 = 2
lagunaLayerTypeGlobal uint32 = 0
lagunaLayerTypeSliding uint32 = 1
)
func (p *lagunaModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "laguna"
// Laguna's chat template and built-in renderer both emit the leading
// special token explicitly. Auto-prepending BOS here would duplicate it.
kv["tokenizer.ggml.add_bos_token"] = false
kv["tokenizer.ggml.pre"] = "laguna"
// Laguna does not need tokenizer.chat_template at runtime: Ollama create
// sets the Laguna renderer/parser from the architecture, and the renderer
// owns prompt formatting.
delete(kv, "tokenizer.chat_template")
kv["laguna.block_count"] = p.NumHiddenLayers
kv["laguna.context_length"] = p.MaxPositionEmbeddings
kv["laguna.embedding_length"] = p.HiddenSize
kv["laguna.feed_forward_length"] = p.IntermediateSize
if len(p.NumAttentionHeadsPerLayer) == int(p.NumHiddenLayers) {
kv["laguna.attention.head_count"] = p.NumAttentionHeadsPerLayer
} else {
kv["laguna.attention.head_count"] = p.NumAttentionHeads
}
kv["laguna.attention.head_count_kv"] = p.NumKeyValueHeads
kv["laguna.attention.key_length"] = p.HeadDim
kv["laguna.attention.value_length"] = p.HeadDim
kv["laguna.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["laguna.attention.sliding_window"] = p.SlidingWindow
kv["laguna.attention.sink_enabled"] = p.SwaAttentionSinkEnabled
if len(p.LayerTypes) > 0 {
encoded := make([]uint32, len(p.LayerTypes))
slidingPattern := make([]bool, len(p.LayerTypes))
for i, layerType := range p.LayerTypes {
if lagunaLayerIsSliding(layerType) {
encoded[i] = lagunaLayerTypeSliding
slidingPattern[i] = true
} else {
encoded[i] = lagunaLayerTypeGlobal
}
}
kv["laguna.attention.layer_types"] = encoded
kv["laguna.attention.sliding_window_pattern"] = slidingPattern
}
if p.Gating.perHead() {
kv["laguna.attention.gating_type"] = uint32(1)
} else {
kv["laguna.attention.gating_type"] = uint32(0)
}
kv["laguna.attention.qk_norm"] = p.QKNormType == "rmsnorm"
kv["laguna.expert_count"] = p.NumExperts
kv["laguna.expert_used_count"] = p.NumExpertsPerTok
kv["laguna.expert_feed_forward_length"] = p.MoEIntermediateSize
kv["laguna.expert_shared_feed_forward_length"] = p.SharedExpertIntermediateSize
kv["laguna.expert_shared_count"] = uint32(1)
kv["laguna.expert_weights_norm"] = p.NormTopKProb
kv["laguna.expert_weights_scale"] = p.MoeRoutedScalingFactor
kv["laguna.expert_gating_func"] = lagunaMoeGatingFunc(p.MoERouterUseSigmoid)
kv["laguna.decoder_sparse_step"] = cmp.Or(p.DecoderSparseStep, uint32(1))
if leading, ok := lagunaLeadingDensePrefix(p.MLPOnlyLayers); ok {
kv["laguna.leading_dense_block_count"] = leading
}
if len(p.MLPOnlyLayers) > 0 {
kv["laguna.dense_layers"] = p.MLPOnlyLayers
}
ropeType := p.RopeParameters.ropeType()
kv["laguna.rope.freq_base"] = cmp.Or(p.RopeParameters.RopeTheta, float32(10000))
kv["laguna.rope.scaling.type"] = ropeType
ropeFactor := cmp.Or(p.RopeParameters.Factor, float32(1))
kv["laguna.rope.scaling.factor"] = ropeFactor
kv["laguna.rope.scaling.original_context_length"] = p.RopeParameters.OriginalMaxPositionEmbeddings
kv["laguna.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
kv["laguna.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow
kv["laguna.rope.scaling.attn_factor"] = lagunaAttentionFactor(ropeType, ropeFactor, p.RopeParameters.AttentionFactor)
kv["laguna.rope.partial_rotary_factor"] = cmp.Or(p.PartialRotaryFactor, float32(1))
swaRopeType := p.SwaRopeParameters.ropeType()
kv["laguna.rope.swa.freq_base"] = cmp.Or(p.SwaRopeParameters.RopeTheta, float32(10000))
kv["laguna.rope.swa.scaling.type"] = cmp.Or(swaRopeType, "linear")
kv["laguna.rope.swa.scaling.factor"] = cmp.Or(p.SwaRopeParameters.Factor, float32(1))
kv["laguna.rope.swa.partial_rotary_factor"] = cmp.Or(p.SwaRopeParameters.PartialRotaryFactor, float32(1))
headDim := p.HeadDim
if headDim == 0 && p.NumAttentionHeads > 0 {
headDim = p.HiddenSize / p.NumAttentionHeads
}
kv["laguna.rope.dimension_count"] = lagunaRopeDim(headDim, cmp.Or(p.PartialRotaryFactor, float32(1)))
kv["laguna.rope.swa.dimension_count"] = lagunaRopeDim(headDim, cmp.Or(p.SwaRopeParameters.PartialRotaryFactor, float32(1)))
return kv
}
func (p *lagunaModel) parseMore(_ iofs.FS) error {
return p.validate()
}
func (p *lagunaModel) validate() error {
if p.NumHiddenLayers == 0 {
return fmt.Errorf("laguna: num_hidden_layers must be set")
}
if p.HiddenSize == 0 {
return fmt.Errorf("laguna: hidden_size must be set")
}
if p.HeadDim == 0 {
return fmt.Errorf("laguna: head_dim must be set")
}
if p.NumKeyValueHeads == 0 {
return fmt.Errorf("laguna: num_key_value_heads must be set")
}
if p.SwaAttentionSinkEnabled {
return fmt.Errorf("laguna: unsupported swa_attention_sink_enabled=true")
}
if !p.Gating.perHead() {
return fmt.Errorf("laguna: unsupported attention gating %q: only gating=\"per-head\" is supported", p.Gating)
}
if p.QKNormType != "rmsnorm" {
return fmt.Errorf("laguna: unsupported qk_norm_type %q: only rmsnorm is supported", p.QKNormType)
}
if !p.MoERouterUseSigmoid {
return fmt.Errorf("laguna: unsupported moe_router_use_sigmoid=false")
}
if p.MoEApplyRouterWeightOnInput {
return fmt.Errorf("laguna: unsupported moe_apply_router_weight_on_input=true")
}
if p.DecoderSparseStep != 0 && p.DecoderSparseStep != 1 {
return fmt.Errorf("laguna: unsupported decoder_sparse_step=%d: only 1 is supported", p.DecoderSparseStep)
}
if len(p.MLPOnlyLayers) != 1 || p.MLPOnlyLayers[0] != 0 {
return fmt.Errorf("laguna: unsupported mlp_only_layers=%v: only [0] is supported", p.MLPOnlyLayers)
}
if p.NumExperts == 0 {
return fmt.Errorf("laguna: num_experts must be set")
}
if p.NumExpertsPerTok == 0 {
return fmt.Errorf("laguna: num_experts_per_tok must be set")
}
if p.MoEIntermediateSize == 0 {
return fmt.Errorf("laguna: moe_intermediate_size must be set")
}
if p.SharedExpertIntermediateSize == 0 {
return fmt.Errorf("laguna: shared_expert_intermediate_size must be set")
}
if len(p.LayerTypes) > 0 && len(p.LayerTypes) != int(p.NumHiddenLayers) {
return fmt.Errorf("laguna: layer_types has %d entries, expected %d", len(p.LayerTypes), p.NumHiddenLayers)
}
for i, layerType := range p.LayerTypes {
if !lagunaLayerIsGlobal(layerType) && !lagunaLayerIsSliding(layerType) {
return fmt.Errorf("laguna: unsupported layer_types[%d]=%q", i, layerType)
}
}
if len(p.NumAttentionHeadsPerLayer) > 0 && len(p.NumAttentionHeadsPerLayer) != int(p.NumHiddenLayers) {
return fmt.Errorf("laguna: num_attention_heads_per_layer has %d entries, expected %d", len(p.NumAttentionHeadsPerLayer), p.NumHiddenLayers)
}
if len(p.NumAttentionHeadsPerLayer) == 0 && p.NumAttentionHeads == 0 {
return fmt.Errorf("laguna: num_attention_heads or num_attention_heads_per_layer must be set")
}
for i, heads := range p.NumAttentionHeadsPerLayer {
if heads == 0 {
return fmt.Errorf("laguna: num_attention_heads_per_layer[%d] must be non-zero", i)
}
}
return nil
}
func (p *lagunaModel) numHeadsForLayer(layer uint32) uint32 {
if len(p.NumAttentionHeadsPerLayer) > int(layer) && p.NumAttentionHeadsPerLayer[layer] > 0 {
return p.NumAttentionHeadsPerLayer[layer]
}
return p.NumAttentionHeads
}
func (p *lagunaModel) layerUsesMoE(layer uint32) bool {
for _, denseLayer := range p.MLPOnlyLayers {
if denseLayer == layer {
return false
}
}
step := cmp.Or(p.DecoderSparseStep, uint32(1))
return p.NumExperts > 0 && (layer+1)%step == 0
}
func (p *lagunaModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"self_attn.g_proj", "attn_g",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_norm", "attn_k_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
"mlp.gate.weight", "ffn_gate_inp.weight",
"mlp.experts.e_score_correction_bias", "exp_probs_b.bias",
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
"mlp.shared_expert.up_proj", "ffn_up_shexp",
"mlp.shared_expert.down_proj", "ffn_down_shexp",
"mlp.experts.*.gate_proj", "ffn_gate_exps",
"mlp.experts.*.up_proj", "ffn_up_exps",
"mlp.experts.*.down_proj", "ffn_down_exps",
}
}
func (p *lagunaModel) Tensors(ts []Tensor) []*ggml.Tensor {
// Current Laguna drops store routed MoE experts as separate per-expert
// tensors. GGUF stores each projection as one stacked tensor. If future
// drops change expert naming or layout, update these patterns with a
// focused conversion test using the new tensor names.
merges := make([]merge, 0, p.NumHiddenLayers*3)
for i := range p.NumHiddenLayers {
merges = append(merges,
merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
},
merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
},
merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
},
)
}
out, rest := mergeTensors(ts, merges...)
for _, t := range rest {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *lagunaModel) specialTokenTypes() []string {
return []string{"bos", "eos", "pad", "unk"}
}
func lagunaLayerIsSliding(layerType string) bool {
return strings.EqualFold(layerType, "sliding_attention")
}
func lagunaLayerIsGlobal(layerType string) bool {
return strings.EqualFold(layerType, "full_attention") || strings.EqualFold(layerType, "global_attention")
}
func lagunaLeadingDensePrefix(layers []uint32) (uint32, bool) {
for i, v := range layers {
if v != uint32(i) {
return 0, false
}
}
return uint32(len(layers)), true
}
func lagunaDenseLayers(mlpOnlyLayers []uint32, mlpLayerTypes []string) ([]uint32, error) {
if len(mlpOnlyLayers) > 0 {
return mlpOnlyLayers, nil
}
if len(mlpLayerTypes) == 0 {
return nil, nil
}
denseLayers := make([]uint32, 0, len(mlpLayerTypes))
for i, layerType := range mlpLayerTypes {
switch {
case strings.EqualFold(layerType, "dense"):
denseLayers = append(denseLayers, uint32(i))
case strings.EqualFold(layerType, "sparse"):
default:
return nil, fmt.Errorf("laguna: unsupported mlp_layer_types[%d]=%q", i, layerType)
}
}
return denseLayers, nil
}
func lagunaMoeGatingFunc(useSigmoid bool) uint32 {
if useSigmoid {
return lagunaGatingFuncSigmoid
}
return lagunaGatingFuncSoftmax
}
func lagunaAttentionFactor(ropeType string, scaleFactor, attentionFactor float32) float32 {
if attentionFactor != 0 {
return attentionFactor
}
if strings.EqualFold(ropeType, "yarn") && scaleFactor > 1 {
return float32(0.1*math.Log(float64(scaleFactor)) + 1)
}
return 1
}
func lagunaRopeDim(headDim uint32, partialRotaryFactor float32) uint32 {
if headDim == 0 {
return 0
}
dim := uint32(float32(headDim) * partialRotaryFactor)
if dim == 0 || dim > headDim {
dim = headDim
}
if dim%2 != 0 {
dim--
}
if dim == 0 {
return headDim
}
return dim
}
var (
_ ModelConverter = (*lagunaModel)(nil)
_ moreParser = (*lagunaModel)(nil)
)

View File

@@ -0,0 +1,450 @@
package convert
import (
"encoding/json"
"fmt"
"io"
"math"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/fs/ggml"
)
type lagunaTestTensor struct {
tensorBase
}
func newLagunaTestTensor(name string, shape ...uint64) Tensor {
return &lagunaTestTensor{tensorBase: tensorBase{name: name, shape: shape}}
}
func (t *lagunaTestTensor) WriteTo(io.Writer) (int64, error) {
return 0, nil
}
func (t *lagunaTestTensor) Clone() Tensor {
return &lagunaTestTensor{tensorBase: tensorBase{
name: t.name,
shape: append([]uint64(nil), t.shape...),
}}
}
func TestLagunaReplacements(t *testing.T) {
p := lagunaModel{}
r := strings.NewReplacer(p.Replacements()...)
tests := []struct {
name string
in string
want string
}{
{"embed", "model.embed_tokens.weight", "token_embd.weight"},
{"final_norm", "model.norm.weight", "output_norm.weight"},
{"lm_head", "lm_head.weight", "output.weight"},
{"block prefix", "model.layers.7.input_layernorm.weight", "blk.7.attn_norm.weight"},
{"q", "model.layers.3.self_attn.q_proj.weight", "blk.3.attn_q.weight"},
{"k", "model.layers.3.self_attn.k_proj.weight", "blk.3.attn_k.weight"},
{"v", "model.layers.3.self_attn.v_proj.weight", "blk.3.attn_v.weight"},
{"o", "model.layers.3.self_attn.o_proj.weight", "blk.3.attn_output.weight"},
{"g", "model.layers.3.self_attn.g_proj.weight", "blk.3.attn_g.weight"},
{"q_norm", "model.layers.3.self_attn.q_norm.weight", "blk.3.attn_q_norm.weight"},
{"k_norm", "model.layers.3.self_attn.k_norm.weight", "blk.3.attn_k_norm.weight"},
{"post_attn_norm", "model.layers.3.post_attention_layernorm.weight", "blk.3.ffn_norm.weight"},
{"dense gate", "model.layers.0.mlp.gate_proj.weight", "blk.0.ffn_gate.weight"},
{"dense up", "model.layers.0.mlp.up_proj.weight", "blk.0.ffn_up.weight"},
{"dense down", "model.layers.0.mlp.down_proj.weight", "blk.0.ffn_down.weight"},
{"shexp gate", "model.layers.5.mlp.shared_expert.gate_proj.weight", "blk.5.ffn_gate_shexp.weight"},
{"shexp up", "model.layers.5.mlp.shared_expert.up_proj.weight", "blk.5.ffn_up_shexp.weight"},
{"shexp down", "model.layers.5.mlp.shared_expert.down_proj.weight", "blk.5.ffn_down_shexp.weight"},
{"router", "model.layers.5.mlp.gate.weight", "blk.5.ffn_gate_inp.weight"},
{"score bias", "model.layers.5.mlp.experts.e_score_correction_bias", "blk.5.exp_probs_b.bias"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := r.Replace(tc.in); got != tc.want {
t.Errorf("Replace(%q) = %q, want %q", tc.in, got, tc.want)
}
})
}
}
func TestLagunaValidateRejectsUnsupportedVariants(t *testing.T) {
base := validLagunaTestModel()
tests := []struct {
name string
edit func(*lagunaModel)
want string
}{
{
name: "per-element gating",
edit: func(m *lagunaModel) {
m.Gating = "per-element"
},
want: "unsupported attention gating",
},
{
name: "attention sinks",
edit: func(m *lagunaModel) {
m.SwaAttentionSinkEnabled = true
},
want: "swa_attention_sink_enabled=true",
},
{
name: "qk norm disabled",
edit: func(m *lagunaModel) {
m.QKNormType = "none"
},
want: "unsupported qk_norm_type",
},
{
name: "softmax moe",
edit: func(m *lagunaModel) {
m.MoERouterUseSigmoid = false
},
want: "moe_router_use_sigmoid=false",
},
{
name: "router weight on input",
edit: func(m *lagunaModel) {
m.MoEApplyRouterWeightOnInput = true
},
want: "moe_apply_router_weight_on_input=true",
},
{
name: "unknown layer type",
edit: func(m *lagunaModel) {
m.LayerTypes[1] = "local_attention"
},
want: "unsupported layer_types[1]",
},
{
name: "nonstandard dense layout",
edit: func(m *lagunaModel) {
m.MLPOnlyLayers = []uint32{0, 3}
},
want: "unsupported mlp_only_layers",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
m := base
m.LayerTypes = append([]string(nil), base.LayerTypes...)
m.NumAttentionHeadsPerLayer = append([]uint32(nil), base.NumAttentionHeadsPerLayer...)
m.MLPOnlyLayers = append([]uint32(nil), base.MLPOnlyLayers...)
tc.edit(&m)
err := m.validate()
if err == nil || !strings.Contains(err.Error(), tc.want) {
t.Fatalf("validate() error = %v, want substring %q", err, tc.want)
}
})
}
}
func TestLagunaGAConfigNormalizesBoolGatingAndNestedRope(t *testing.T) {
var m lagunaModel
if err := json.Unmarshal([]byte(`{
"architectures": ["LagunaForCausalLM"],
"num_hidden_layers": 1,
"hidden_size": 8,
"num_attention_heads": 2,
"num_key_value_heads": 1,
"head_dim": 4,
"gating": true,
"num_experts": 2,
"num_experts_per_tok": 1,
"moe_intermediate_size": 4,
"shared_expert_intermediate_size": 4,
"decoder_sparse_step": 1,
"mlp_layer_types": ["dense"],
"rope_parameters": {
"full_attention": {
"rope_theta": 500000,
"rope_type": "yarn",
"factor": 32,
"original_max_position_embeddings": 4096,
"beta_fast": 64,
"beta_slow": 1,
"attention_factor": 1,
"partial_rotary_factor": 0.5
},
"sliding_attention": {
"rope_theta": 10000,
"rope_type": "default",
"partial_rotary_factor": 1
}
}
}`), &m); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
if err := m.validate(); err != nil {
t.Fatalf("validate() error = %v", err)
}
if m.Gating != "true" {
t.Fatalf("Gating = %q, want raw true marker", m.Gating)
}
if !m.Gating.perHead() {
t.Fatal("expected bool gating to normalize as per-head support")
}
if m.QKNormType != "rmsnorm" {
t.Fatalf("QKNormType = %q, want rmsnorm default", m.QKNormType)
}
if !m.MoERouterUseSigmoid {
t.Fatal("MoERouterUseSigmoid should default true")
}
if !m.NormTopKProb {
t.Fatal("NormTopKProb should default true")
}
if diff := cmp.Diff(m.MLPOnlyLayers, []uint32{0}); diff != "" {
t.Fatalf("MLPOnlyLayers mismatch (-got +want):\n%s", diff)
}
if m.RopeParameters.RopeTheta != 500000 || m.RopeParameters.PartialRotaryFactor != 0.5 {
t.Fatalf("full rope = %#v, want theta=500000 partial=0.5", m.RopeParameters)
}
if m.SwaRopeParameters.RopeTheta != 10000 || m.SwaRopeParameters.PartialRotaryFactor != 1 {
t.Fatalf("swa rope = %#v, want theta=10000 partial=1", m.SwaRopeParameters)
}
}
func validLagunaTestModel() lagunaModel {
return lagunaModel{
ModelParameters: ModelParameters{
VocabSize: 32,
},
NumHiddenLayers: 2,
HiddenSize: 8,
IntermediateSize: 16,
NumAttentionHeads: 2,
NumKeyValueHeads: 1,
HeadDim: 4,
RMSNormEPS: 1e-6,
MaxPositionEmbeddings: 4096,
SlidingWindow: 512,
Gating: "per-head",
QKNormType: "rmsnorm",
LayerTypes: []string{"global_attention", "sliding_attention"},
NumAttentionHeadsPerLayer: []uint32{2, 2},
NumExperts: 2,
NumExpertsPerTok: 1,
MoEIntermediateSize: 4,
SharedExpertIntermediateSize: 4,
NormTopKProb: true,
MoeRoutedScalingFactor: 2.5,
MoERouterUseSigmoid: true,
DecoderSparseStep: 1,
MLPOnlyLayers: []uint32{0},
}
}
func validLagunaTestTensors(m lagunaModel) []Tensor {
ts := []Tensor{
newLagunaTestTensor("token_embd.weight", uint64(m.VocabSize), uint64(m.HiddenSize)),
newLagunaTestTensor("output_norm.weight", uint64(m.HiddenSize)),
}
for layer := range m.NumHiddenLayers {
prefix := fmt.Sprintf("blk.%d", layer)
heads := uint64(m.numHeadsForLayer(layer))
attnWidth := heads * uint64(m.HeadDim)
kvWidth := uint64(m.NumKeyValueHeads * m.HeadDim)
ts = append(ts,
newLagunaTestTensor(prefix+".attn_norm.weight", uint64(m.HiddenSize)),
newLagunaTestTensor(prefix+".ffn_norm.weight", uint64(m.HiddenSize)),
newLagunaTestTensor(prefix+".attn_q.weight", attnWidth, uint64(m.HiddenSize)),
newLagunaTestTensor(prefix+".attn_k.weight", kvWidth, uint64(m.HiddenSize)),
newLagunaTestTensor(prefix+".attn_v.weight", kvWidth, uint64(m.HiddenSize)),
newLagunaTestTensor(prefix+".attn_output.weight", uint64(m.HiddenSize), attnWidth),
newLagunaTestTensor(prefix+".attn_g.weight", heads, uint64(m.HiddenSize)),
newLagunaTestTensor(prefix+".attn_q_norm.weight", uint64(m.HeadDim)),
newLagunaTestTensor(prefix+".attn_k_norm.weight", uint64(m.HeadDim)),
)
if m.layerUsesMoE(layer) {
ts = append(ts,
newLagunaTestTensor(prefix+".ffn_gate_inp.weight", uint64(m.NumExperts), uint64(m.HiddenSize)),
newLagunaTestTensor(prefix+".exp_probs_b.bias", uint64(m.NumExperts)),
newLagunaTestTensor(prefix+".ffn_gate_shexp.weight", uint64(m.SharedExpertIntermediateSize), uint64(m.HiddenSize)),
newLagunaTestTensor(prefix+".ffn_up_shexp.weight", uint64(m.SharedExpertIntermediateSize), uint64(m.HiddenSize)),
newLagunaTestTensor(prefix+".ffn_down_shexp.weight", uint64(m.HiddenSize), uint64(m.SharedExpertIntermediateSize)),
)
for expert := range m.NumExperts {
ts = append(ts,
newLagunaTestTensor(fmt.Sprintf("%s.mlp.experts.%d.gate_proj.weight", prefix, expert), uint64(m.MoEIntermediateSize), uint64(m.HiddenSize)),
newLagunaTestTensor(fmt.Sprintf("%s.mlp.experts.%d.up_proj.weight", prefix, expert), uint64(m.MoEIntermediateSize), uint64(m.HiddenSize)),
newLagunaTestTensor(fmt.Sprintf("%s.mlp.experts.%d.down_proj.weight", prefix, expert), uint64(m.HiddenSize), uint64(m.MoEIntermediateSize)),
)
}
} else {
ts = append(ts,
newLagunaTestTensor(prefix+".ffn_gate.weight", uint64(m.IntermediateSize), uint64(m.HiddenSize)),
newLagunaTestTensor(prefix+".ffn_up.weight", uint64(m.IntermediateSize), uint64(m.HiddenSize)),
newLagunaTestTensor(prefix+".ffn_down.weight", uint64(m.HiddenSize), uint64(m.IntermediateSize)),
)
}
}
return ts
}
func TestLagunaTensorsMergeRoutedExperts(t *testing.T) {
m := validLagunaTestModel()
out := m.Tensors(validLagunaTestTensors(m))
tensors := make(map[string]*ggml.Tensor, len(out))
for _, t := range out {
tensors[t.Name] = t
}
tests := map[string][]uint64{
"blk.1.ffn_gate_exps.weight": {uint64(m.NumExperts), uint64(m.MoEIntermediateSize), uint64(m.HiddenSize)},
"blk.1.ffn_up_exps.weight": {uint64(m.NumExperts), uint64(m.MoEIntermediateSize), uint64(m.HiddenSize)},
"blk.1.ffn_down_exps.weight": {uint64(m.NumExperts), uint64(m.HiddenSize), uint64(m.MoEIntermediateSize)},
}
for name, wantShape := range tests {
tensor, ok := tensors[name]
if !ok {
t.Fatalf("missing merged tensor %q", name)
}
if diff := cmp.Diff(wantShape, tensor.Shape); diff != "" {
t.Fatalf("%s shape mismatch (-want +got):\n%s", name, diff)
}
}
for expert := range m.NumExperts {
name := fmt.Sprintf("blk.1.mlp.experts.%d.gate_proj.weight", expert)
if _, ok := tensors[name]; ok {
t.Fatalf("unexpected unmerged expert tensor %q", name)
}
}
}
func TestLagunaKVShape(t *testing.T) {
m := lagunaModel{
NumHiddenLayers: 4,
HiddenSize: 128,
IntermediateSize: 256,
NumAttentionHeads: 8,
NumKeyValueHeads: 4,
HeadDim: 16,
RMSNormEPS: 1e-6,
MaxPositionEmbeddings: 4096,
SlidingWindow: 512,
PartialRotaryFactor: 0.5,
Gating: "per-head",
QKNormType: "rmsnorm",
LayerTypes: []string{"full_attention", "sliding_attention", "sliding_attention", "sliding_attention"},
NumAttentionHeadsPerLayer: []uint32{8, 16, 16, 16},
NumExperts: 32,
NumExpertsPerTok: 4,
MoEIntermediateSize: 64,
SharedExpertIntermediateSize: 64,
NormTopKProb: true,
MoeRoutedScalingFactor: 2.5,
MoERouterUseSigmoid: true,
DecoderSparseStep: 1,
MLPOnlyLayers: []uint32{0},
}
m.RopeParameters.RopeTheta = 500000
m.RopeParameters.RopeType = "yarn"
m.RopeParameters.Factor = 32
m.RopeParameters.OriginalMaxPositionEmbeddings = 4096
m.RopeParameters.BetaFast = 64
m.RopeParameters.BetaSlow = 1
m.SwaRopeParameters.RopeTheta = 10000
m.SwaRopeParameters.RopeType = "linear"
m.SwaRopeParameters.Factor = 1
m.SwaRopeParameters.PartialRotaryFactor = 1
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}, Template: "{% include 'chat_template.jinja' %}"})
required := []string{
"general.architecture",
"tokenizer.ggml.pre",
"laguna.block_count",
"laguna.context_length",
"laguna.embedding_length",
"laguna.feed_forward_length",
"laguna.attention.head_count",
"laguna.attention.head_count_kv",
"laguna.attention.key_length",
"laguna.attention.value_length",
"laguna.attention.layer_norm_rms_epsilon",
"laguna.attention.sliding_window",
"laguna.attention.layer_types",
"laguna.attention.sliding_window_pattern",
"laguna.attention.gating_type",
"laguna.attention.qk_norm",
"laguna.expert_count",
"laguna.expert_used_count",
"laguna.expert_feed_forward_length",
"laguna.expert_shared_feed_forward_length",
"laguna.expert_shared_count",
"laguna.expert_weights_norm",
"laguna.expert_weights_scale",
"laguna.expert_gating_func",
"laguna.leading_dense_block_count",
"laguna.dense_layers",
"laguna.rope.freq_base",
"laguna.rope.scaling.type",
"laguna.rope.scaling.factor",
"laguna.rope.partial_rotary_factor",
"laguna.rope.swa.freq_base",
"laguna.rope.swa.scaling.type",
"laguna.rope.dimension_count",
"laguna.rope.swa.dimension_count",
}
for _, k := range required {
if _, ok := kv[k]; !ok {
t.Errorf("missing required KV: %s", k)
}
}
if got := kv["general.architecture"]; got != "laguna" {
t.Errorf("architecture = %v, want laguna", got)
}
if got := kv["tokenizer.ggml.add_bos_token"]; got != false {
t.Errorf("tokenizer.ggml.add_bos_token = %v, want false", got)
}
if _, ok := kv["tokenizer.chat_template"]; ok {
t.Fatal("tokenizer.chat_template should be omitted for Laguna")
}
if got := kv["laguna.expert_gating_func"]; got != lagunaGatingFuncSigmoid {
t.Errorf("expert_gating_func = %v, want sigmoid(%d)", got, lagunaGatingFuncSigmoid)
}
if got := kv["laguna.leading_dense_block_count"]; got != uint32(1) {
t.Errorf("leading_dense_block_count = %v, want 1", got)
}
if got := kv["laguna.rope.dimension_count"]; got != uint32(8) {
t.Errorf("rope.dimension_count = %v, want 8", got)
}
if got := kv["laguna.rope.swa.dimension_count"]; got != uint32(16) {
t.Errorf("rope.swa.dimension_count = %v, want 16", got)
}
if got, ok := kv["laguna.attention.layer_types"].([]uint32); !ok || len(got) != 4 || got[0] != 0 || got[1] != 1 || got[2] != 1 || got[3] != 1 {
t.Fatalf("layer_types = %#v, want [0 1 1 1]", kv["laguna.attention.layer_types"])
}
if got, ok := kv["laguna.attention.sliding_window_pattern"].([]bool); !ok || len(got) != 4 || got[0] || !got[1] || !got[2] || !got[3] {
t.Fatalf("sliding_window_pattern = %#v, want [false true true true]", kv["laguna.attention.sliding_window_pattern"])
}
}
func TestLagunaKVYarnAttentionFactorFallback(t *testing.T) {
m := validLagunaTestModel()
m.RopeParameters.RopeType = "yarn"
m.RopeParameters.Factor = 32
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
got, ok := kv["laguna.rope.scaling.attn_factor"].(float32)
if !ok {
t.Fatalf("attn_factor type = %T, want float32", kv["laguna.rope.scaling.attn_factor"])
}
want := float32(0.1*math.Log(32) + 1)
if diff := math.Abs(float64(got - want)); diff > 1e-6 {
t.Fatalf("attn_factor = %v, want %v", got, want)
}
}

View File

@@ -3,6 +3,7 @@ package convert
import (
"cmp"
"encoding/json"
"errors"
"fmt"
"io/fs"
"math"
@@ -69,7 +70,415 @@ type nemotronHModel struct {
ExpertGroupUsedCount uint32 `json:"topk_group"`
}
type nemotronHNanoVLModel struct {
ModelParameters
MaxSequenceLength uint32 `json:"max_sequence_length"`
ForceImageSize uint32 `json:"force_image_size"`
DownsampleRatio float32 `json:"downsample_ratio"`
PatchSize uint32 `json:"patch_size"`
UseThumbnail *bool `json:"use_thumbnail"`
ImgContextTokenID uint32 `json:"img_context_token_id"`
ImgContextToken string `json:"img_context_token"`
ImgStartToken string `json:"img_start_token"`
ImgEndToken string `json:"img_end_token"`
VitHiddenSize uint32 `json:"vit_hidden_size"`
ProjectorHidden uint32 `json:"projector_hidden_size"`
SoundContextTokenID uint32 `json:"sound_context_token_id"`
SoundContextToken string `json:"sound_context_token"`
NormMean []float32 `json:"norm_mean"`
NormStd []float32 `json:"norm_std"`
VisionConfig radioConfig `json:"vision_config"`
SoundConfig soundConfig `json:"sound_config"`
LLMConfig nemotronHModel `json:"llm_config"`
Preprocessor struct {
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
DownsampleRatio float32 `json:"downsample_ratio"`
MaxNumTiles uint32 `json:"max_num_tiles"`
UseThumbnail *bool `json:"use_thumbnail"`
NormMean []float32 `json:"norm_mean"`
NormStd []float32 `json:"norm_std"`
}
}
type soundConfig struct {
ModelType string `json:"model_type"`
HiddenSize uint32 `json:"hidden_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
ConvKernelSize uint32 `json:"conv_kernel_size"`
SubsamplingConvChannels uint32 `json:"subsampling_conv_channels"`
SubsamplingConvKernelSize uint32 `json:"subsampling_conv_kernel_size"`
SubsamplingConvStride uint32 `json:"subsampling_conv_stride"`
SubsamplingFactor uint32 `json:"subsampling_factor"`
NumMelBins uint32 `json:"num_mel_bins"`
ProjectionHiddenSize uint32 `json:"projection_hidden_size"`
SamplingRate uint32 `json:"sampling_rate"`
ScaleInput bool `json:"scale_input"`
}
type radioConfig struct {
Version string `json:"version"`
PatchSize uint32 `json:"patch_size"`
MaxResolution uint32 `json:"max_resolution"`
MinNumPatches uint32 `json:"min_num_patches"`
MaxNumPatches uint32 `json:"max_num_patches"`
SeparateVideoEmbedder bool `json:"separate_video_embedder"`
Args struct {
MinNumPatches uint32 `json:"min_num_patches"`
MaxNumPatches uint32 `json:"max_num_patches"`
} `json:"args"`
}
var _ ModelConverter = (*nemotronHModel)(nil)
var _ ModelConverter = (*nemotronHNanoVLModel)(nil)
func (n *nemotronHNanoVLModel) parseMore(fsys fs.FS) error {
if n.MaxSequenceLength > 0 {
n.LLMConfig.MaxPositionEmbeddings = n.MaxSequenceLength
}
if err := n.LLMConfig.parseMore(fsys); err != nil {
return err
}
if bts, err := fs.ReadFile(fsys, "preprocessor_config.json"); err == nil {
if err := json.Unmarshal(bts, &n.Preprocessor); err != nil {
return fmt.Errorf("nemotron_h_omni: parse preprocessor_config.json: %w", err)
}
} else if !errors.Is(err, fs.ErrNotExist) {
return err
}
if version := strings.TrimSpace(n.VisionConfig.Version); version != "" && version != "radio_v2.5-h" {
return fmt.Errorf("nemotron_h_omni: unsupported RADIO version %q", version)
}
if patchSize := n.visionPatchSize(); patchSize != 16 {
return fmt.Errorf("nemotron_h_omni: unsupported vision patch_size=%d", patchSize)
}
if scale := n.visionProjectorScaleFactor(); scale != 2 {
return fmt.Errorf("nemotron_h_omni: unsupported vision projector scale factor=%d", scale)
}
if n.SoundConfig.NumHiddenLayers > 0 {
if modelType := strings.TrimSpace(n.SoundConfig.ModelType); modelType != "" && modelType != "parakeet" {
return fmt.Errorf("nemotron_h_omni: unsupported sound model_type %q", modelType)
}
if n.soundHiddenSize() == 0 {
return fmt.Errorf("nemotron_h_omni: sound hidden_size must be set")
}
if n.soundAttentionHeads() == 0 {
return fmt.Errorf("nemotron_h_omni: sound num_attention_heads must be set")
}
if n.soundSubsamplingFactor() != 8 {
return fmt.Errorf("nemotron_h_omni: unsupported sound subsampling_factor=%d", n.soundSubsamplingFactor())
}
if n.soundMelBins() != 128 {
return fmt.Errorf("nemotron_h_omni: unsupported sound num_mel_bins=%d", n.soundMelBins())
}
}
return nil
}
func (n *nemotronHNanoVLModel) KV(t *Tokenizer) KV {
kv := n.LLMConfig.KV(t)
kv["general.architecture"] = "nemotron_h_omni"
kv["vision.block_count"] = n.visionBlockCount()
kv["vision.embedding_length"] = n.visionEmbeddingLength()
kv["vision.feed_forward_length"] = n.visionFeedForwardLength()
kv["vision.attention.head_count"] = n.visionAttentionHeads()
kv["vision.attention.layer_norm_epsilon"] = float32(1e-6)
kv["vision.patch_size"] = n.visionPatchSize()
kv["vision.image_size"] = n.visionImageSize()
kv["vision.max_tiles"] = n.visionMaxTiles()
kv["vision.use_thumbnail"] = n.visionUseThumbnail()
if minPatches := n.visionMinNumPatches(); minPatches > 0 {
kv["vision.min_num_patches"] = minPatches
}
if maxPatches := n.visionMaxNumPatches(); maxPatches > 0 {
kv["vision.max_num_patches"] = maxPatches
}
kv["vision.num_channels"] = uint32(3)
kv["vision.image_mean"] = slices.Clone(defaultFloat32Slice(n.visionMean(), imageNetStandardMean))
kv["vision.image_std"] = slices.Clone(defaultFloat32Slice(n.visionStd(), imageNetStandardSTD))
kv["vision.projector.scale_factor"] = n.visionProjectorScaleFactor()
setTokenID := func(key string, explicit uint32, token string) {
if explicit > 0 {
kv[key] = explicit
return
}
if t == nil || t.Vocabulary == nil {
return
}
for i, v := range t.Vocabulary.Tokens {
if v == token {
kv[key] = uint32(i)
return
}
}
}
setTokenID("vision.image_token_id", n.ImgContextTokenID, cmp.Or(n.ImgContextToken, "<image>"))
setTokenID("vision.image_start_token_id", 0, cmp.Or(n.ImgStartToken, "<img>"))
setTokenID("vision.image_end_token_id", 0, cmp.Or(n.ImgEndToken, "</img>"))
if n.SoundConfig.NumHiddenLayers > 0 {
kv["audio.block_count"] = n.SoundConfig.NumHiddenLayers
kv["audio.embedding_length"] = n.soundHiddenSize()
kv["audio.feed_forward_length"] = n.soundFeedForwardLength()
kv["audio.attention.head_count"] = n.soundAttentionHeads()
kv["audio.attention.layer_norm_epsilon"] = float32(1e-5)
kv["audio.conv_kernel_size"] = n.soundConvKernelSize()
kv["audio.num_mel_bins"] = n.soundMelBins()
kv["audio.sample_rate"] = n.soundSampleRate()
kv["audio.subsampling_factor"] = n.soundSubsamplingFactor()
kv["audio.subsampling_conv_channels"] = n.soundSubsamplingConvChannels()
kv["audio.subsampling_conv_kernel_size"] = n.soundSubsamplingConvKernelSize()
kv["audio.subsampling_conv_stride"] = n.soundSubsamplingConvStride()
kv["audio.projection_hidden_size"] = n.soundProjectionHiddenSize()
kv["audio.scale_input"] = n.SoundConfig.ScaleInput
setTokenID("audio.sound_token_id", n.SoundContextTokenID, cmp.Or(n.SoundContextToken, "<so_embedding>"))
}
return kv
}
func (n *nemotronHNanoVLModel) Tensors(ts []Tensor) []*ggml.Tensor {
var textTensors []Tensor
var out []*ggml.Tensor
for _, t := range ts {
switch {
case isNemotronHNanoVLOmittedTensor(t.Name()):
continue
case strings.Contains(t.Name(), ".attn_qkv"):
out = append(out, slices.Collect(splitDim(t, 0,
split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")},
split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")},
split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")},
))...)
case t.Name() == "v.position_embd":
shape := t.Shape()
if len(shape) == 3 && shape[0] == 1 {
shape = shape[1:]
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
})
case strings.HasPrefix(t.Name(), "a.") || strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm."):
name := t.Name()
shape := slices.Clone(t.Shape())
if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, ".conv_dw.") && strings.HasSuffix(name, ".weight") && len(shape) == 3 {
t.SetRepacker(squeezeMiddleDim)
shape = []uint64{shape[0], shape[2]}
}
if strings.HasPrefix(name, "a.blk.") && (strings.Contains(name, ".conv_pw1.") || strings.Contains(name, ".conv_pw2.")) && strings.HasSuffix(name, ".weight") && len(shape) == 3 && shape[2] == 1 {
t.SetRepacker(squeezeLastDim)
shape = shape[:2]
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
})
default:
textTensors = append(textTensors, t)
}
}
return append(n.LLMConfig.Tensors(textTensors), out...)
}
func (n *nemotronHNanoVLModel) Replacements() []string {
return append([]string{
"language_model.", "",
"vision_model.radio_model.model.patch_generator.embedder", "v.patch_embd",
"vision_model.radio_model.model.patch_generator.pos_embed", "v.position_embd",
"vision_model.radio_model.model.patch_generator.cls_token.token", "v.cls_embd",
"vision_model.radio_model.model.blocks", "v.blk",
"attn.qkv", "attn_qkv",
"attn.proj", "attn_out",
"mlp.fc1", "ffn_up",
"mlp.fc2", "ffn_down",
"norm1", "ln1",
"norm2", "ln2",
"mlp1.0", "mm.norm",
"mlp1.1", "mm.1",
"mlp1.3", "mm.2",
"sound_encoder.encoder.feature_extractor.featurizer.fb", "a.feature_extractor.fb",
"sound_encoder.encoder.feature_extractor.featurizer.window", "a.feature_extractor.window",
"sound_encoder.encoder.subsampling.layers.0", "a.subsampling.conv0",
"sound_encoder.encoder.subsampling.layers.2", "a.subsampling.dw1",
"sound_encoder.encoder.subsampling.layers.3", "a.subsampling.pw1",
"sound_encoder.encoder.subsampling.layers.5", "a.subsampling.dw2",
"sound_encoder.encoder.subsampling.layers.6", "a.subsampling.pw2",
"sound_encoder.encoder.subsampling.linear", "a.subsampling.linear",
"sound_encoder.encoder.layers", "a.blk",
"feed_forward1.linear1", "ffn1_up",
"feed_forward1.linear2", "ffn1_down",
"feed_forward2.linear1", "ffn2_up",
"feed_forward2.linear2", "ffn2_down",
"norm_feed_forward1", "ffn1_norm",
"norm_feed_forward2", "ffn2_norm",
"norm_self_att", "attn_norm",
"norm_conv", "conv_norm",
"norm_out", "out_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_out",
"self_attn.relative_k_proj", "attn_rel_k",
"self_attn.bias_u", "attn_bias_u",
"self_attn.bias_v", "attn_bias_v",
"conv.pointwise_conv1", "conv_pw1",
"conv.pointwise_conv2", "conv_pw2",
"conv.depthwise_conv", "conv_dw",
"conv.norm", "conv_bn",
"sound_projection.norm", "mm.a.norm",
"sound_projection.linear1", "mm.a.1",
"sound_projection.linear2", "mm.a.2",
}, n.LLMConfig.Replacements()...)
}
func (n *nemotronHNanoVLModel) specialTokenTypes() []string {
return n.LLMConfig.specialTokenTypes()
}
func isNemotronHNanoVLOmittedTensor(name string) bool {
return strings.HasSuffix(name, ".conv_bn.num_batches_tracked") ||
strings.HasPrefix(name, "vision_model.radio_model.input_conditioner.") ||
strings.HasPrefix(name, "vision_model.radio_model.model.patch_generator.video_embedder")
}
func squeezeLastDim(_ string, data []float32, _ []uint64) ([]float32, error) {
return data, nil
}
func (n *nemotronHNanoVLModel) visionImageSize() uint32 {
return cmp.Or(n.ForceImageSize, n.Preprocessor.ImageSize, uint32(512))
}
func (n *nemotronHNanoVLModel) visionPatchSize() uint32 {
return cmp.Or(n.PatchSize, n.Preprocessor.PatchSize, n.VisionConfig.PatchSize, uint32(16))
}
func (n *nemotronHNanoVLModel) visionProjectorScaleFactor() uint32 {
ratio := cmp.Or(n.DownsampleRatio, n.Preprocessor.DownsampleRatio, float32(0.5))
if ratio <= 0 {
return 2
}
return max(uint32(1), uint32(math.Round(1.0/float64(ratio))))
}
func (n *nemotronHNanoVLModel) visionBlockCount() uint32 {
return 32
}
func (n *nemotronHNanoVLModel) visionEmbeddingLength() uint32 {
return cmp.Or(n.VitHiddenSize, uint32(1280))
}
func (n *nemotronHNanoVLModel) visionAttentionHeads() uint32 {
return 16
}
func (n *nemotronHNanoVLModel) visionFeedForwardLength() uint32 {
return 4 * n.visionEmbeddingLength()
}
func (n *nemotronHNanoVLModel) visionMaxTiles() uint32 {
return cmp.Or(n.Preprocessor.MaxNumTiles, uint32(12))
}
func (n *nemotronHNanoVLModel) visionMinNumPatches() uint32 {
return cmp.Or(n.VisionConfig.MinNumPatches, n.VisionConfig.Args.MinNumPatches)
}
func (n *nemotronHNanoVLModel) visionMaxNumPatches() uint32 {
return cmp.Or(n.VisionConfig.MaxNumPatches, n.VisionConfig.Args.MaxNumPatches)
}
func (n *nemotronHNanoVLModel) visionUseThumbnail() bool {
for _, v := range []*bool{n.UseThumbnail, n.Preprocessor.UseThumbnail} {
if v != nil {
return *v
}
}
return true
}
func (n *nemotronHNanoVLModel) visionMean() []float32 {
if len(n.NormMean) > 0 {
return n.NormMean
}
return n.Preprocessor.NormMean
}
func (n *nemotronHNanoVLModel) visionStd() []float32 {
if len(n.NormStd) > 0 {
return n.NormStd
}
return n.Preprocessor.NormStd
}
func (n *nemotronHNanoVLModel) soundHiddenSize() uint32 {
return cmp.Or(n.SoundConfig.HiddenSize, uint32(1024))
}
func (n *nemotronHNanoVLModel) soundAttentionHeads() uint32 {
return cmp.Or(n.SoundConfig.NumAttentionHeads, uint32(8))
}
func (n *nemotronHNanoVLModel) soundFeedForwardLength() uint32 {
return cmp.Or(n.SoundConfig.IntermediateSize, 4*n.soundHiddenSize())
}
func (n *nemotronHNanoVLModel) soundConvKernelSize() uint32 {
return cmp.Or(n.SoundConfig.ConvKernelSize, uint32(9))
}
func (n *nemotronHNanoVLModel) soundMelBins() uint32 {
return cmp.Or(n.SoundConfig.NumMelBins, uint32(128))
}
func (n *nemotronHNanoVLModel) soundSampleRate() uint32 {
return cmp.Or(n.SoundConfig.SamplingRate, uint32(16000))
}
func (n *nemotronHNanoVLModel) soundSubsamplingFactor() uint32 {
return cmp.Or(n.SoundConfig.SubsamplingFactor, uint32(8))
}
func (n *nemotronHNanoVLModel) soundSubsamplingConvChannels() uint32 {
return cmp.Or(n.SoundConfig.SubsamplingConvChannels, uint32(256))
}
func (n *nemotronHNanoVLModel) soundSubsamplingConvKernelSize() uint32 {
return cmp.Or(n.SoundConfig.SubsamplingConvKernelSize, uint32(3))
}
func (n *nemotronHNanoVLModel) soundSubsamplingConvStride() uint32 {
return cmp.Or(n.SoundConfig.SubsamplingConvStride, uint32(2))
}
func (n *nemotronHNanoVLModel) soundProjectionHiddenSize() uint32 {
return cmp.Or(n.SoundConfig.ProjectionHiddenSize, uint32(4096))
}
var (
imageNetStandardMean = []float32{0.48145466, 0.4578275, 0.40821073}
imageNetStandardSTD = []float32{0.26862954, 0.26130258, 0.27577711}
)
func (n *nemotronHModel) parseMore(_ fs.FS) error {
if n.NumHiddenLayers == 0 {

View File

@@ -217,6 +217,316 @@ func TestNemotronHLoadModelMetadata(t *testing.T) {
}
}
func TestNemotronHNanoVLLoadModelMetadata(t *testing.T) {
tempDir := t.TempDir()
config := `{
"architectures": ["NemotronH_Nano_VL_V2"],
"model_type": "NemotronH_Nano_VL_V2",
"max_sequence_length": 131072,
"force_image_size": 512,
"downsample_ratio": 0.5,
"patch_size": 16,
"use_thumbnail": true,
"img_context_token_id": 18,
"img_context_token": "<image>",
"img_start_token": "<img>",
"img_end_token": "</img>",
"sound_context_token_id": 27,
"sound_context_token": "<so_embedding>",
"vit_hidden_size": 1280,
"projector_hidden_size": 20480,
"norm_mean": [0.48145466, 0.4578275, 0.40821073],
"norm_std": [0.26862954, 0.26130258, 0.27577711],
"vision_config": {
"version": "radio_v2.5-h",
"patch_size": 16,
"max_resolution": 2048,
"separate_video_embedder": true
},
"sound_config": {
"model_type": "parakeet",
"hidden_size": 1024,
"num_attention_heads": 8,
"num_hidden_layers": 24,
"intermediate_size": 4096,
"conv_kernel_size": 9,
"subsampling_conv_channels": 256,
"subsampling_conv_kernel_size": 3,
"subsampling_conv_stride": 2,
"subsampling_factor": 8,
"num_mel_bins": 128,
"projection_hidden_size": 4096,
"sampling_rate": 16000
},
"llm_config": {
"architectures": ["NemotronHForCausalLM"],
"model_type": "nemotron_h",
"num_hidden_layers": 4,
"hidden_size": 512,
"max_position_embeddings": 262144,
"num_attention_heads": 8,
"num_key_value_heads": 2,
"head_dim": 64,
"layer_norm_epsilon": 1e-5,
"conv_kernel": 4,
"ssm_state_size": 128,
"mamba_num_heads": 16,
"mamba_head_dim": 32,
"n_groups": 8,
"hybrid_override_pattern": "ME*M",
"n_routed_experts": 16,
"num_experts_per_tok": 4,
"moe_intermediate_size": 256
}
}`
if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tempDir, "preprocessor_config.json"), []byte(`{
"image_size": 512,
"patch_size": 16,
"downsample_ratio": 0.5,
"max_num_tiles": 12,
"use_thumbnail": true,
"norm_mean": [0.48145466, 0.4578275, 0.40821073],
"norm_std": [0.26862954, 0.26130258, 0.27577711]
}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil {
t.Fatal(err)
}
conv, tokenizer, err := LoadModelMetadata(os.DirFS(tempDir))
if err != nil {
t.Fatal(err)
}
if _, ok := conv.(*nemotronHNanoVLModel); !ok {
t.Fatalf("unexpected converter type: %T", conv)
}
kv := conv.KV(tokenizer)
if got, want := kv["general.architecture"], "nemotron_h_omni"; got != want {
t.Fatalf("unexpected architecture: got %v want %v", got, want)
}
if got, want := kv["context_length"], uint32(131072); got != want {
t.Fatalf("unexpected context length: got %v want %v", got, want)
}
if got, want := kv["vision.block_count"], uint32(32); got != want {
t.Fatalf("unexpected vision block count: got %v want %v", got, want)
}
if got, want := kv["vision.image_size"], uint32(512); got != want {
t.Fatalf("unexpected vision image size: got %v want %v", got, want)
}
if got, want := kv["vision.projector.scale_factor"], uint32(2); got != want {
t.Fatalf("unexpected projector scale factor: got %v want %v", got, want)
}
if got, want := kv["audio.block_count"], uint32(24); got != want {
t.Fatalf("unexpected audio block count: got %v want %v", got, want)
}
if got, want := kv["audio.sound_token_id"], uint32(27); got != want {
t.Fatalf("unexpected audio token id: got %v want %v", got, want)
}
if got, want := kv["audio.subsampling_factor"], uint32(8); got != want {
t.Fatalf("unexpected audio subsampling factor: got %v want %v", got, want)
}
}
func TestNemotronHNanoOmniReasoningV3LoadModelMetadata(t *testing.T) {
tempDir := t.TempDir()
config := `{
"architectures": ["NemotronH_Nano_Omni_Reasoning_V3"],
"model_type": "NemotronH_Nano_Omni_Reasoning_V3",
"max_sequence_length": 131072,
"force_image_size": 512,
"downsample_ratio": 0.5,
"patch_size": 16,
"img_context_token_id": 18,
"img_context_token": "<image>",
"img_start_token": "<img>",
"img_end_token": "</img>",
"sound_context_token_id": 27,
"sound_context_token": "<so_embedding>",
"vit_hidden_size": 1280,
"projector_hidden_size": 4096,
"vision_config": {
"version": "radio_v2.5-h",
"patch_size": 16,
"min_num_patches": 1024,
"max_num_patches": 13312,
"args": {
"min_num_patches": 1024,
"max_num_patches": 13312
}
},
"sound_config": {
"model_type": "parakeet",
"hidden_size": 1024,
"num_attention_heads": 8,
"num_hidden_layers": 24,
"intermediate_size": 4096,
"conv_kernel_size": 9,
"subsampling_conv_channels": 256,
"subsampling_conv_kernel_size": 3,
"subsampling_conv_stride": 2,
"subsampling_factor": 8,
"num_mel_bins": 128,
"projection_hidden_size": 4096,
"sampling_rate": 16000
},
"llm_config": {
"architectures": ["NemotronHForCausalLM"],
"model_type": "nemotron_h",
"num_hidden_layers": 4,
"hidden_size": 512,
"max_position_embeddings": 262144,
"num_attention_heads": 8,
"num_key_value_heads": 2,
"head_dim": 64,
"layer_norm_epsilon": 1e-5,
"conv_kernel": 4,
"ssm_state_size": 128,
"mamba_num_heads": 16,
"mamba_head_dim": 32,
"n_groups": 8,
"hybrid_override_pattern": "ME*M",
"n_routed_experts": 16,
"num_experts_per_tok": 4,
"moe_intermediate_size": 256
}
}`
if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil {
t.Fatal(err)
}
conv, tokenizer, err := LoadModelMetadata(os.DirFS(tempDir))
if err != nil {
t.Fatal(err)
}
if _, ok := conv.(*nemotronHNanoVLModel); !ok {
t.Fatalf("unexpected converter type: %T", conv)
}
kv := conv.KV(tokenizer)
if got, want := kv["general.architecture"], "nemotron_h_omni"; got != want {
t.Fatalf("unexpected architecture: got %v want %v", got, want)
}
if got, want := kv["vision.block_count"], uint32(32); got != want {
t.Fatalf("unexpected vision block count: got %v want %v", got, want)
}
if got, want := kv["vision.min_num_patches"], uint32(1024); got != want {
t.Fatalf("unexpected vision min patches: got %v want %v", got, want)
}
if got, want := kv["vision.max_num_patches"], uint32(13312); got != want {
t.Fatalf("unexpected vision max patches: got %v want %v", got, want)
}
if got, want := kv["audio.block_count"], uint32(24); got != want {
t.Fatalf("unexpected audio block count: got %v want %v", got, want)
}
if got, want := kv["audio.sound_token_id"], uint32(27); got != want {
t.Fatalf("unexpected audio token id: got %v want %v", got, want)
}
}
func TestNemotronHNanoVLTensorsRetainVisionAndAudio(t *testing.T) {
m := &nemotronHNanoVLModel{
LLMConfig: nemotronHModel{NGroups: 8},
}
in := []Tensor{
&fakeTensor{
name: "blk.0.ssm_a",
shape: []uint64{4},
data: []float32{0, 1, 2, 3},
},
&fakeTensor{name: "v.blk.0.attn_qkv.weight", shape: []uint64{3840, 1280}},
&fakeTensor{name: "v.position_embd", shape: []uint64{1, 16384, 1280}},
&fakeTensor{name: "v.cls_embd", shape: []uint64{10, 1280}},
&fakeTensor{name: "mm.norm.weight", shape: []uint64{5120}},
&fakeTensor{name: "a.feature_extractor.fb", shape: []uint64{1, 128, 257}},
&fakeTensor{name: "a.subsampling.dw1.weight", shape: []uint64{256, 1, 3, 3}},
&fakeTensor{name: "a.blk.0.conv_dw.weight", shape: []uint64{1024, 1, 9}},
&fakeTensor{name: "a.blk.0.conv_pw1.weight", shape: []uint64{2048, 1024, 1}},
&fakeTensor{name: "a.blk.0.conv_bn.num_batches_tracked", shape: []uint64{1}},
&fakeTensor{name: "mm.a.1.weight", shape: []uint64{4096, 1024}},
}
out := m.Tensors(in)
got := map[string][]uint64{}
for _, tns := range out {
got[tns.Name] = tns.Shape
}
for _, name := range []string{
"blk.0.ssm_a",
"v.blk.0.attn_q.weight",
"v.blk.0.attn_k.weight",
"v.blk.0.attn_v.weight",
"v.position_embd",
"v.cls_embd",
"mm.norm.weight",
"a.feature_extractor.fb",
"a.subsampling.dw1.weight",
"a.blk.0.conv_dw.weight",
"a.blk.0.conv_pw1.weight",
"mm.a.1.weight",
} {
if _, ok := got[name]; !ok {
t.Fatalf("expected tensor %q in output", name)
}
}
if gotShape, want := got["blk.0.ssm_a"], []uint64{4, 1}; !slices.Equal(gotShape, want) {
t.Fatalf("unexpected ssm_a shape: got %v want %v", gotShape, want)
}
if gotShape, want := got["v.position_embd"], []uint64{16384, 1280}; !slices.Equal(gotShape, want) {
t.Fatalf("unexpected position embedding shape: got %v want %v", gotShape, want)
}
if gotShape, want := got["a.blk.0.conv_dw.weight"], []uint64{1024, 9}; !slices.Equal(gotShape, want) {
t.Fatalf("unexpected audio conv_dw shape: got %v want %v", gotShape, want)
}
if gotShape, want := got["a.blk.0.conv_pw1.weight"], []uint64{2048, 1024}; !slices.Equal(gotShape, want) {
t.Fatalf("unexpected audio conv_pw1 shape: got %v want %v", gotShape, want)
}
if _, ok := got["a.blk.0.conv_bn.num_batches_tracked"]; ok {
t.Fatal("audio batchnorm num_batches_tracked should be omitted")
}
}
func TestNemotronHNanoVLReplacements(t *testing.T) {
m := &nemotronHNanoVLModel{}
r := strings.NewReplacer(m.Replacements()...)
if got, want := r.Replace("language_model.backbone.layers.1.mixer.fc1_latent_proj.weight"), "blk.1.ffn_latent_in.weight"; got != want {
t.Fatalf("unexpected fc1 replacement: got %q want %q", got, want)
}
if got, want := r.Replace("language_model.lm_head.weight"), "output.weight"; got != want {
t.Fatalf("unexpected lm_head replacement: got %q want %q", got, want)
}
if got, want := r.Replace("vision_model.radio_model.model.blocks.0.attn.qkv.weight"), "v.blk.0.attn_qkv.weight"; got != want {
t.Fatalf("unexpected vision replacement: got %q want %q", got, want)
}
if got, want := r.Replace("mlp1.1.weight"), "mm.1.weight"; got != want {
t.Fatalf("unexpected projector replacement: got %q want %q", got, want)
}
if got, want := r.Replace("sound_encoder.encoder.layers.0.self_attn.q_proj.weight"), "a.blk.0.attn_q.weight"; got != want {
t.Fatalf("unexpected audio q_proj replacement: got %q want %q", got, want)
}
if got, want := r.Replace("sound_encoder.encoder.layers.0.conv.pointwise_conv1.weight"), "a.blk.0.conv_pw1.weight"; got != want {
t.Fatalf("unexpected audio conv replacement: got %q want %q", got, want)
}
if got, want := r.Replace("sound_projection.linear2.weight"), "mm.a.2.weight"; got != want {
t.Fatalf("unexpected audio projector replacement: got %q want %q", got, want)
}
}
func TestNemotronHReplacementsLatentProjections(t *testing.T) {
m := &nemotronHModel{}
r := strings.NewReplacer(m.Replacements()...)

View File

@@ -42,8 +42,10 @@ func (t tensorBase) Kind() uint32 {
strings.HasSuffix(t.name, ".bias") ||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
strings.HasPrefix(t.name, "a.conv1d.") || // audio SSCP conv weights must be F32 for im2col
strings.Contains(t.name, ".conv_dw.") || // audio depthwise conv weights must be F32
strings.HasPrefix(t.name, "a.feature_extractor.") || // audio feature-extractor constants are read with BackendGet and must be real F32 values
strings.HasPrefix(t.name, "a.conv1d.") || // audio SSCP conv weights are kept F32 for im2col; this likely slows audio and should be revisited
strings.HasPrefix(t.name, "a.subsampling.") || // audio Parakeet subsampling weights are kept F32 for conv/linear stability; this likely slows audio and should be revisited
strings.Contains(t.name, ".conv_dw.") || // audio depthwise conv weights are kept F32; this likely slows audio and should be revisited
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" ||
t.name == "v.position_embd.weight" ||

View File

@@ -5,10 +5,12 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"maps"
"math"
"slices"
"strings"
@@ -23,6 +25,11 @@ type safetensorMetadata struct {
}
func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]Tensor, error) {
fp8Block, err := safetensorsFP8BlockSize(fsys)
if err != nil {
return nil, err
}
var ts []Tensor
for _, p := range ps {
f, err := fsys.Open(p)
@@ -50,24 +57,47 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
names := make(map[string]struct{}, len(keys))
fp8Scales, err := collectSafetensorsFP8Scales(n, headers)
if err != nil {
return nil, err
}
for _, key := range keys {
if value := headers[key]; value.Type != "" {
if _, ok := fp8Scales.consumed[key]; ok {
continue
}
// Scalar tensors (e.g. clipped linear min/max) are 0-dim in safetensors.
// Promote them to 1-dim so they can be stored in GGUF.
if len(value.Shape) == 0 {
value.Shape = []uint64{1}
}
var scale *safetensorScale
if value.Type == "F8_E4M3" {
if !fp8Block.ok {
return nil, fmt.Errorf("missing fp8 block size metadata for tensor %q", key)
}
scale = fp8Scales.byWeight[key]
if scale == nil {
return nil, fmt.Errorf("missing fp8 scale companion for tensor %q", key)
}
}
ggufName := replacer.Replace(key)
if _, ok := names[ggufName]; ok {
return nil, fmt.Errorf("duplicate tensor name '%s' was found for this model", ggufName)
}
names[ggufName] = struct{}{}
ts = append(ts, safetensor{
fs: fsys,
path: p,
dtype: value.Type,
offset: safetensorsPad(n, value.Offsets[0]),
size: safetensorsPad(n, value.Offsets[1]) - safetensorsPad(n, value.Offsets[0]),
fs: fsys,
path: p,
dtype: value.Type,
offset: safetensorsPad(n, value.Offsets[0]),
size: safetensorsPad(n, value.Offsets[1]) - safetensorsPad(n, value.Offsets[0]),
scale: scale,
fp8Block: fp8Block,
tensorBase: &tensorBase{
name: ggufName,
shape: value.Shape,
@@ -85,12 +115,22 @@ func safetensorsPad(n, offset int64) int64 {
return 8 + n + offset
}
type safetensor struct {
fs fs.FS
path string
type safetensorScale struct {
name string
dtype string
shape []uint64
offset int64
size int64
}
type safetensor struct {
fs fs.FS
path string
dtype string
offset int64
size int64
scale *safetensorScale
fp8Block safetensorFP8BlockSize
*tensorBase
}
@@ -104,17 +144,26 @@ func (st safetensor) Kind() uint32 {
kind != tensorKindFP32 {
kind = tensorKindBF16
}
if st.dtype == "F8_E4M3" && kind != tensorKindFP32 {
kind = tensorKindBF16
}
return kind
}
func (st safetensor) SourceDType() string {
return st.dtype
}
func (st safetensor) Clone() Tensor {
return &safetensor{
fs: st.fs,
path: st.path,
dtype: st.dtype,
offset: st.offset,
size: st.size,
fs: st.fs,
path: st.path,
dtype: st.dtype,
offset: st.offset,
size: st.size,
scale: st.scale.Clone(),
fp8Block: st.fp8Block,
tensorBase: &tensorBase{
name: st.name,
repacker: st.repacker,
@@ -123,6 +172,19 @@ func (st safetensor) Clone() Tensor {
}
}
func (ss *safetensorScale) Clone() *safetensorScale {
if ss == nil {
return nil
}
return &safetensorScale{
name: ss.name,
dtype: ss.dtype,
shape: slices.Clone(ss.shape),
offset: ss.offset,
size: ss.size,
}
}
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
f, err := st.fs.Open(st.path)
if err != nil {
@@ -180,6 +242,16 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
}
f32s = bfloat16.DecodeFloat32(u8s)
case "F8_E4M3":
u8s := make([]uint8, st.size)
if err = binary.Read(br, binary.LittleEndian, u8s); err != nil {
return 0, err
}
f32s, err = st.decodeFP8E4M3(u8s)
if err != nil {
return 0, err
}
default:
return 0, fmt.Errorf("unknown data type: %s", st.dtype)
}
@@ -208,3 +280,334 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
}
}
type safetensorsFP8Scales struct {
byWeight map[string]*safetensorScale
consumed map[string]struct{}
}
func collectSafetensorsFP8Scales(n int64, headers map[string]safetensorMetadata) (safetensorsFP8Scales, error) {
scales := safetensorsFP8Scales{
byWeight: make(map[string]*safetensorScale),
consumed: make(map[string]struct{}),
}
for key, value := range headers {
if value.Type != "F8_E4M3" {
continue
}
scaleKey, scaleValue, ok, err := safetensorsFP8Scale(key, headers)
if err != nil {
return safetensorsFP8Scales{}, err
}
if !ok {
continue
}
if _, ok := scales.consumed[scaleKey]; ok {
return safetensorsFP8Scales{}, fmt.Errorf("fp8 scale companion %q is used by multiple tensors", scaleKey)
}
scales.byWeight[key] = &safetensorScale{
name: scaleKey,
dtype: scaleValue.Type,
shape: slices.Clone(scaleValue.Shape),
offset: safetensorsPad(n, scaleValue.Offsets[0]),
size: safetensorsPad(n, scaleValue.Offsets[1]) - safetensorsPad(n, scaleValue.Offsets[0]),
}
scales.consumed[scaleKey] = struct{}{}
}
return scales, nil
}
func safetensorsFP8Scale(key string, headers map[string]safetensorMetadata) (string, safetensorMetadata, bool, error) {
candidates := safetensorsFP8ScaleCandidates(key)
var scaleKey string
var scaleValue safetensorMetadata
if strings.HasSuffix(key, ".weight") {
// Keep support for compressed-tensors exports that place the scale name
// between the module path and weight suffix.
base := strings.TrimSuffix(key, ".weight")
candidates = appendUnique(candidates, base+".weight_scale")
candidates = appendUnique(candidates, base+".weight_scale_inv")
}
for _, candidate := range candidates {
if value, ok := headers[candidate]; ok && value.Type != "" {
if scaleKey != "" {
return "", safetensorMetadata{}, false, fmt.Errorf("multiple fp8 scale companions for tensor %q: %q and %q", key, scaleKey, candidate)
}
scaleKey = candidate
scaleValue = value
}
}
if scaleKey == "" {
return "", safetensorMetadata{}, false, nil
}
return scaleKey, scaleValue, true, nil
}
func safetensorsFP8ScaleCandidates(key string) []string {
var candidates []string
candidates = appendUnique(candidates, key+"_scale")
candidates = appendUnique(candidates, key+"_scale_inv")
candidates = appendUnique(candidates, key+".scale")
candidates = appendUnique(candidates, key+".scale_inv")
return candidates
}
func appendUnique(values []string, value string) []string {
if !slices.Contains(values, value) {
values = append(values, value)
}
return values
}
type safetensorFP8BlockSize struct {
rows int
cols int
ok bool
}
type safetensorsSourceQuantization struct {
QuantMethod string `json:"quant_method"`
Format string `json:"format"`
WeightBlockSize []int `json:"weight_block_size"`
ConfigGroups map[string]struct {
Format string `json:"format"`
Weights struct {
BlockStructure []int `json:"block_structure"`
NumBits int `json:"num_bits"`
Type string `json:"type"`
} `json:"weights"`
} `json:"config_groups"`
}
type safetensorsModelConfig struct {
Quantization safetensorsSourceQuantization `json:"quantization"`
QuantizationConfig safetensorsSourceQuantization `json:"quantization_config"`
CompressionConfig safetensorsSourceQuantization `json:"compression_config"`
TextConfig struct {
Quantization safetensorsSourceQuantization `json:"quantization"`
QuantizationConfig safetensorsSourceQuantization `json:"quantization_config"`
CompressionConfig safetensorsSourceQuantization `json:"compression_config"`
} `json:"text_config"`
}
func safetensorsFP8BlockSize(fsys fs.FS) (safetensorFP8BlockSize, error) {
bts, err := fs.ReadFile(fsys, "config.json")
if errors.Is(err, fs.ErrNotExist) {
return safetensorFP8BlockSize{}, nil
}
if err != nil {
return safetensorFP8BlockSize{}, err
}
bts = sanitizeNonFiniteJSON(bts)
var cfg safetensorsModelConfig
if err := json.Unmarshal(bts, &cfg); err != nil {
return safetensorFP8BlockSize{}, fmt.Errorf("parse config.json fp8 metadata: %w", err)
}
var blocks []safetensorFP8BlockSize
for _, q := range []safetensorsSourceQuantization{
cfg.Quantization,
cfg.QuantizationConfig,
cfg.CompressionConfig,
cfg.TextConfig.Quantization,
cfg.TextConfig.QuantizationConfig,
cfg.TextConfig.CompressionConfig,
} {
if strings.EqualFold(q.QuantMethod, "fp8") && len(q.WeightBlockSize) == 2 {
block, err := newSafetensorFP8BlockSize(q.WeightBlockSize[0], q.WeightBlockSize[1])
if err != nil {
return safetensorFP8BlockSize{}, err
}
blocks = append(blocks, block)
}
if !strings.EqualFold(q.QuantMethod, "compressed-tensors") && !strings.EqualFold(q.Format, "float-quantized") {
continue
}
for _, group := range q.ConfigGroups {
if !strings.EqualFold(group.Format, "float-quantized") ||
group.Weights.NumBits != 8 ||
!strings.EqualFold(group.Weights.Type, "float") ||
len(group.Weights.BlockStructure) != 2 {
continue
}
block, err := newSafetensorFP8BlockSize(group.Weights.BlockStructure[0], group.Weights.BlockStructure[1])
if err != nil {
return safetensorFP8BlockSize{}, err
}
blocks = append(blocks, block)
}
}
if len(blocks) == 0 {
return safetensorFP8BlockSize{}, nil
}
block := blocks[0]
for _, other := range blocks[1:] {
if other.rows != block.rows || other.cols != block.cols {
return safetensorFP8BlockSize{}, fmt.Errorf("multiple fp8 block sizes in config.json: %dx%d and %dx%d", block.rows, block.cols, other.rows, other.cols)
}
}
return block, nil
}
func newSafetensorFP8BlockSize(rows, cols int) (safetensorFP8BlockSize, error) {
if rows <= 0 || cols <= 0 {
return safetensorFP8BlockSize{}, fmt.Errorf("invalid fp8 block size %dx%d", rows, cols)
}
return safetensorFP8BlockSize{rows: rows, cols: cols, ok: true}, nil
}
func (st safetensor) decodeFP8E4M3(data []byte) ([]float32, error) {
if st.scale == nil {
return nil, fmt.Errorf("missing fp8 scale companion for tensor %q", st.name)
}
if !st.fp8Block.ok {
return nil, fmt.Errorf("missing fp8 block size metadata for tensor %q", st.name)
}
if len(st.shape) != 2 {
return nil, fmt.Errorf("expected 2D fp8 tensor %q, got shape %v", st.name, st.shape)
}
rows, cols := int(st.shape[0]), int(st.shape[1])
if rows < 0 || cols < 0 || rows*cols != len(data) {
return nil, fmt.Errorf("fp8 tensor %q shape %v does not match %d bytes", st.name, st.shape, len(data))
}
scale, err := st.readScale()
if err != nil {
return nil, err
}
if len(st.scale.shape) != 2 {
return nil, fmt.Errorf("expected 2D fp8 scale tensor %q, got shape %v", st.scale.name, st.scale.shape)
}
blockRows := st.fp8Block.rows
blockCols := st.fp8Block.cols
scaleRows, scaleCols := int(st.scale.shape[0]), int(st.scale.shape[1])
expectedRows := (rows + blockRows - 1) / blockRows
expectedCols := (cols + blockCols - 1) / blockCols
if scaleRows != expectedRows || scaleCols != expectedCols {
return nil, fmt.Errorf("unexpected fp8 scale shape %v for tensor %q shape %v; want [%d %d]", st.scale.shape, st.name, st.shape, expectedRows, expectedCols)
}
if len(scale) != scaleRows*scaleCols {
return nil, fmt.Errorf("fp8 scale tensor %q shape %v does not match decoded length %d", st.scale.name, st.scale.shape, len(scale))
}
f32s := make([]float32, len(data))
for r := range rows {
scaleRow := r / blockRows
rowOffset := r * cols
for c := range cols {
f32s[rowOffset+c] = decodeFloat8E4M3FN(data[rowOffset+c]) * scale[scaleRow*scaleCols+c/blockCols]
}
}
return f32s, nil
}
func (st safetensor) readScale() ([]float32, error) {
r, err := st.sectionReader(st.scale.offset, st.scale.size)
if err != nil {
return nil, fmt.Errorf("failed to read fp8 scale tensor %q: %w", st.scale.name, err)
}
if closer, ok := r.(io.Closer); ok {
defer closer.Close()
}
br := bufio.NewReaderSize(r, min(32<<10, int(st.scale.size)))
switch st.scale.dtype {
case "F32":
f32s := make([]float32, st.scale.size/4)
if err := binary.Read(br, binary.LittleEndian, f32s); err != nil {
return nil, err
}
return f32s, nil
case "F16":
u16s := make([]uint16, st.scale.size/2)
if err := binary.Read(br, binary.LittleEndian, u16s); err != nil {
return nil, err
}
f32s := make([]float32, len(u16s))
for i := range u16s {
f32s[i] = float16.Frombits(u16s[i]).Float32()
}
return f32s, nil
case "BF16":
u8s := make([]uint8, st.scale.size)
if err := binary.Read(br, binary.LittleEndian, u8s); err != nil {
return nil, err
}
return bfloat16.DecodeFloat32(u8s), nil
default:
return nil, fmt.Errorf("unsupported fp8 scale dtype %q for tensor %q", st.scale.dtype, st.scale.name)
}
}
func (st safetensor) sectionReader(offset, size int64) (io.Reader, error) {
f, err := st.fs.Open(st.path)
if err != nil {
return nil, err
}
if readerAt, ok := f.(io.ReaderAt); ok {
return &readCloserReader{
Reader: io.NewSectionReader(readerAt, offset, size),
Closer: f,
}, nil
}
if seeker, ok := f.(io.Seeker); ok {
if _, err := seeker.Seek(offset, io.SeekStart); err != nil {
f.Close()
return nil, err
}
return &readCloserReader{
Reader: io.LimitReader(f, size),
Closer: f,
}, nil
}
if _, err := io.CopyN(io.Discard, f, offset); err != nil {
f.Close()
return nil, err
}
return &readCloserReader{
Reader: io.LimitReader(f, size),
Closer: f,
}, nil
}
type readCloserReader struct {
io.Reader
io.Closer
}
func decodeFloat8E4M3FN(v byte) float32 {
sign := float32(1)
if v&0x80 != 0 {
sign = -1
}
exp := int((v >> 3) & 0x0f)
mant := int(v & 0x07)
if exp == 0 {
if mant == 0 {
return 0 * sign
}
return sign * float32(math.Ldexp(float64(mant)/8, -6))
}
if exp == 0x0f && mant == 0x07 {
return float32(math.NaN())
}
return sign * float32(math.Ldexp(1+float64(mant)/8, exp-7))
}

View File

@@ -3,8 +3,10 @@ package convert
import (
"bytes"
"encoding/binary"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"github.com/d4l3k/go-bfloat16"
@@ -231,6 +233,222 @@ func TestSafetensors(t *testing.T) {
}
}
func TestSafetensorWriteToFP8E4M3(t *testing.T) {
root, err := os.OpenRoot(t.TempDir())
if err != nil {
t.Fatal(err)
}
defer root.Close()
path := filepath.Base(t.Name())
f, err := root.Create(path)
if err != nil {
t.Fatal(err)
}
// E4M3FN encodings for 1.0, 2.0, 0.5, and -1.0.
if _, err := f.Write([]byte{0x38, 0x40, 0x30, 0xb8}); err != nil {
t.Fatal(err)
}
if _, err := f.Write(bfloat16.EncodeFloat32([]float32{2})); err != nil {
t.Fatal(err)
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
st := safetensor{
fs: root.FS(),
path: path,
dtype: "F8_E4M3",
offset: 0,
size: 4,
fp8Block: safetensorFP8BlockSize{rows: 128, cols: 128, ok: true},
scale: &safetensorScale{
name: "linear.weight_scale",
dtype: "BF16",
shape: []uint64{1, 1},
offset: 4,
size: 2,
},
tensorBase: &tensorBase{
name: "linear.weight",
shape: []uint64{2, 2},
},
}
var b bytes.Buffer
if _, err := st.WriteTo(&b); err != nil {
t.Fatal(err)
}
want := bfloat16.EncodeFloat32([]float32{2, 4, 1, -2})
if diff := cmp.Diff(want, b.Bytes()); diff != "" {
t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff)
}
}
func TestSafetensorWriteToFP8E4M3UsesConfiguredBlockSize(t *testing.T) {
root, err := os.OpenRoot(t.TempDir())
if err != nil {
t.Fatal(err)
}
defer root.Close()
path := filepath.Base(t.Name())
f, err := root.Create(path)
if err != nil {
t.Fatal(err)
}
if _, err := f.Write(bytes.Repeat([]byte{0x38}, 12)); err != nil {
t.Fatal(err)
}
if _, err := f.Write(bfloat16.EncodeFloat32([]float32{1, 2, 3, 4})); err != nil {
t.Fatal(err)
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
st := safetensor{
fs: root.FS(),
path: path,
dtype: "F8_E4M3",
offset: 0,
size: 12,
fp8Block: safetensorFP8BlockSize{rows: 2, cols: 3, ok: true},
scale: &safetensorScale{
name: "linear.weight_scale",
dtype: "BF16",
shape: []uint64{2, 2},
offset: 12,
size: 8,
},
tensorBase: &tensorBase{
name: "linear.weight",
shape: []uint64{3, 4},
},
}
var b bytes.Buffer
if _, err := st.WriteTo(&b); err != nil {
t.Fatal(err)
}
want := bfloat16.EncodeFloat32([]float32{
1, 1, 1, 2,
1, 1, 1, 2,
3, 3, 3, 4,
})
if diff := cmp.Diff(want, b.Bytes()); diff != "" {
t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff)
}
}
func TestParseSafetensorsConsumesFP8ScaleCompanion(t *testing.T) {
tempDir := t.TempDir()
generateSafetensorTestData(t, tempDir, map[string]*tensorData{
"linear.weight": {
Offsets: []int{0, 4},
Type: "F8_E4M3",
Shape: []int{2, 2},
},
"linear.weight_scale": {
Offsets: []int{4, 6},
Type: "BF16",
Shape: []int{1, 1},
},
})
writeFP8BlockConfig(t, tempDir, 128, 128)
tensors, err := parseSafetensors(os.DirFS(tempDir), strings.NewReplacer(), "model-00001-of-00001.safetensors")
if err != nil {
t.Fatal(err)
}
if len(tensors) != 1 {
t.Fatalf("expected one tensor, got %d", len(tensors))
}
if got := tensors[0].Name(); got != "linear.weight" {
t.Fatalf("unexpected tensor name %q", got)
}
if got := tensors[0].Kind(); got != tensorKindBF16 {
t.Fatalf("unexpected fp8 converted kind %d, want %d", got, tensorKindBF16)
}
}
func TestParseSafetensorsRejectsFP8WithoutBlockMetadata(t *testing.T) {
tempDir := t.TempDir()
generateSafetensorTestData(t, tempDir, map[string]*tensorData{
"linear.weight": {
Offsets: []int{0, 4},
Type: "F8_E4M3",
Shape: []int{2, 2},
},
"linear.weight_scale": {
Offsets: []int{4, 6},
Type: "BF16",
Shape: []int{1, 1},
},
})
_, err := parseSafetensors(os.DirFS(tempDir), strings.NewReplacer(), "model-00001-of-00001.safetensors")
if err == nil || !strings.Contains(err.Error(), "missing fp8 block size metadata") {
t.Fatalf("expected missing fp8 block size metadata error, got %v", err)
}
}
func TestParseSafetensorsRejectsAmbiguousFP8ScaleCompanion(t *testing.T) {
tempDir := t.TempDir()
generateSafetensorTestData(t, tempDir, map[string]*tensorData{
"linear.weight": {
Offsets: []int{0, 4},
Type: "F8_E4M3",
Shape: []int{2, 2},
},
"linear.weight_scale": {
Offsets: []int{4, 6},
Type: "BF16",
Shape: []int{1, 1},
},
"linear.weight.scale": {
Offsets: []int{6, 8},
Type: "BF16",
Shape: []int{1, 1},
},
})
writeFP8BlockConfig(t, tempDir, 128, 128)
_, err := parseSafetensors(os.DirFS(tempDir), strings.NewReplacer(), "model-00001-of-00001.safetensors")
if err == nil || !strings.Contains(err.Error(), "multiple fp8 scale companions") {
t.Fatalf("expected ambiguous fp8 scale companion error, got %v", err)
}
}
func writeFP8BlockConfig(t *testing.T, dir string, rows, cols int) {
t.Helper()
config := fmt.Sprintf(`{
"architectures": ["GenericForCausalLM"],
"compression_config": {
"format": "float-quantized",
"config_groups": {
"group_0": {
"format": "float-quantized",
"weights": {
"type": "float",
"num_bits": 8,
"block_structure": [%d, %d]
}
}
}
}
}`, rows, cols)
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(config), 0o644); err != nil {
t.Fatal(err)
}
}
func TestSafetensorKind(t *testing.T) {
tests := []struct {
name string
@@ -259,6 +477,17 @@ func TestSafetensorKind(t *testing.T) {
},
expected: tensorKindFP16,
},
{
name: "BF16 audio feature extractor constants should return FP32",
st: safetensor{
tensorBase: &tensorBase{
name: "a.feature_extractor.fb",
shape: []uint64{1, 128, 257},
},
dtype: "BF16",
},
expected: tensorKindFP32,
},
{
name: "BF16 dtype with FP32 base kind should return FP32",
st: safetensor{

View File

@@ -5,6 +5,7 @@ import (
"errors"
"io"
"iter"
"maps"
"path"
"slices"
"strconv"
@@ -153,3 +154,54 @@ func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
return 0, nil
}
func sourceTensorKV(ts []*ggml.Tensor) KV {
sourceFP8 := make(map[string]struct{})
for _, t := range ts {
if writerSourceDType(t.WriterTo) == "F8_E4M3" {
sourceFP8[t.Name] = struct{}{}
}
}
if len(sourceFP8) == 0 {
return nil
}
return KV{
"source_quantization": "hf_fp8",
"source_fp8_tensors": slices.Sorted(maps.Keys(sourceFP8)),
}
}
type sourceDTypeTensor interface {
SourceDType() string
}
func writerSourceDType(w io.WriterTo) string {
switch w := w.(type) {
case sourceDTypeTensor:
return w.SourceDType()
case mergeGroup:
if len(w) == 0 {
return ""
}
dtype := sourceDType(w[0])
if dtype == "" {
return ""
}
for _, t := range w[1:] {
if sourceDType(t) != dtype {
return ""
}
}
return dtype
default:
return ""
}
}
func sourceDType(t Tensor) string {
if t, ok := t.(sourceDTypeTensor); ok {
return t.SourceDType()
}
return ""
}

View File

@@ -21,7 +21,8 @@ type fakeTensor struct {
shape []uint64
data []float32
repacker Repacker
sourceDType string
repacker Repacker
}
func (f fakeTensor) Name() string {
@@ -36,16 +37,21 @@ func (f fakeTensor) Kind() uint32 {
return 0
}
func (f fakeTensor) SourceDType() string {
return f.sourceDType
}
func (f *fakeTensor) SetRepacker(fn Repacker) {
f.repacker = fn
}
func (f fakeTensor) Clone() Tensor {
return &fakeTensor{
name: f.name,
shape: slices.Clone(f.shape),
data: slices.Clone(f.data),
repacker: f.repacker,
name: f.name,
shape: slices.Clone(f.shape),
data: slices.Clone(f.data),
sourceDType: f.sourceDType,
repacker: f.repacker,
}
}
@@ -995,3 +1001,43 @@ func TestMergeOrder(t *testing.T) {
})
}
}
func TestSourceTensorKVRecordsFP8OutputTensors(t *testing.T) {
fp8 := &fakeTensor{name: "linear.weight", shape: []uint64{2, 2}, sourceDType: "F8_E4M3"}
bf16 := &fakeTensor{name: "other.weight", shape: []uint64{2, 2}, sourceDType: "BF16"}
kv := sourceTensorKV([]*ggml.Tensor{
{Name: "blk.0.linear.weight", WriterTo: fp8},
{Name: "blk.0.other.weight", WriterTo: bf16},
})
if got := kv["source_quantization"]; got != "hf_fp8" {
t.Fatalf("source_quantization = %v, want hf_fp8", got)
}
got, ok := kv["source_fp8_tensors"].([]string)
if !ok {
t.Fatalf("source_fp8_tensors = %#v, want []string", kv["source_fp8_tensors"])
}
if diff := cmp.Diff([]string{"blk.0.linear.weight"}, got); diff != "" {
t.Fatalf("source_fp8_tensors mismatch (-want +got):\n%s", diff)
}
}
func TestSourceTensorKVRecordsMergedFP8OutputTensors(t *testing.T) {
fp8A := &fakeTensor{name: "expert.0.weight", shape: []uint64{2, 2}, sourceDType: "F8_E4M3"}
fp8B := &fakeTensor{name: "expert.1.weight", shape: []uint64{2, 2}, sourceDType: "F8_E4M3"}
bf16 := &fakeTensor{name: "expert.2.weight", shape: []uint64{2, 2}, sourceDType: "BF16"}
kv := sourceTensorKV([]*ggml.Tensor{
{Name: "ffn_exps.weight", WriterTo: mergeGroup{fp8A, fp8B}},
{Name: "mixed_exps.weight", WriterTo: mergeGroup{fp8A, bf16}},
})
got, ok := kv["source_fp8_tensors"].([]string)
if !ok {
t.Fatalf("source_fp8_tensors = %#v, want []string", kv["source_fp8_tensors"])
}
if diff := cmp.Diff([]string{"ffn_exps.weight"}, got); diff != "" {
t.Fatalf("source_fp8_tensors mismatch (-want +got):\n%s", diff)
}
}

View File

@@ -103,6 +103,8 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
t.Pre = "qwen2"
case "00431aed57e696b747435f734d1e3b9b1bfd931a121fb5cac7129e97c181e9ba":
t.Pre = "qwen35"
case "b92c0824a58e1d8dc3221cf3e12c433c3a86f57e46d57229993489f0798e7702":
t.Pre = "laguna"
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
// noop, empty pretokenizer
default:

View File

@@ -124,7 +124,8 @@
"/integrations/opencode",
"/integrations/droid",
"/integrations/goose",
"/integrations/pi"
"/integrations/pi",
"/integrations/poolside"
]
},
{

View File

@@ -15,6 +15,7 @@ Coding assistants that can read, modify, and execute code in your projects.
- [Droid](/integrations/droid)
- [Goose](/integrations/goose)
- [Pi](/integrations/pi)
- [Poolside](/integrations/poolside)
## Assistants

View File

@@ -0,0 +1,54 @@
---
title: Poolside
---
Poolside is Poolside's software agent for the terminal, built for enterprise development workflows.
## Install
Install [Poolside](https://github.com/poolsideai/pool):
## Usage with Ollama
### Quick setup
```shell
ollama launch pool
```
### Run directly with a model
```shell
ollama launch pool --model kimi-k2.6:cloud
```
### Pass arguments through to Poolside
Arguments after `--` are passed directly to Poolside:
```shell
ollama launch pool -- --help
```
## Manual setup
Poolside connects to Ollama using the OpenAI-compatible API via environment variables.
1. Set the environment variables:
```shell
export POOLSIDE_STANDALONE_BASE_URL=http://localhost:11434/v1
export POOLSIDE_API_KEY=ollama
```
2. Run Poolside with an Ollama model:
```shell
pool -m kimi-k2.6:cloud
```
Or run with environment variables inline:
```shell
POOLSIDE_STANDALONE_BASE_URL=http://localhost:11434/v1 POOLSIDE_API_KEY=ollama pool -m kimi-k2.6:cloud
```

View File

@@ -283,10 +283,11 @@ func (kv KV) OllamaEngineRequired() bool {
"gemma3n",
"gemma4",
"gptoss", "gpt-oss",
"laguna",
"llama4",
"mistral3",
"mllama",
"nemotron_h", "nemotron_h_moe",
"nemotron_h", "nemotron_h_moe", "nemotron_h_omni",
"nomic-bert",
"olmo3",
"qwen25vl",
@@ -897,7 +898,7 @@ func (f GGML) FlashAttention() bool {
"lfm2",
"lfm2moe",
"mistral3",
"nemotron_h", "nemotron_h_moe",
"nemotron_h", "nemotron_h_moe", "nemotron_h_omni",
"olmo3",
"qwen3", "qwen3moe",
"qwen35", "qwen35moe",

View File

@@ -0,0 +1,444 @@
package laguna
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
const (
cacheTypeSWA = iota
cacheTypeCausal
)
type Options struct {
hiddenSize int
headDim int
numHeads []int
numKVHeads int
eps float32
slidingWindow int
slidingWindowPattern []bool
fullRopeDim int
fullRopeBase, fullRopeScale float32
fullRopeOriginalContextLength int
fullRopeAttentionFactor float32
fullRopeBetaFast float32
fullRopeBetaSlow float32
swaRopeDim int
swaRopeBase, swaRopeScale float32
numExperts, numExpertsUsed int
normTopKProb bool
routedScalingFactor float32
decoderSparseStep int
denseLayers map[int]bool
}
func (o *Options) numHeadsForLayer(layer int) int {
if layer < len(o.numHeads) && o.numHeads[layer] > 0 {
return o.numHeads[layer]
}
if len(o.numHeads) > 0 && o.numHeads[0] > 0 {
return o.numHeads[0]
}
return 1
}
func (o *Options) layerIsSliding(layer int) bool {
return layer < len(o.slidingWindowPattern) && o.slidingWindowPattern[layer]
}
func (o *Options) layerUsesMoE(layer int) bool {
if o.numExperts == 0 || o.denseLayers[layer] {
return false
}
step := o.decoderSparseStep
if step <= 0 {
step = 1
}
return (layer+1)%step == 0
}
func (o *Options) applyRotaryPositionEmbeddings(ctx ml.Context, layer int, states, positions ml.Tensor) ml.Tensor {
opts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.layerIsSliding(layer) {
return nn.RoPE(ctx, states, positions, o.swaRopeDim, o.swaRopeBase, 1./o.swaRopeScale, opts...)
}
opts = append(opts,
rope.WithOriginalContextLength(o.fullRopeOriginalContextLength),
rope.WithExtrapolationFactor(1),
rope.WithAttentionFactor(o.fullRopeAttentionFactor),
rope.WithBetaFast(o.fullRopeBetaFast),
rope.WithBetaSlow(o.fullRopeBetaSlow),
)
return nn.RoPE(ctx, states, positions, o.fullRopeDim, o.fullRopeBase, 1./o.fullRopeScale, opts...)
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Gate *nn.Linear `gguf:"attn_g"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *Attention) Forward(ctx ml.Context, layer int, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
numHeads := opts.numHeadsForLayer(layer)
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
gate := sa.Gate.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim, numHeads, batchSize)
key = key.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = opts.applyRotaryPositionEmbeddings(ctx, layer, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, layer, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), cache)
// Laguna applies the per-head gate softplus in float32, then casts back.
gate = gate.Cast(ctx, ml.DTypeF32).Softplus(ctx).Cast(ctx, attention.DType())
attention = attention.Mul(ctx, gate.Reshape(ctx, 1, numHeads, batchSize))
attention = attention.Reshape(ctx, opts.headDim*numHeads, batchSize)
return sa.Output.Forward(ctx, attention)
}
type MLP interface {
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
}
type dense struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
SharedExpert *dense `gguf:",suf:_shexp"`
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
}
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
if moe.ExpProbsBias != nil {
scores = scores.Add(ctx, moe.ExpProbsBias)
}
return scores.TopK(ctx, opts.numExpertsUsed)
}
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
residual := hiddenStates
scores := moe.Router.Forward(ctx, hiddenStates).Cast(ctx, ml.DTypeF32).Sigmoid(ctx)
selectedExperts := moe.topKIndices(ctx, scores, opts)
routingWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, selectedExperts)
if opts.normTopKProb {
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
}
routingWeights = routingWeights.Scale(ctx, float64(opts.routedScalingFactor))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
upStates := moe.Up.Forward(ctx, hiddenStates, selectedExperts)
hiddenStates = moe.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, upStates)
experts := moe.Down.Forward(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextStates.Add(ctx, moe.SharedExpert.Forward(ctx, residual, opts))
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
*Attention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenStates
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = l.Attention.Forward(ctx, layer, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = l.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type Model struct {
model.Base
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
}
func New(c fs.Config) (model.Model, error) {
if c.Bool("attention.sink_enabled") {
return nil, fmt.Errorf("laguna: SWA attention sinks are not supported")
}
if c.Uint("attention.gating_type") != 1 {
return nil, fmt.Errorf("laguna: unsupported attention gating type %d", c.Uint("attention.gating_type"))
}
if !c.Bool("attention.qk_norm") {
return nil, fmt.Errorf("laguna: Q/K RMSNorm is required")
}
if gating := c.Uint("expert_gating_func"); gating != 2 {
return nil, fmt.Errorf("laguna: unsupported expert gating function %d", gating)
}
numLayers := int(c.Uint("block_count"))
opts := newOptions(c, numLayers)
layers := make([]Layer, numLayers)
for i := range layers {
if opts.layerUsesMoE(i) {
layers[i].MLP = &sparse{}
} else {
layers[i].MLP = &dense{}
}
}
var pre []string
switch c.String("tokenizer.ggml.pre") {
case "laguna":
pre = []string{
`(?:\r?\n)+(?!\r?\n)`,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
}
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
pre...,
),
Layers: layers,
Options: opts,
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(opts.slidingWindow), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
func newOptions(c fs.Config, numLayers int) *Options {
denseLayers := make(map[int]bool)
for _, layer := range configUints(c, "dense_layers") {
denseLayers[int(layer)] = true
}
for i := range c.Uint("leading_dense_block_count") {
denseLayers[int(i)] = true
}
fullRopeScale := c.Float("rope.scaling.factor", 1)
if fullRopeScale == 0 {
fullRopeScale = 1
}
swaRopeScale := c.Float("rope.swa.scaling.factor", 1)
if swaRopeScale == 0 {
swaRopeScale = 1
}
fullRopeType := c.String("rope.scaling.type")
fullRopeAttentionFactor := lagunaAttentionFactor(fullRopeType, fullRopeScale, c.Float("rope.scaling.attn_factor"))
return &Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),
numHeads: expandIntArray(configUints(c, "attention.head_count"), numLayers, c.Uint("attention.head_count", 1)),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-6),
slidingWindow: int(c.Uint("attention.sliding_window", 512)),
slidingWindowPattern: slidingWindowPattern(c, numLayers),
fullRopeDim: int(c.Uint("rope.dimension_count", c.Uint("attention.key_length"))),
fullRopeBase: c.Float("rope.freq_base", 500000),
fullRopeScale: fullRopeScale,
fullRopeOriginalContextLength: int(c.Uint("rope.scaling.original_context_length", 4096)),
fullRopeAttentionFactor: fullRopeAttentionFactor,
fullRopeBetaFast: c.Float("rope.scaling.beta_fast", 64),
fullRopeBetaSlow: c.Float("rope.scaling.beta_slow", 1),
swaRopeDim: int(c.Uint("rope.swa.dimension_count", c.Uint("attention.key_length"))),
swaRopeBase: c.Float("rope.swa.freq_base", 10000),
swaRopeScale: swaRopeScale,
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("expert_weights_norm", true),
routedScalingFactor: c.Float("expert_weights_scale", 1),
decoderSparseStep: int(c.Uint("decoder_sparse_step", 1)),
denseLayers: denseLayers,
}
}
func lagunaAttentionFactor(ropeType string, scaleFactor, attentionFactor float32) float32 {
if attentionFactor != 0 {
return attentionFactor
}
if ropeType == "yarn" && scaleFactor > 1 {
return float32(0.1*math.Log(float64(scaleFactor)) + 1)
}
return 1
}
func slidingWindowPattern(c fs.Config, numLayers int) []bool {
pattern := c.Bools("attention.sliding_window_pattern")
if len(pattern) == numLayers {
return pattern
}
layerTypes := configUints(c, "attention.layer_types")
if len(layerTypes) == numLayers {
pattern = make([]bool, numLayers)
for i, layerType := range layerTypes {
pattern[i] = layerType == 1
}
return pattern
}
return make([]bool, numLayers)
}
func configUints(c fs.Config, key string) []uint32 {
keyExists := c.Value(c.Architecture()+"."+key) != nil || c.Value(key) != nil
if cc, ok := c.(interface {
Uints(string, ...[]uint32) []uint32
}); ok {
if values := cc.Uints(key); len(values) > 0 && (keyExists || !(len(values) == 1 && values[0] == 0)) {
return values
}
}
ints := c.Ints(key)
if len(ints) > 0 && (keyExists || !(len(ints) == 1 && ints[0] == 0)) {
values := make([]uint32, len(ints))
for i, v := range ints {
values[i] = uint32(v)
}
return values
}
if scalar := c.Uint(key); scalar != 0 {
return []uint32{scalar}
}
return nil
}
func expandIntArray(values []uint32, n int, fallback uint32) []int {
if len(values) == 0 {
values = []uint32{fallback}
}
defaultValue := values[0]
if len(values) == 1 {
defaultValue = values[0]
}
out := make([]int, n)
for i := range out {
if i < len(values) {
out[i] = int(values[i])
} else {
out[i] = int(defaultValue)
}
}
return out
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.Options.applyRotaryPositionEmbeddings(ctx, layer, key, shift), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
if m.Cache != nil {
m.Cache.SetLayer(i)
if wrapper, ok := m.Cache.(*kvcache.WrapperCache); ok {
cacheType := cacheTypeCausal
if m.Options.layerIsSliding(i) {
cacheType = cacheTypeSWA
}
wrapper.SetLayerType(cacheType)
}
}
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, i, hiddenStates, positions, outputs, m.Cache, m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("laguna", New)
}
var _ model.Model = (*Model)(nil)

View File

@@ -0,0 +1,237 @@
package laguna
import (
"iter"
"math"
"testing"
)
type testConfig map[string]any
func (c testConfig) Architecture() string { return "laguna" }
func (c testConfig) key(key string) string {
switch {
case len(key) >= len("tokenizer.") && key[:len("tokenizer.")] == "tokenizer.":
return key
case len(key) >= len("general.") && key[:len("general.")] == "general.":
return key
default:
return "laguna." + key
}
}
func (c testConfig) String(key string, defaultValue ...string) string {
if v, ok := c[c.key(key)].(string); ok {
return v
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
return ""
}
func (c testConfig) Uint(key string, defaultValue ...uint32) uint32 {
switch v := c[c.key(key)].(type) {
case uint32:
return v
case int:
return uint32(v)
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
return 0
}
func (c testConfig) Float(key string, defaultValue ...float32) float32 {
if v, ok := c[c.key(key)].(float32); ok {
return v
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
return 0
}
func (c testConfig) Bool(key string, defaultValue ...bool) bool {
if v, ok := c[c.key(key)].(bool); ok {
return v
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
return false
}
func (c testConfig) Strings(key string, defaultValue ...[]string) []string {
if v, ok := c[c.key(key)].([]string); ok {
return v
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
return nil
}
func (c testConfig) Ints(key string, defaultValue ...[]int32) []int32 {
if v, ok := c[c.key(key)].([]int32); ok {
return v
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
return nil
}
func (c testConfig) Uints(key string, defaultValue ...[]uint32) []uint32 {
if v, ok := c[c.key(key)].([]uint32); ok {
return v
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
return nil
}
func (c testConfig) Floats(key string, defaultValue ...[]float32) []float32 {
if v, ok := c[c.key(key)].([]float32); ok {
return v
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
return nil
}
func (c testConfig) Bools(key string, defaultValue ...[]bool) []bool {
if v, ok := c[c.key(key)].([]bool); ok {
return v
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
return nil
}
func (c testConfig) Len() int { return len(c) }
func (c testConfig) Keys() iter.Seq[string] {
return func(yield func(string) bool) {
for key := range c {
if !yield(key) {
return
}
}
}
}
func (c testConfig) Value(key string) any { return c[key] }
func TestNewOptionsLayerConfig(t *testing.T) {
cfg := testConfig{
"laguna.block_count": uint32(4),
"laguna.embedding_length": uint32(128),
"laguna.attention.key_length": uint32(16),
"laguna.attention.head_count": []uint32{8, 16, 16, 16},
"laguna.attention.head_count_kv": uint32(4),
"laguna.attention.layer_norm_rms_epsilon": float32(1e-6),
"laguna.attention.sliding_window": uint32(512),
"laguna.attention.sliding_window_pattern": []bool{false, true, true, true},
"laguna.rope.dimension_count": uint32(8),
"laguna.rope.freq_base": float32(500000),
"laguna.rope.scaling.factor": float32(32),
"laguna.rope.scaling.original_context_length": uint32(4096),
"laguna.rope.swa.dimension_count": uint32(16),
"laguna.rope.swa.freq_base": float32(10000),
"laguna.expert_count": uint32(32),
"laguna.expert_used_count": uint32(4),
"laguna.expert_weights_norm": true,
"laguna.expert_weights_scale": float32(2.5),
"laguna.decoder_sparse_step": uint32(1),
"laguna.dense_layers": []uint32{0},
}
opts := newOptions(cfg, 4)
if got := opts.numHeadsForLayer(0); got != 8 {
t.Fatalf("layer 0 heads = %d, want 8", got)
}
if got := opts.numHeadsForLayer(1); got != 16 {
t.Fatalf("layer 1 heads = %d, want 16", got)
}
if opts.layerIsSliding(0) {
t.Fatal("layer 0 should be full attention")
}
if !opts.layerIsSliding(1) {
t.Fatal("layer 1 should be sliding attention")
}
if opts.layerUsesMoE(0) {
t.Fatal("layer 0 should be dense")
}
if !opts.layerUsesMoE(1) {
t.Fatal("layer 1 should use MoE")
}
if opts.fullRopeDim != 8 || opts.swaRopeDim != 16 {
t.Fatalf("rope dims = full %d swa %d, want 8/16", opts.fullRopeDim, opts.swaRopeDim)
}
}
func TestNewOptionsYarnAttentionFactorFallback(t *testing.T) {
cfg := testConfig{
"laguna.block_count": uint32(1),
"laguna.embedding_length": uint32(128),
"laguna.attention.key_length": uint32(16),
"laguna.attention.head_count": uint32(8),
"laguna.attention.head_count_kv": uint32(4),
"laguna.rope.scaling.type": "yarn",
"laguna.rope.scaling.factor": float32(32),
}
opts := newOptions(cfg, 1)
want := float32(0.1*math.Log(32) + 1)
if got := opts.fullRopeAttentionFactor; math.Abs(float64(got-want)) > 1e-6 {
t.Fatalf("fullRopeAttentionFactor = %v, want %v", got, want)
}
}
func TestNewRejectsUnsupportedLagunaVariants(t *testing.T) {
tests := []struct {
name string
cfg testConfig
}{
{
name: "attention sinks",
cfg: testConfig{
"laguna.attention.sink_enabled": true,
},
},
{
name: "non per-head gate",
cfg: testConfig{
"laguna.attention.gating_type": uint32(0),
},
},
{
name: "missing qk norm",
cfg: testConfig{
"laguna.attention.gating_type": uint32(1),
},
},
{
name: "non sigmoid experts",
cfg: testConfig{
"laguna.attention.gating_type": uint32(1),
"laguna.attention.qk_norm": true,
"laguna.expert_gating_func": uint32(1),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, err := New(tt.cfg); err == nil {
t.Fatal("expected unsupported variant error")
}
})
}
}

View File

@@ -11,6 +11,7 @@ import (
_ "github.com/ollama/ollama/model/models/glm4moelite"
_ "github.com/ollama/ollama/model/models/glmocr"
_ "github.com/ollama/ollama/model/models/gptoss"
_ "github.com/ollama/ollama/model/models/laguna"
_ "github.com/ollama/ollama/model/models/lfm2"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"

View File

@@ -0,0 +1,355 @@
package nemotronh
import (
"errors"
"image"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize int
patchSize int
numChannels int
maxTiles int
minNumPatches int
maxNumPatches int
useThumbnail bool
projectorScale int
imageMean [3]float32
imageStd [3]float32
}
type processedVisionTile struct {
data []float32
size image.Point
}
func newImageProcessor(c fs.Config) ImageProcessor {
mean := c.Floats("vision.image_mean")
std := c.Floats("vision.image_std")
processor := ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 512)),
patchSize: int(c.Uint("vision.patch_size", 16)),
numChannels: int(c.Uint("vision.num_channels", 3)),
maxTiles: int(c.Uint("vision.max_tiles", 12)),
minNumPatches: int(c.Uint("vision.min_num_patches")),
maxNumPatches: int(c.Uint("vision.max_num_patches")),
useThumbnail: c.Bool("vision.use_thumbnail", true),
projectorScale: int(c.Uint("vision.projector.scale_factor", 2)),
imageMean: imageproc.ClipDefaultMean,
imageStd: imageproc.ClipDefaultSTD,
}
if len(mean) >= 3 {
processor.imageMean = [3]float32{mean[0], mean[1], mean[2]}
}
if len(std) >= 3 {
processor.imageStd = [3]float32{std[0], std[1], std[2]}
}
if processor.imageSize <= 0 {
processor.imageSize = 512
}
if processor.patchSize <= 0 {
processor.patchSize = 16
}
if processor.numChannels <= 0 {
processor.numChannels = 3
}
if processor.maxTiles <= 0 {
processor.maxTiles = 12
}
if processor.projectorScale <= 0 {
processor.projectorScale = 2
}
return processor
}
func (p ImageProcessor) ProcessImage(img image.Image) ([]processedVisionTile, error) {
img = imageproc.Composite(img)
if p.useDynamicResolution() {
return p.processDynamicImage(img)
}
return p.processTiledImage(img), nil
}
func (p ImageProcessor) useDynamicResolution() bool {
return p.minNumPatches > 0 || p.maxNumPatches > 0
}
func (p ImageProcessor) processTiledImage(img image.Image) []processedVisionTile {
bounds := img.Bounds()
origWidth := bounds.Dx()
origHeight := bounds.Dy()
targetRatios := nemotronTargetRatios(p.maxTiles)
gridWidth, gridHeight := findClosestAspectRatio(float64(origWidth)/float64(origHeight), targetRatios, origWidth, origHeight, p.imageSize)
targetWidth := p.imageSize * gridWidth
targetHeight := p.imageSize * gridHeight
resized := resizeImageBicubicCHW(img, targetWidth, targetHeight)
tiles := make([]processedVisionTile, 0, gridWidth*gridHeight+1)
for row := range gridHeight {
for col := range gridWidth {
tile := cropCHWRegion(
resized,
targetWidth,
targetHeight,
p.numChannels,
col*p.imageSize,
row*p.imageSize,
p.imageSize,
p.imageSize,
)
tiles = append(tiles, processedVisionTile{
data: normalizeVisionCHW(tile, p.imageMean, p.imageStd),
size: image.Point{X: p.imageSize, Y: p.imageSize},
})
}
}
if p.useThumbnail && len(tiles) > 1 {
thumbnail := resizeImageBicubicCHW(img, p.imageSize, p.imageSize)
tiles = append(tiles, processedVisionTile{
data: normalizeVisionCHW(thumbnail, p.imageMean, p.imageStd),
size: image.Point{X: p.imageSize, Y: p.imageSize},
})
}
return tiles
}
func (p ImageProcessor) processDynamicImage(img image.Image) ([]processedVisionTile, error) {
bounds := img.Bounds()
origWidth := bounds.Dx()
origHeight := bounds.Dy()
patchesWidth, patchesHeight := p.dynamicPatchGrid(origWidth, origHeight)
if patchesWidth <= 0 || patchesHeight <= 0 {
return nil, errors.New("nemotron_h_omni: invalid dynamic image patch grid")
}
targetWidth := patchesWidth * p.patchSize
targetHeight := patchesHeight * p.patchSize
resized := resizeImageBicubicCHW(img, targetWidth, targetHeight)
return []processedVisionTile{{
data: normalizeVisionCHW(resized, p.imageMean, p.imageStd),
size: image.Point{X: targetWidth, Y: targetHeight},
}}, nil
}
func (p ImageProcessor) dynamicPatchGrid(origWidth, origHeight int) (int, int) {
patchesHeight := max(1, int(math.Round(float64(origHeight)/float64(p.patchSize)+0.5)))
patchesWidth := max(1, int(math.Round(float64(origWidth)/float64(p.patchSize)+0.5)))
patches := patchesHeight * patchesWidth
currentNumPatchesAvailable := p.maxNumPatches
if currentNumPatchesAvailable <= 0 {
currentNumPatchesAvailable = max(patches, p.minNumPatches)
}
factor := math.Min(math.Sqrt(float64(currentNumPatchesAvailable)/float64(patches)), 1.0)
targetPatchesHeight := max(1, int(math.Floor(factor*float64(patchesHeight))))
targetPatchesWidth := max(1, int(math.Floor(factor*float64(patchesWidth))))
if currentNumPatchesAvailable > p.minNumPatches && targetPatchesHeight*targetPatchesWidth < p.minNumPatches {
upFactor := math.Sqrt(float64(p.minNumPatches) / float64(targetPatchesHeight*targetPatchesWidth))
targetPatchesHeight = int(math.Ceil(upFactor * float64(targetPatchesHeight)))
targetPatchesWidth = int(math.Ceil(upFactor * float64(targetPatchesWidth)))
}
targetPatchesHeight = roundPatchGridForPixelShuffle(targetPatchesHeight, targetPatchesWidth, currentNumPatchesAvailable, p.projectorScale)
targetPatchesWidth = roundPatchGridForPixelShuffle(targetPatchesWidth, targetPatchesHeight, currentNumPatchesAvailable, p.projectorScale)
return targetPatchesWidth, targetPatchesHeight
}
func roundPatchGridForPixelShuffle(v, other, maxPatches, divisor int) int {
if divisor <= 1 {
return v
}
rem := v % divisor
if rem == 0 {
return v
}
inc := divisor - rem
if (v+inc)*other <= maxPatches {
return v + inc
}
return max(divisor, v-rem)
}
type nemotronImageRatio struct {
width int
height int
}
func nemotronTargetRatios(maxTiles int) []nemotronImageRatio {
targetRatios := make([]nemotronImageRatio, 0, maxTiles*maxTiles)
for n := 1; n <= maxTiles; n++ {
for w := 1; w <= n; w++ {
for h := 1; h <= n; h++ {
if w*h > maxTiles {
continue
}
targetRatios = append(targetRatios, nemotronImageRatio{width: w, height: h})
}
}
}
unique := targetRatios[:0]
for _, ratio := range targetRatios {
if slices.Contains(unique, ratio) {
continue
}
unique = append(unique, ratio)
}
slices.SortFunc(unique, func(a, b nemotronImageRatio) int {
return a.width*a.height - b.width*b.height
})
return unique
}
func findClosestAspectRatio(aspectRatio float64, targetRatios []nemotronImageRatio, width, height, imageSize int) (int, int) {
bestRatio := nemotronImageRatio{width: 1, height: 1}
bestRatioDiff := math.MaxFloat64
area := width * height
for _, ratio := range targetRatios {
targetAspectRatio := float64(ratio.width) / float64(ratio.height)
ratioDiff := math.Abs(aspectRatio - targetAspectRatio)
if ratioDiff < bestRatioDiff {
bestRatioDiff = ratioDiff
bestRatio = ratio
continue
}
if ratioDiff == bestRatioDiff && area > int(0.5*float64(imageSize*imageSize*ratio.width*ratio.height)) {
bestRatio = ratio
}
}
return bestRatio.width, bestRatio.height
}
func resizeImageBicubicCHW(img image.Image, outW, outH int) []float32 {
bounds := img.Bounds()
inW := bounds.Dx()
inH := bounds.Dy()
src := make([]float32, 3*inW*inH)
for y := range inH {
for x := range inW {
r, g, b, _ := img.At(bounds.Min.X+x, bounds.Min.Y+y).RGBA()
src[y*inW+x] = float32(r>>8) / 255.0
src[inW*inH+y*inW+x] = float32(g>>8) / 255.0
src[2*inW*inH+y*inW+x] = float32(b>>8) / 255.0
}
}
dst := make([]float32, 3*outW*outH)
scaleX := float64(inW) / float64(outW)
scaleY := float64(inH) / float64(outH)
for oy := range outH {
srcY := scaleY*(float64(oy)+0.5) - 0.5
yBase := int(math.Floor(srcY))
yFrac := clampUnit(srcY - float64(yBase))
wy := torchBicubicWeights(yFrac)
for ox := range outW {
srcX := scaleX*(float64(ox)+0.5) - 0.5
xBase := int(math.Floor(srcX))
xFrac := clampUnit(srcX - float64(xBase))
wx := torchBicubicWeights(xFrac)
for c := range 3 {
var sum float64
channelOffset := c * inW * inH
for ky := range 4 {
iy := clampIndex(yBase-1+ky, 0, inH-1)
for kx := range 4 {
ix := clampIndex(xBase-1+kx, 0, inW-1)
sum += float64(src[channelOffset+iy*inW+ix]) * wy[ky] * wx[kx]
}
}
dst[c*outW*outH+oy*outW+ox] = float32(sum)
}
}
}
return dst
}
func cropCHWRegion(values []float32, width, height, channels, left, top, cropW, cropH int) []float32 {
out := make([]float32, channels*cropW*cropH)
channelSize := width * height
cropSize := cropW * cropH
for c := range channels {
srcBase := c * channelSize
dstBase := c * cropSize
for y := range cropH {
copy(out[dstBase+y*cropW:dstBase+(y+1)*cropW], values[srcBase+(top+y)*width+left:srcBase+(top+y)*width+left+cropW])
}
}
return out
}
func normalizeVisionCHW(values []float32, mean, std [3]float32) []float32 {
out := make([]float32, len(values))
channelSize := len(values) / 3
for c := range 3 {
base := c * channelSize
for i := range channelSize {
out[base+i] = (values[base+i] - mean[c]) / std[c]
}
}
return out
}
func torchBicubicWeights(t float64) [4]float64 {
const a = -0.75
return [4]float64{
bicubicConvolution2(t+1.0, a),
bicubicConvolution1(t, a),
bicubicConvolution1(1.0-t, a),
bicubicConvolution2(2.0-t, a),
}
}
func bicubicConvolution1(x, a float64) float64 {
return ((a+2)*x-(a+3))*x*x + 1
}
func bicubicConvolution2(x, a float64) float64 {
return ((a*x-5*a)*x+8*a)*x - 4*a
}
func clampUnit(v float64) float64 {
if v < 0 {
return 0
}
if v > 1 {
return 1
}
return v
}
func clampIndex(v, lo, hi int) int {
if v < lo {
return lo
}
if v > hi {
return hi
}
return v
}

View File

@@ -117,9 +117,7 @@ func Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
func (m *Model) forwardHiddenStates(ctx ml.Context, batch input.Batch, hiddenStates ml.Tensor) (ml.Tensor, error) {
cache := m.Cache.(*HybridCache)
for i, layer := range m.Layers {
@@ -137,11 +135,24 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil
}
func (m *Model) forwardLogits(ctx ml.Context, batch input.Batch, hiddenStates ml.Tensor) (ml.Tensor, error) {
hiddenStates, err := m.forwardHiddenStates(ctx, batch, hiddenStates)
if err != nil {
return nil, err
}
return m.Output.Forward(ctx, hiddenStates), nil
}
func New(c fs.Config) (model.Model, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
return m.forwardLogits(ctx, batch, hiddenStates)
}
func newTextModel(c fs.Config) (*Model, error) {
numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers)
@@ -306,6 +317,10 @@ func New(c fs.Config) (model.Model, error) {
return &m, nil
}
func New(c fs.Config) (model.Model, error) {
return newTextModel(c)
}
func init() {
model.Register("nemotron_h", New)
model.Register("nemotron_h_moe", New)

View File

@@ -0,0 +1,511 @@
package nemotronh
import (
"math"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type AudioOptions struct {
hiddenSize int
numHeads int
headDim int
intermediateSize int
convKernelSize int
melBins int
sampleRate int
subsamplingKernel int
subsamplingStride int
scaleInput bool
eps float32
}
type AudioFeatureExtractor struct {
FB ml.Tensor `gguf:"fb"`
Window ml.Tensor `gguf:"window"`
mu sync.Mutex
fb []float32
window []float32
fbShape [2]int
}
func (f *AudioFeatureExtractor) windowAndFilters(melBins, freqBins, sampleRate int) ([]float32, []float32) {
if f == nil {
return defaultParakeetWindow(), buildSlaneyMelFilterBank(freqBins, melBins, sampleRate)
}
f.mu.Lock()
defer f.mu.Unlock()
if f.window == nil {
if f.Window != nil {
if values := f.Window.BackendGet(); len(values) == parakeetWinLength {
f.window = values
}
}
if f.window == nil {
f.window = defaultParakeetWindow()
}
}
if f.fb == nil {
if f.FB != nil {
if values := f.FB.BackendGet(); len(values) == melBins*freqBins {
f.fb = values
f.fbShape = [2]int{melBins, freqBins}
}
}
if f.fb == nil {
f.fb = buildSlaneyMelFilterBank(freqBins, melBins, sampleRate)
f.fbShape = [2]int{melBins, freqBins}
}
}
return f.window, f.fb
}
type AudioSubsampling struct {
Conv0 *nn.Conv2D `gguf:"conv0"`
DW1 *AudioDepthwiseConv2D `gguf:"dw1"`
PW1 *nn.Conv2D `gguf:"pw1"`
DW2 *AudioDepthwiseConv2D `gguf:"dw2"`
PW2 *nn.Conv2D `gguf:"pw2"`
Linear *nn.Linear `gguf:"linear"`
}
type AudioDepthwiseConv2D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
type AudioFeedForward struct {
Up *nn.Linear `gguf:"up"`
Down *nn.Linear `gguf:"down"`
}
type AudioSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_out"`
RelativeKey *nn.Linear `gguf:"attn_rel_k"`
BiasU ml.Tensor `gguf:"attn_bias_u"`
BiasV ml.Tensor `gguf:"attn_bias_v"`
}
type AudioConvolutionModule struct {
Pointwise1 *nn.Linear `gguf:"conv_pw1"`
Depthwise ml.Tensor `gguf:"conv_dw.weight"`
BatchNorm *AudioBatchNorm1D `gguf:"conv_bn"`
Pointwise2 *nn.Linear `gguf:"conv_pw2"`
}
type AudioBatchNorm1D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
RunningMean ml.Tensor `gguf:"running_mean"`
RunningVar ml.Tensor `gguf:"running_var"`
}
type AudioLayer struct {
FFN1Norm *nn.LayerNorm `gguf:"ffn1_norm"`
FFN1Up *nn.Linear `gguf:"ffn1_up"`
FFN1Down *nn.Linear `gguf:"ffn1_down"`
AttentionNorm *nn.LayerNorm `gguf:"attn_norm"`
Attention *AudioSelfAttention
ConvNorm *nn.LayerNorm `gguf:"conv_norm"`
Conv *AudioConvolutionModule
FFN2Norm *nn.LayerNorm `gguf:"ffn2_norm"`
FFN2Up *nn.Linear `gguf:"ffn2_up"`
FFN2Down *nn.Linear `gguf:"ffn2_down"`
OutputNorm *nn.LayerNorm `gguf:"out_norm"`
}
type AudioModel struct {
FeatureExtractor *AudioFeatureExtractor `gguf:"feature_extractor"`
Subsampling *AudioSubsampling `gguf:"subsampling"`
Layers []AudioLayer `gguf:"blk"`
*AudioOptions
}
type AudioProjector struct {
Norm *nn.RMSNorm `gguf:"norm"`
Linear1 *nn.Linear `gguf:"1"`
Linear2 *nn.Linear `gguf:"2"`
}
func (p *AudioProjector) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
x = p.Norm.Forward(ctx, x, eps)
x = audioF32(ctx, p.Linear1.Forward(ctx, x))
x = x.RELU(ctx)
x = x.Mul(ctx, x)
return audioF32(ctx, p.Linear2.Forward(ctx, x))
}
func (m *AudioModel) ForwardAudio(ctx ml.Context, melFeatures ml.Tensor, validFrames int, projector *AudioProjector) ml.Tensor {
x := melFeatures.Reshape(ctx, melFeatures.Dim(0), melFeatures.Dim(1), 1, 1)
validLen := validFrames
x = forwardAudioConv2D(ctx, m.Subsampling.Conv0, x, m.subsamplingStride, m.subsamplingStride, audioConvPadding(m.subsamplingKernel), audioConvPadding(m.subsamplingKernel), 1, 1)
x = x.RELU(ctx)
validLen = convOutputLength(validLen, m.subsamplingKernel, m.subsamplingStride, audioConvPadding(m.subsamplingKernel))
x = applyAudioTimeMask(ctx, x, validLen)
x = forwardAudioDepthwiseConv2D(ctx, m.Subsampling.DW1, x, m.subsamplingStride, m.subsamplingStride, audioConvPadding(m.subsamplingKernel), audioConvPadding(m.subsamplingKernel), 1, 1)
x = forwardAudioConv2D(ctx, m.Subsampling.PW1, x, 1, 1, 0, 0, 1, 1)
x = x.RELU(ctx)
validLen = convOutputLength(validLen, m.subsamplingKernel, m.subsamplingStride, audioConvPadding(m.subsamplingKernel))
x = applyAudioTimeMask(ctx, x, validLen)
x = forwardAudioDepthwiseConv2D(ctx, m.Subsampling.DW2, x, m.subsamplingStride, m.subsamplingStride, audioConvPadding(m.subsamplingKernel), audioConvPadding(m.subsamplingKernel), 1, 1)
x = forwardAudioConv2D(ctx, m.Subsampling.PW2, x, 1, 1, 0, 0, 1, 1)
x = x.RELU(ctx)
validLen = convOutputLength(validLen, m.subsamplingKernel, m.subsamplingStride, audioConvPadding(m.subsamplingKernel))
x = applyAudioTimeMask(ctx, x, validLen)
x = flattenAudioSubsamplingOutput(ctx, x)
x = m.Subsampling.Linear.Forward(ctx, x)
if m.scaleInput {
x = x.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
}
if validLen > 0 && validLen < x.Dim(1) {
x = x.Slice(ctx, 1, 0, validLen, 1).Contiguous(ctx)
}
for i := range m.Layers {
x = m.Layers[i].Forward(ctx, x, validLen, m.AudioOptions)
}
if projector != nil {
x = projector.Forward(ctx, x, m.eps)
}
return x
}
func flattenAudioSubsamplingOutput(ctx ml.Context, x ml.Tensor) ml.Tensor {
fOut, tOut, cOut := x.Dim(0), x.Dim(1), x.Dim(2)
// PyTorch flattens the subsampling output after [B, C, T, F] ->
// [B, T, C, F], so F must remain the fastest dimension inside each
// channel block before the linear projection.
x = x.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
return x.Reshape(ctx, cOut*fOut, tOut)
}
func (l *AudioLayer) Forward(ctx ml.Context, x ml.Tensor, validLen int, opts *AudioOptions) ml.Tensor {
residual := x
x = audioFeedForward(ctx, l.FFN1Up, l.FFN1Down, l.FFN1Norm.Forward(ctx, x, opts.eps)).Scale(ctx, 0.5)
x = residual.Add(ctx, x)
residual = x
x = l.Attention.Forward(ctx, l.AttentionNorm.Forward(ctx, x, opts.eps), validLen, opts)
x = residual.Add(ctx, x)
residual = x
x = l.Conv.Forward(ctx, l.ConvNorm.Forward(ctx, x, opts.eps), opts)
x = residual.Add(ctx, x)
residual = x
x = audioFeedForward(ctx, l.FFN2Up, l.FFN2Down, l.FFN2Norm.Forward(ctx, x, opts.eps)).Scale(ctx, 0.5)
x = residual.Add(ctx, x)
return l.OutputNorm.Forward(ctx, x, opts.eps)
}
func audioFeedForward(ctx ml.Context, up, down *nn.Linear, x ml.Tensor) ml.Tensor {
x = audioF32(ctx, up.Forward(ctx, x))
x = x.SILU(ctx)
return audioF32(ctx, down.Forward(ctx, x))
}
func (a *AudioSelfAttention) Forward(ctx ml.Context, x ml.Tensor, validLen int, opts *AudioOptions) ml.Tensor {
seqLen := x.Dim(1)
headDim := opts.headDim
numHeads := opts.numHeads
q := audioF32(ctx, a.Query.Forward(ctx, x)).Reshape(ctx, headDim, numHeads, seqLen)
k := audioF32(ctx, a.Key.Forward(ctx, x)).Reshape(ctx, headDim, numHeads, seqLen)
v := audioF32(ctx, a.Value.Forward(ctx, x)).Reshape(ctx, headDim, numHeads, seqLen)
qU := q
if a.BiasU != nil {
qU = qU.Add(ctx, audioF32(ctx, a.BiasU).Reshape(ctx, headDim, numHeads, 1))
}
qV := q
if a.BiasV != nil {
qV = qV.Add(ctx, audioF32(ctx, a.BiasV).Reshape(ctx, headDim, numHeads, 1))
}
qP := qU.Permute(ctx, 0, 2, 1, 3)
kP := k.Permute(ctx, 0, 2, 1, 3)
logits := kP.MulmatFullPrec(ctx, qP)
positionEmbeddings := parakeetPositionEmbeddings(ctx, seqLen, opts.hiddenSize)
relKey := audioF32(ctx, a.RelativeKey.Forward(ctx, positionEmbeddings)).Reshape(ctx, headDim, numHeads, 2*seqLen-1)
pP := relKey.Permute(ctx, 0, 2, 1, 3)
qVP := qV.Permute(ctx, 0, 2, 1, 3)
relLogits := pP.MulmatFullPrec(ctx, qVP)
relLogits = relativeShiftParakeet(ctx, relLogits, seqLen, numHeads)
logits = logits.Add(ctx, relLogits)
logits = logits.Scale(ctx, math.Pow(float64(headDim), -0.5))
if validLen > 0 && validLen < seqLen {
logits = logits.Add(ctx, audioAttentionMask(ctx, seqLen, validLen))
}
logits = logits.Softmax(ctx)
vP := v.Permute(ctx, 0, 2, 1, 3)
vPT := vP.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
out := vPT.Mulmat(ctx, logits)
out = out.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
out = out.Reshape(ctx, opts.hiddenSize, seqLen)
return audioF32(ctx, a.Output.Forward(ctx, out))
}
func (c *AudioConvolutionModule) Forward(ctx ml.Context, x ml.Tensor, opts *AudioOptions) ml.Tensor {
x = audioF32(ctx, c.Pointwise1.Forward(ctx, x))
hidden := x.Dim(0) / 2
value := x.Slice(ctx, 0, 0, hidden, 1).Contiguous(ctx)
gate := x.Slice(ctx, 0, hidden, 2*hidden, 1).Contiguous(ctx).Sigmoid(ctx)
x = value.Mul(ctx, gate)
x = audioDepthwiseConv1DSame(ctx, x, c.Depthwise, audioConvPadding(opts.convKernelSize))
x = c.BatchNorm.Forward(ctx, x, opts.eps)
x = x.SILU(ctx)
return audioF32(ctx, c.Pointwise2.Forward(ctx, x))
}
func audioF32(ctx ml.Context, x ml.Tensor) ml.Tensor {
if x.DType() == ml.DTypeF32 {
return x
}
// Metal binary kernels used by the audio graph require F32 operands here.
// This likely slows audio and should be revisited once the precision vs.
// speed tradeoff is validated against BF16-native elementwise paths.
return x.Cast(ctx, ml.DTypeF32)
}
func (b *AudioBatchNorm1D) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
if b == nil || b.RunningMean == nil || b.RunningVar == nil {
return x
}
hidden := x.Dim(0)
epsValues := make([]float32, hidden)
for i := range epsValues {
epsValues[i] = eps
}
variance := b.RunningVar.Add(ctx, ctx.Input().FromFloats(epsValues, hidden))
x = x.Sub(ctx, b.RunningMean)
x = x.Div(ctx, variance.Sqrt(ctx))
if b.Weight != nil {
x = x.Mul(ctx, b.Weight)
}
if b.Bias != nil {
x = x.Add(ctx, b.Bias)
}
return x
}
func forwardAudioConv2D(ctx ml.Context, conv *nn.Conv2D, x ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
weight := conv.Weight.Contiguous(ctx)
x = weight.Conv2D(ctx, x, s0, s1, p0, p1, d0, d1)
if conv.Bias != nil {
x = x.Add(ctx, conv.Bias.Reshape(ctx, 1, 1, -1))
}
return x
}
func forwardAudioDepthwiseConv2D(ctx ml.Context, conv *AudioDepthwiseConv2D, x ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
x = audioDepthwiseConv2D(ctx, x, conv.Weight, s0, s1, p0, p1, d0, d1)
if conv.Bias != nil {
x = x.Add(ctx, conv.Bias.Reshape(ctx, 1, 1, -1))
}
return x
}
func applyAudioTimeMask(ctx ml.Context, x ml.Tensor, validLen int) ml.Tensor {
if validLen <= 0 || validLen >= x.Dim(1) {
return x
}
mask := make([]float32, x.Dim(1))
for i := range validLen {
mask[i] = 1
}
return x.Mul(ctx, ctx.Input().FromFloats(mask, 1, x.Dim(1), 1, 1))
}
func audioDepthwiseConv1DSame(ctx ml.Context, x, kernel ml.Tensor, padding int) ml.Tensor {
kernelSize := kernel.Dim(0)
seqLen := x.Dim(1)
kernelT := kernel.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
var out ml.Tensor
for k := range kernelSize {
offset := k - padding
shifted := x
switch {
case offset > 0:
shifted = x.Slice(ctx, 1, offset, seqLen, 1).Contiguous(ctx)
shifted = shifted.PadExt(ctx, 0, 0, 0, offset, 0, 0, 0, 0)
case offset < 0:
shift := -offset
shifted = x.Slice(ctx, 1, 0, seqLen-shift, 1).Contiguous(ctx)
shifted = shifted.PadExt(ctx, 0, 0, shift, 0, 0, 0, 0, 0)
}
wk := kernelT.Slice(ctx, 1, k, k+1, 1).Contiguous(ctx)
term := shifted.Mul(ctx, wk)
if out == nil {
out = term
} else {
out = out.Add(ctx, term)
}
}
return out
}
func audioDepthwiseConv2D(ctx ml.Context, x, kernel ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
if d0 != 1 || d1 != 1 {
panic("audio depthwise conv2d only supports dilation 1")
}
kernel = kernel.Contiguous(ctx)
kernelW, kernelH := kernel.Dim(0), kernel.Dim(1)
outW := convOutputLength(x.Dim(0), kernelW, s0, p0)
outH := convOutputLength(x.Dim(1), kernelH, s1, p1)
padded := x.PadExt(ctx, p0, p0, p1, p1, 0, 0, 0, 0)
var out ml.Tensor
for ky := range kernelH {
for kx := range kernelW {
patch := padded.Slice(ctx, 0, kx, kx+s0*(outW-1)+1, s0).Contiguous(ctx)
patch = patch.Slice(ctx, 1, ky, ky+s1*(outH-1)+1, s1).Contiguous(ctx)
wk := kernel.Slice(ctx, 0, kx, kx+1, 1).Slice(ctx, 1, ky, ky+1, 1).Contiguous(ctx)
if wk.Dim(2) == 1 {
wk = wk.Permute(ctx, 0, 1, 3, 2).Contiguous(ctx)
} else {
wk = wk.Reshape(ctx, 1, 1, wk.Dim(2), wk.Dim(3))
}
term := patch.Mul(ctx, wk)
if out == nil {
out = term
} else {
out = out.Add(ctx, term)
}
}
}
return out
}
func convOutputLength(inputLength, kernel, stride, padding int) int {
if inputLength <= 0 {
return 0
}
return (inputLength+2*padding-kernel)/stride + 1
}
func audioConvPadding(kernel int) int {
return (kernel - 1) / 2
}
func parakeetPositionEmbeddings(ctx ml.Context, seqLen, hiddenSize int) ml.Tensor {
half := hiddenSize / 2
values := make([]float32, hiddenSize*(2*seqLen-1))
for posIdx, pos := 0, seqLen-1; posIdx < 2*seqLen-1; posIdx, pos = posIdx+1, pos-1 {
for i := range half {
invFreq := math.Pow(10000, -float64(2*i)/float64(hiddenSize))
angle := float64(pos) * invFreq
values[posIdx*hiddenSize+2*i] = float32(math.Sin(angle))
values[posIdx*hiddenSize+2*i+1] = float32(math.Cos(angle))
}
}
return ctx.Input().FromFloats(values, hiddenSize, 2*seqLen-1)
}
func relativeShiftParakeet(ctx ml.Context, x ml.Tensor, seqLen, numHeads int) ml.Tensor {
positionLen := 2*seqLen - 1
x = x.PadExt(ctx, 1, 0, 0, 0, 0, 0, 0, 0)
x = x.Reshape(ctx, seqLen, positionLen+1, numHeads)
x = x.Slice(ctx, 1, 1, positionLen+1, 1).Contiguous(ctx)
x = x.Reshape(ctx, positionLen, seqLen, numHeads)
return x.Slice(ctx, 0, 0, seqLen, 1).Contiguous(ctx)
}
func audioAttentionMask(ctx ml.Context, seqLen, validLen int) ml.Tensor {
values := make([]float32, seqLen*seqLen)
for q := range seqLen {
for k := range seqLen {
if q >= validLen || k >= validLen {
values[q*seqLen+k] = -1e9
}
}
}
return ctx.Input().FromFloats(values, seqLen, seqLen, 1)
}
func newAudioModel(c fs.Config) *AudioModel {
numLayers := int(c.Uint("audio.block_count", 0))
if numLayers == 0 {
return nil
}
return &AudioModel{
Layers: make([]AudioLayer, numLayers),
AudioOptions: newAudioOptions(c),
}
}
func newAudioProjector(c fs.Config) *AudioProjector {
if c.Uint("audio.block_count", 0) == 0 {
return nil
}
return &AudioProjector{}
}
func newAudioOptions(c fs.Config) *AudioOptions {
hiddenSize := int(c.Uint("audio.embedding_length", 1024))
numHeads := int(c.Uint("audio.attention.head_count", 8))
headDim := hiddenSize / max(1, numHeads)
return &AudioOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: headDim,
intermediateSize: int(c.Uint("audio.feed_forward_length", uint32(hiddenSize*4))),
convKernelSize: int(c.Uint("audio.conv_kernel_size", 9)),
melBins: int(c.Uint("audio.num_mel_bins", 128)),
sampleRate: int(c.Uint("audio.sample_rate", 16000)),
subsamplingKernel: int(c.Uint("audio.subsampling_conv_kernel_size", 3)),
subsamplingStride: int(c.Uint("audio.subsampling_conv_stride", 2)),
scaleInput: c.Bool("audio.scale_input", false),
eps: c.Float("audio.attention.layer_norm_epsilon", 1e-5),
}
}
func defaultAudioOptions() *AudioOptions {
return &AudioOptions{
hiddenSize: 1024,
numHeads: 8,
headDim: 128,
intermediateSize: 4096,
convKernelSize: 9,
melBins: 128,
sampleRate: 16000,
subsamplingKernel: 3,
subsamplingStride: 2,
eps: 1e-5,
}
}

View File

@@ -0,0 +1,239 @@
package nemotronh
import (
"bytes"
"errors"
"image"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type OmniModel struct {
*Model
*VisionModel `gguf:"v"`
*AudioModel `gguf:"a"`
*MultiModalProjector `gguf:"mm"`
*AudioProjector `gguf:"mm.a"`
ImageProcessor
imageTokenID int32
imageStartToken int32
imageEndToken int32
audioTokenID int32
}
var _ model.MultimodalProcessor = (*OmniModel)(nil)
func NewOmni(c fs.Config) (model.Model, error) {
textModel, err := newTextModel(c)
if err != nil {
return nil, err
}
imageTokenID := int32(c.Uint("vision.image_token_id", 18))
imageStartToken := int32(c.Uint("vision.image_start_token_id", 19))
imageEndToken := int32(c.Uint("vision.image_end_token_id", 20))
audioTokenID := int32(c.Uint("audio.sound_token_id", 27))
return &OmniModel{
Model: textModel,
VisionModel: newVisionModel(c),
AudioModel: newAudioModel(c),
MultiModalProjector: newMultiModalProjector(c),
AudioProjector: newAudioProjector(c),
ImageProcessor: newImageProcessor(c),
imageTokenID: imageTokenID,
imageStartToken: imageStartToken,
imageEndToken: imageEndToken,
audioTokenID: audioTokenID,
}, nil
}
func (m *OmniModel) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if isAudioData(multimodalData) {
return m.encodeAudioMultimodal(ctx, multimodalData)
}
if m.VisionModel == nil || m.MultiModalProjector == nil || len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
tiles, err := m.ImageProcessor.ProcessImage(img)
if err != nil {
return nil, err
}
mm := make([]input.Multimodal, 0, len(tiles))
for _, tile := range tiles {
patches := visionPatchGrid{
Width: tile.size.X / m.ImageProcessor.patchSize,
Height: tile.size.Y / m.ImageProcessor.patchSize,
}
if patches.Width == 0 || patches.Height == 0 {
return nil, errors.New("nemotron_h_omni: invalid resized image dimensions")
}
patchInput := packVisionPatchesCHW(tile.data, tile.size.X, tile.size.Y, m.ImageProcessor.numChannels, m.ImageProcessor.patchSize)
visionOutputs := m.VisionModel.ForwardPacked(ctx, patchInput, patches)
projected := m.MultiModalProjector.Forward(ctx, visionOutputs, patches)
mm = append(mm, input.Multimodal{Tensor: projected})
}
return mm, nil
}
type audioTag struct{}
func (m *OmniModel) encodeAudioMultimodal(ctx ml.Context, data []byte) ([]input.Multimodal, error) {
if m.AudioModel == nil || m.AudioProjector == nil || len(m.AudioModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
samples, err := decodeWAV(data, m.AudioModel.sampleRate)
if err != nil {
return nil, err
}
melData, frames, validFrames, err := computeParakeetMelSpectrogram(samples, m.AudioModel.FeatureExtractor, m.AudioModel.AudioOptions)
if err != nil {
return nil, err
}
melTensor := ctx.Input().FromFloats(melData, m.AudioModel.melBins, frames)
audioOutputs := m.AudioModel.ForwardAudio(ctx, melTensor, validFrames, m.AudioProjector)
return []input.Multimodal{{Tensor: audioOutputs, Data: audioTag{}}}, nil
}
func (m *OmniModel) PostLoad() error {
return nil
}
func (m *OmniModel) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
imageToken := m.imageTokenID
if imageToken == 0 {
imageToken = 18
}
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
continue
}
totalTokens := 0
for _, mm := range inp.Multimodal {
if mm.Tensor == nil {
continue
}
totalTokens += mm.Tensor.Dim(1)
}
if totalTokens <= 0 {
return nil, errors.New("nemotron_h_omni: multimodal input has no tokens")
}
if _, ok := inp.Multimodal[0].Data.(audioTag); ok {
audioToken := m.audioTokenID
if audioToken == 0 {
audioToken = 27
}
for i, mm := range inp.Multimodal {
tokenCount := 0
if mm.Tensor != nil {
tokenCount = mm.Tensor.Dim(1)
}
if tokenCount <= 0 {
return nil, errors.New("nemotron_h_omni: multimodal input has no tokens")
}
first := &input.Input{Token: audioToken, SameBatch: tokenCount - 1}
if i == 0 {
first.MultimodalHash = inp.MultimodalHash
}
first.Multimodal = []input.Multimodal{mm}
result = append(result, first)
if tokenCount > 1 {
result = append(result, slices.Repeat([]*input.Input{{Token: audioToken}}, tokenCount-1)...)
}
}
continue
}
if m.imageStartToken > 0 {
result = append(result, &input.Input{
Token: m.imageStartToken,
SameBatch: totalTokens + btoi(m.imageEndToken > 0),
})
}
for _, mm := range inp.Multimodal {
tokenCount := 0
if mm.Tensor != nil {
tokenCount = mm.Tensor.Dim(1)
}
if tokenCount <= 0 {
return nil, errors.New("nemotron_h_omni: multimodal input has no tokens")
}
result = append(result, &input.Input{
Token: imageToken,
Multimodal: []input.Multimodal{mm},
MultimodalHash: inp.MultimodalHash,
})
if tokenCount > 1 {
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, tokenCount-1)...)
}
}
if m.imageEndToken > 0 {
result = append(result, &input.Input{Token: m.imageEndToken})
}
}
return result, nil
}
func btoi(v bool) int {
if v {
return 1
}
return 0
}
func (m *OmniModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
if len(batch.Multimodal) > 0 {
hiddenStates = hiddenStates.Duplicate(ctx)
}
for _, mm := range batch.Multimodal {
offset := mm.Index
for _, multimodal := range mm.Multimodal {
if multimodal.Tensor == nil {
continue
}
tensor := multimodal.Tensor
ctx.Forward(tensor.Copy(ctx, hiddenStates.View(ctx, offset*hiddenStates.Stride(1), tensor.Dim(0)*tensor.Dim(1))))
offset += tensor.Dim(1)
}
}
return m.forwardLogits(ctx, batch, hiddenStates)
}
func init() {
model.Register("nemotron_h_omni", NewOmni)
}

View File

@@ -0,0 +1,606 @@
package nemotronh
import (
"bytes"
"encoding/base64"
"encoding/binary"
"image"
"image/color"
"math"
"os"
"path/filepath"
"slices"
"strings"
"testing"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
backendggml "github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
)
type fakeTensor struct {
*backendggml.Tensor
dims []int
}
func (t *fakeTensor) Dim(i int) int {
return t.dims[i]
}
func setupTestContext(t *testing.T) ml.Context {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "*.gguf")
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := fsggml.WriteGGUF(f, fsggml.KV{"general.architecture": "test"}, nil); err != nil {
t.Fatal(err)
}
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
ctx := b.NewContext().Input()
t.Cleanup(func() {
ctx.Close()
b.Close()
})
return ctx
}
func TestPostTokenizeImageSpans(t *testing.T) {
m := &OmniModel{
imageTokenID: 18,
imageStartToken: 19,
imageEndToken: 20,
}
makeChunk := func() input.Multimodal {
return input.Multimodal{Tensor: &fakeTensor{dims: []int{2688, 256, 1, 1}}}
}
in := []*input.Input{
{Token: 7},
{
Multimodal: []input.Multimodal{makeChunk(), makeChunk()},
MultimodalHash: 99,
},
{Token: 8},
}
out, err := m.PostTokenize(in)
if err != nil {
t.Fatalf("PostTokenize() error = %v", err)
}
if len(out) != 516 {
t.Fatalf("len(out) = %d, want 516", len(out))
}
if out[0].Token != 7 {
t.Fatalf("out[0].Token = %d, want 7", out[0].Token)
}
if out[1].Token != 19 {
t.Fatalf("out[1].Token = %d, want 19", out[1].Token)
}
if out[1].SameBatch != 513 {
t.Fatalf("out[1].SameBatch = %d, want 513", out[1].SameBatch)
}
if out[2].Token != 18 || len(out[2].Multimodal) != 1 || out[2].MultimodalHash != 99 || out[2].SameBatch != 0 {
t.Fatalf("unexpected first image token: %+v", *out[2])
}
if out[258].Token != 18 || len(out[258].Multimodal) != 1 || out[258].MultimodalHash != 99 || out[258].SameBatch != 0 {
t.Fatalf("unexpected second image token: %+v", *out[258])
}
if out[514].Token != 20 {
t.Fatalf("out[514].Token = %d, want 20", out[514].Token)
}
if out[515].Token != 8 {
t.Fatalf("out[515].Token = %d, want 8", out[515].Token)
}
}
func TestProjectorPixelShuffleMatchesReferenceV2Order(t *testing.T) {
ctx := setupTestContext(t)
hidden := 2
width := 4
height := 2
values := make([]float32, 0, hidden*width*height)
for y := range height {
for x := range width {
for c := range hidden {
values = append(values, float32(100*y+10*x+c))
}
}
}
got := pixelShuffleVisionOutputs(ctx, ctx.FromFloats(values, hidden, width*height), visionPatchGrid{
Width: width,
Height: height,
}, 2)
ctx.Forward(got).Compute(got)
want := []float32{
0, 1, 10, 11, 100, 101, 110, 111,
20, 21, 30, 31, 120, 121, 130, 131,
}
if got.Shape()[0] != 8 || got.Shape()[1] != 2 {
t.Fatalf("shape = %v, want [8 2 1]", got.Shape())
}
gotValues := got.BackendGet()
if len(gotValues) != len(want) {
t.Fatalf("len(got) = %d, want %d", len(gotValues), len(want))
}
for i := range want {
if gotValues[i] != want[i] {
t.Fatalf("got[%d] = %v, want %v", i, gotValues[i], want[i])
}
}
}
func TestPostTokenizeAudioSpans(t *testing.T) {
m := &OmniModel{
audioTokenID: 27,
}
in := []*input.Input{
{Token: 7},
{
Multimodal: []input.Multimodal{{
Tensor: &fakeTensor{dims: []int{2688, 13, 1, 1}},
Data: audioTag{},
}},
MultimodalHash: 99,
},
{Token: 8},
}
out, err := m.PostTokenize(in)
if err != nil {
t.Fatalf("PostTokenize() error = %v", err)
}
if len(out) != 15 {
t.Fatalf("len(out) = %d, want 15", len(out))
}
if out[0].Token != 7 || out[14].Token != 8 {
t.Fatalf("unexpected surrounding tokens: first=%d last=%d", out[0].Token, out[14].Token)
}
for i := 1; i <= 13; i++ {
if out[i].Token != 27 {
t.Fatalf("out[%d].Token = %d, want 27", i, out[i].Token)
}
}
if len(out[1].Multimodal) != 1 || out[1].MultimodalHash != 99 {
t.Fatalf("first audio token did not carry multimodal payload: %+v", *out[1])
}
if out[1].SameBatch != 12 {
t.Fatalf("first audio token SameBatch = %d, want 12", out[1].SameBatch)
}
if len(out[2].Multimodal) != 0 {
t.Fatalf("only the first audio token should carry multimodal payload: %+v", *out[2])
}
}
func TestParakeetAudioPreprocessShapes(t *testing.T) {
data := sineWAV(t, 16000, 440, 1.0)
samples, err := decodeWAV(data, 16000)
if err != nil {
t.Fatal(err)
}
if got, want := len(samples), 16000; got != want {
t.Fatalf("sample count = %d, want %d", got, want)
}
mel, frames, validFrames, err := computeParakeetMelSpectrogram(samples, nil, defaultAudioOptions())
if err != nil {
t.Fatal(err)
}
if frames != 101 {
t.Fatalf("frames = %d, want 101", frames)
}
if validFrames != 100 {
t.Fatalf("validFrames = %d, want 100", validFrames)
}
if len(mel) != 101*128 {
t.Fatalf("len(mel) = %d, want %d", len(mel), 101*128)
}
lastFrame := mel[100*128 : 101*128]
if !slices.Equal(lastFrame, make([]float32, 128)) {
t.Fatal("expected masked final frame to be zero")
}
}
func TestParakeetAudioPreprocessMatchesIntegrationWAVReference(t *testing.T) {
data := integrationAudioWAV(t)
samples, err := decodeWAV(data, 16000)
if err != nil {
t.Fatal(err)
}
if got, want := len(samples), 42083; got != want {
t.Fatalf("sample count = %d, want %d", got, want)
}
mel, frames, validFrames, err := computeParakeetMelSpectrogram(samples, nil, defaultAudioOptions())
if err != nil {
t.Fatal(err)
}
if frames != 264 {
t.Fatalf("frames = %d, want 264", frames)
}
if validFrames != 263 {
t.Fatalf("validFrames = %d, want 263", validFrames)
}
if len(mel) != 264*128 {
t.Fatalf("len(mel) = %d, want %d", len(mel), 264*128)
}
lastFrame := mel[263*128 : 264*128]
if !slices.Equal(lastFrame, make([]float32, 128)) {
t.Fatal("expected masked final frame to be zero")
}
// Reference values come from the ParakeetExtractor path used by vLLM:
// pre-emphasis, torch.stft(center=True, pad_mode="constant"), Slaney mel
// filters, log guard 2^-24, and per-mel normalization over valid frames.
checks := map[[2]int]float32{
{0, 0}: -1.0855197,
{0, 50}: -0.93212974,
{1, 10}: -0.9735168,
{2, 100}: -0.6533053,
{50, 0}: 2.2483668,
{50, 127}: -0.3828735,
{100, 50}: 2.9742377,
{262, 0}: -0.9521758,
{262, 127}: -0.4602786,
{263, 50}: 0,
}
for pos, want := range checks {
got := mel[pos[0]*128+pos[1]]
if math.Abs(float64(got-want)) > 1e-4 {
t.Errorf("mel[%d,%d] = %v, want %v", pos[0], pos[1], got, want)
}
}
}
func integrationAudioWAV(t *testing.T) []byte {
t.Helper()
path := filepath.Join("..", "..", "..", "integration", "audio_test_data_test.go")
b, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
const marker = "const audioEncodingPrompt = `"
s := string(b)
start := strings.Index(s, marker)
if start < 0 {
t.Fatal("audioEncodingPrompt marker not found")
}
start += len(marker)
end := strings.Index(s[start:], "`")
if end < 0 {
t.Fatal("audioEncodingPrompt terminator not found")
}
data, err := base64.StdEncoding.DecodeString(strings.TrimSpace(s[start : start+end]))
if err != nil {
t.Fatal(err)
}
return data
}
func TestRelativeShiftParakeetMatchesReference(t *testing.T) {
ctx := setupTestContext(t)
seqLen := 3
positionLen := 2*seqLen - 1
values := make([]float32, seqLen*positionLen)
for q := range seqLen {
for p := range positionLen {
values[q*positionLen+p] = float32(q*10 + p)
}
}
x := ctx.FromFloats(values, positionLen, seqLen, 1)
got := relativeShiftParakeet(ctx, x, seqLen, 1)
ctx.Forward(got).Compute(got)
want := []float32{
2, 3, 4,
11, 12, 13,
20, 21, 22,
}
if !slices.Equal(got.BackendGet(), want) {
t.Fatalf("relative shift mismatch:\n got %v\nwant %v", got.BackendGet(), want)
}
}
func TestAudioDepthwiseConv2DMatchesReference(t *testing.T) {
ctx := setupTestContext(t)
freq, frames, channels := 4, 5, 2
xValues := make([]float32, freq*frames*channels)
for i := range xValues {
xValues[i] = float32(i)/10 - 1
}
kernelValues := make([]float32, 3*3*channels)
for i := range kernelValues {
kernelValues[i] = float32(i)/7 - 1
}
x := ctx.FromFloats(xValues, freq, frames, channels, 1)
kernel := ctx.FromFloats(kernelValues, 3, 3, 1, channels)
bias := ctx.FromFloats([]float32{0.25, -0.5}, channels)
got := audioDepthwiseConv2D(ctx, x, kernel, 2, 2, 1, 1, 1, 1).Add(ctx, bias.Reshape(ctx, 1, 1, -1))
ctx.Forward(got).Compute(got)
want := []float32{
0.86428565, 1.3357141,
1.2785715, 1.3642857,
-0.5928571, -1.7499999,
5.4000001, 8.8142853,
10.514286, 16.042856,
6.6857138, 9.8428574,
}
assertCloseSlice(t, got.BackendGet(), want, 1e-5)
}
func TestFlattenAudioSubsamplingOutputMatchesReference(t *testing.T) {
ctx := setupTestContext(t)
const (
freq = 2
frames = 3
channels = 2
)
values := make([]float32, freq*frames*channels)
for c := range channels {
for t := range frames {
for f := range freq {
values[f+freq*(t+frames*c)] = float32(100*c + 10*t + f)
}
}
}
got := flattenAudioSubsamplingOutput(ctx, ctx.FromFloats(values, freq, frames, channels, 1))
ctx.Forward(got).Compute(got)
want := []float32{
0, 1, 100, 101,
10, 11, 110, 111,
20, 21, 120, 121,
}
assertCloseSlice(t, got.BackendGet(), want, 0)
}
func TestAudioDepthwiseConv1DMatchesReference(t *testing.T) {
ctx := setupTestContext(t)
xValues := make([]float32, 2*5)
for i := range xValues {
xValues[i] = float32(i)/5 - 0.7
}
kernelValues := make([]float32, 3*2)
for i := range kernelValues {
kernelValues[i] = float32(i)/3 - 0.5
}
x := ctx.FromFloats(xValues, 2, 5)
kernel := ctx.FromFloats(kernelValues, 3, 2)
got := audioDepthwiseConv1DSame(ctx, x, kernel, 1)
ctx.Forward(got).Compute(got)
want := []float32{
0.066666655, -0.5333333,
0.41666666, 0.016666688,
0.21666668, 1.0166667,
0.01666667, 2.0166664,
-0.40000004, 1.2666667,
}
assertCloseSlice(t, got.BackendGet(), want, 1e-5)
}
func TestAudioSelfAttentionMatchesReference(t *testing.T) {
ctx := setupTestContext(t)
const (
hiddenSize = 4
numHeads = 2
headDim = 2
seqLen = 3
)
xValues := make([]float32, hiddenSize*seqLen)
for i := range xValues {
xValues[i] = float32(i)/10 - 0.5
}
identity := make([]float32, hiddenSize*hiddenSize)
for i := range hiddenSize {
identity[i*hiddenSize+i] = 1
}
linear := func() *nn.Linear {
return &nn.Linear{Weight: ctx.FromFloats(identity, hiddenSize, hiddenSize)}
}
attn := &AudioSelfAttention{
Query: linear(),
Key: linear(),
Value: linear(),
Output: linear(),
RelativeKey: linear(),
BiasU: ctx.FromFloats([]float32{0.1, -0.2, 0.3, -0.4}, headDim, numHeads),
BiasV: ctx.FromFloats([]float32{-0.05, 0.07, 0.11, -0.13}, headDim, numHeads),
}
got := attn.Forward(ctx, ctx.FromFloats(xValues, hiddenSize, seqLen), seqLen, &AudioOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: headDim,
})
ctx.Forward(got).Compute(got)
want := []float32{
-0.08471569, 0.015284289, 0.05532019, 0.1553202,
-0.09135241, 0.008647568, 0.11468154, 0.21468155,
-0.019152153, 0.08084783, 0.1733382, 0.2733382,
}
assertCloseSlice(t, got.BackendGet(), want, 1e-5)
}
func assertCloseSlice(t *testing.T, got, want []float32, tolerance float64) {
t.Helper()
if len(got) != len(want) {
t.Fatalf("len(got) = %d, want %d", len(got), len(want))
}
for i := range want {
if math.Abs(float64(got[i]-want[i])) > tolerance {
t.Fatalf("got[%d] = %v, want %v\nall got: %v", i, got[i], want[i], got)
}
}
}
func TestPackPatchesCHW(t *testing.T) {
values := []float32{
0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
100, 101, 102, 103,
104, 105, 106, 107,
108, 109, 110, 111,
112, 113, 114, 115,
}
got := packVisionPatchesCHW(values, 4, 4, 2, 2)
want := []float32{
0, 1, 4, 5, 100, 101, 104, 105,
2, 3, 6, 7, 102, 103, 106, 107,
8, 9, 12, 13, 108, 109, 112, 113,
10, 11, 14, 15, 110, 111, 114, 115,
}
if len(got) != len(want) {
t.Fatalf("len(got) = %d, want %d", len(got), len(want))
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("got[%d] = %v, want %v", i, got[i], want[i])
}
}
}
func TestResizePositionEmbeddingMatchesReferenceInterpolation(t *testing.T) {
values := []float32{
0, 10,
20, 30,
}
got := resizePositionEmbedding(values, 1, 2, 2, 3, 3)
want := []float32{
0, 5, 10,
10, 15, 20,
20, 25, 30,
}
if len(got) != len(want) {
t.Fatalf("len(got) = %d, want %d", len(got), len(want))
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("got[%d] = %v, want %v", i, got[i], want[i])
}
}
}
func TestDynamicImageProcessorMatchesReferencePatchBudget(t *testing.T) {
p := ImageProcessor{
imageSize: 512,
patchSize: 16,
numChannels: 3,
minNumPatches: 1024,
maxNumPatches: 13312,
projectorScale: 2,
imageMean: [3]float32{0.48145466, 0.4578275, 0.40821073},
imageStd: [3]float32{0.26862954, 0.26130258, 0.27577711},
}
img := image.NewRGBA(image.Rect(0, 0, 400, 250))
bounds := img.Bounds()
width, height := bounds.Dx(), bounds.Dy()
for y := range height {
for x := range width {
img.SetRGBA(x, y, color.RGBA{R: uint8(x), G: uint8(y), B: 128, A: 255})
}
}
tiles, err := p.ProcessImage(img)
if err != nil {
t.Fatalf("ProcessImage() error = %v", err)
}
if got, want := len(tiles), 1; got != want {
t.Fatalf("len(tiles) = %d, want %d", got, want)
}
if got, want := tiles[0].size, (image.Point{X: 672, Y: 416}); got != want {
t.Fatalf("tile size = %v, want %v", got, want)
}
if got, want := len(tiles[0].data), 3*672*416; got != want {
t.Fatalf("tile data len = %d, want %d", got, want)
}
}
func sineWAV(t *testing.T, sampleRate int, frequency float64, seconds float64) []byte {
t.Helper()
samples := int(float64(sampleRate) * seconds)
var pcm bytes.Buffer
for i := range samples {
v := int16(math.Sin(2*math.Pi*frequency*float64(i)/float64(sampleRate)) * 32767)
if err := binary.Write(&pcm, binary.LittleEndian, v); err != nil {
t.Fatal(err)
}
}
var out bytes.Buffer
out.WriteString("RIFF")
if err := binary.Write(&out, binary.LittleEndian, uint32(36+pcm.Len())); err != nil {
t.Fatal(err)
}
out.WriteString("WAVE")
out.WriteString("fmt ")
if err := binary.Write(&out, binary.LittleEndian, uint32(16)); err != nil {
t.Fatal(err)
}
if err := binary.Write(&out, binary.LittleEndian, uint16(1)); err != nil {
t.Fatal(err)
}
if err := binary.Write(&out, binary.LittleEndian, uint16(1)); err != nil {
t.Fatal(err)
}
if err := binary.Write(&out, binary.LittleEndian, uint32(sampleRate)); err != nil {
t.Fatal(err)
}
if err := binary.Write(&out, binary.LittleEndian, uint32(sampleRate*2)); err != nil {
t.Fatal(err)
}
if err := binary.Write(&out, binary.LittleEndian, uint16(2)); err != nil {
t.Fatal(err)
}
if err := binary.Write(&out, binary.LittleEndian, uint16(16)); err != nil {
t.Fatal(err)
}
out.WriteString("data")
if err := binary.Write(&out, binary.LittleEndian, uint32(pcm.Len())); err != nil {
t.Fatal(err)
}
out.Write(pcm.Bytes())
return out.Bytes()
}

View File

@@ -0,0 +1,348 @@
package nemotronh
import (
"math"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
const nemotronVisionBatchSize = 1
type visionPatchGrid struct {
Width int
Height int
}
type VisionPatchEmbedding struct {
*nn.Linear
}
func packVisionPatchesCHW(values []float32, width, height, channels, patchSize int) []float32 {
patchesX, patchesY := width/patchSize, height/patchSize
patchDim := channels * patchSize * patchSize
plane := width * height
patches := make([]float32, patchDim*patchesX*patchesY)
offset := 0
for py := range patchesY {
for px := range patchesX {
for c := range channels {
channelBase := c * plane
for yy := range patchSize {
rowBase := (py*patchSize + yy) * width
for xx := range patchSize {
patches[offset] = values[channelBase+rowBase+px*patchSize+xx]
offset++
}
}
}
}
}
return patches
}
func (p *VisionPatchEmbedding) ForwardPacked(ctx ml.Context, patches []float32, patchDim, numPatches int) ml.Tensor {
hiddenState := ctx.Input().FromFloats(patches, patchDim, numPatches)
hiddenState = hiddenState.Duplicate(ctx)
return p.Linear.Forward(ctx, hiddenState)
}
func (p *VisionPatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSize int) ml.Tensor {
// Match the RADIO patch generator's exact flattening order: patches are laid
// out token-major with each token packed as channel, then patch-row, then
// patch-col. This is more explicit than the prior IM2Col path and likely
// slower, but it avoids backend-specific packing differences that caused the
// converted patch embedder to diverge badly from the reference model.
width, height, channels := pixelValues.Dim(0), pixelValues.Dim(1), pixelValues.Dim(2)
patchesX, patchesY := width/patchSize, height/patchSize
patchDim := channels * patchSize * patchSize
values := pixelValues.BackendGet()
return p.ForwardPacked(ctx, packVisionPatchesCHW(values, width, height, channels, patchSize), patchDim, patchesX*patchesY)
}
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), nemotronVisionBatchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), nemotronVisionBatchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), nemotronVisionBatchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), nemotronVisionBatchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
return mlp.Down.Forward(ctx, mlp.Up.Forward(ctx, hiddenState).GELU(ctx))
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"ln2"`
MLP *VisionMLP
}
func (l *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState)
return hiddenState.Add(ctx, residual)
}
type VisionOptions struct {
hiddenSize int
numHeads int
imageSize int
patchSize int
eps float32
}
type VisionModel struct {
PatchEmbedding *VisionPatchEmbedding `gguf:"patch_embd"`
PositionEmbedding ml.Tensor `gguf:"position_embd"`
ClassEmbedding ml.Tensor `gguf:"cls_embd"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionOptions
resizedPositionEmbeddingsMu sync.Mutex
resizedPositionEmbeddings map[visionPatchGrid][]float32
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, patches visionPatchGrid) ml.Tensor {
numPatches := patches.Width * patches.Height
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize)
return m.forwardPatchEmbeddings(ctx, hiddenState, patches, numPatches)
}
func (m *VisionModel) ForwardPacked(ctx ml.Context, patchValues []float32, patches visionPatchGrid) ml.Tensor {
numPatches := patches.Width * patches.Height
patchDim := 0
if numPatches > 0 {
patchDim = len(patchValues) / numPatches
}
hiddenState := m.PatchEmbedding.ForwardPacked(ctx, patchValues, patchDim, numPatches)
return m.forwardPatchEmbeddings(ctx, hiddenState, patches, numPatches)
}
func (m *VisionModel) forwardPatchEmbeddings(ctx ml.Context, hiddenState ml.Tensor, patches visionPatchGrid, numPatches int) ml.Tensor {
if m.PositionEmbedding != nil {
positionEmbeddings := m.positionEmbeddings(ctx, hiddenState, patches, numPatches)
hiddenState = hiddenState.Add(ctx, positionEmbeddings)
}
if m.ClassEmbedding != nil {
numPrefixTokens := m.ClassEmbedding.Dim(1)
classEmbeddings := m.ClassEmbedding.Cast(ctx, hiddenState.DType())
classEmbeddings = classEmbeddings.Reshape(ctx, classEmbeddings.Dim(0), numPrefixTokens, 1)
hiddenState = classEmbeddings.Concat(ctx, hiddenState, 1)
}
for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, m.VisionOptions)
}
if m.ClassEmbedding != nil {
hiddenState = hiddenState.Slice(ctx, 1, m.ClassEmbedding.Dim(1), hiddenState.Dim(1), 1)
}
return hiddenState.Reshape(ctx, hiddenState.Dim(0), hiddenState.Dim(1))
}
func (m *VisionModel) positionEmbeddings(ctx ml.Context, hiddenState ml.Tensor, patches visionPatchGrid, numPatches int) ml.Tensor {
posTokens := m.PositionEmbedding.Dim(1)
source := int(math.Sqrt(float64(posTokens)))
positionEmbeddings := m.PositionEmbedding.Cast(ctx, hiddenState.DType())
if !(source > 0 && source*source == posTokens && (source != patches.Width || source != patches.Height)) {
if positionEmbeddings.Dim(1) > numPatches {
positionEmbeddings = positionEmbeddings.Slice(ctx, 1, 0, numPatches, 1)
}
return positionEmbeddings
}
if cached, ok := m.cachePositionEmbeddings(ctx, hiddenState.Dim(0), patches); ok {
return ctx.Input().FromFloats(cached, hiddenState.Dim(0), numPatches)
}
// Runner fit/reserve builds worst-case multimodal graphs before weights are
// loaded, so the align-corners CPU cache path cannot materialize source
// values there. Fall back to a graph-only bilinear resize for reservation;
// the loaded inference path above still uses the cached align-corners data.
positionEmbeddings = positionEmbeddings.Reshape(ctx, -1, source, source)
positionEmbeddings = positionEmbeddings.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
positionEmbeddings = positionEmbeddings.Interpolate(ctx, [4]int{
patches.Width,
patches.Height,
hiddenState.Dim(0),
1,
}, ml.SamplingModeBilinear)
positionEmbeddings = positionEmbeddings.Permute(ctx, 1, 2, 0, 3)
return positionEmbeddings.Contiguous(ctx, -1, patches.Width*patches.Height)
}
func (m *VisionModel) cachePositionEmbeddings(ctx ml.Context, hidden int, patches visionPatchGrid) ([]float32, bool) {
m.resizedPositionEmbeddingsMu.Lock()
cached := m.resizedPositionEmbeddings[patches]
m.resizedPositionEmbeddingsMu.Unlock()
if cached != nil {
return cached, true
}
if len(m.PositionEmbedding.Bytes()) == 0 {
return nil, false
}
posTokens := m.PositionEmbedding.Dim(1)
source := int(math.Sqrt(float64(posTokens)))
positionEmbeddingsF32 := m.PositionEmbedding.Cast(ctx, ml.DTypeF32)
ctx.Forward(positionEmbeddingsF32).Compute(positionEmbeddingsF32)
// RADIO eval-time CPE uses bilinear interpolation with align_corners=false.
// Cache a CPU-resized token-major embedding here for correctness first. This
// is likely slower than a native graph path and should be revisited if this
// precision vs speed tradeoff is not worthwhile.
cached = resizePositionEmbedding(positionEmbeddingsF32.Floats(), hidden, source, source, patches.Width, patches.Height)
m.resizedPositionEmbeddingsMu.Lock()
if m.resizedPositionEmbeddings == nil {
m.resizedPositionEmbeddings = make(map[visionPatchGrid][]float32)
}
if existing := m.resizedPositionEmbeddings[patches]; existing != nil {
cached = existing
} else {
m.resizedPositionEmbeddings[patches] = cached
}
m.resizedPositionEmbeddingsMu.Unlock()
return cached, true
}
func resizePositionEmbedding(values []float32, hidden, sourceWidth, sourceHeight, targetWidth, targetHeight int) []float32 {
out := make([]float32, hidden*targetWidth*targetHeight)
scaleX := float64(sourceWidth) / float64(targetWidth)
scaleY := float64(sourceHeight) / float64(targetHeight)
for oy := range targetHeight {
srcY := scaleY*(float64(oy)+0.5) - 0.5
y0 := int(math.Floor(srcY))
y1 := min(y0+1, sourceHeight-1)
wy := float32(srcY - float64(y0))
y0 = max(y0, 0)
for ox := range targetWidth {
srcX := scaleX*(float64(ox)+0.5) - 0.5
x0 := int(math.Floor(srcX))
x1 := min(x0+1, sourceWidth-1)
wx := float32(srcX - float64(x0))
x0 = max(x0, 0)
t00 := (y0*sourceWidth + x0) * hidden
t01 := (y0*sourceWidth + x1) * hidden
t10 := (y1*sourceWidth + x0) * hidden
t11 := (y1*sourceWidth + x1) * hidden
dst := (oy*targetWidth + ox) * hidden
for h := range hidden {
v00 := values[t00+h]
v01 := values[t01+h]
v10 := values[t10+h]
v11 := values[t11+h]
top := v00 + (v01-v00)*wx
bot := v10 + (v11-v10)*wx
out[dst+h] = top + (bot-top)*wy
}
}
}
return out
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
VisionOptions: &VisionOptions{
hiddenSize: int(c.Uint("vision.embedding_length", 1280)),
numHeads: int(c.Uint("vision.attention.head_count", 16)),
imageSize: int(c.Uint("vision.image_size", 512)),
patchSize: int(c.Uint("vision.patch_size", 16)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
},
}
}
type MultiModalProjector struct {
Norm *nn.RMSNorm `gguf:"norm"`
Linear1 *nn.Linear `gguf:"1"`
Linear2 *nn.Linear `gguf:"2"`
scaleFactor int
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, patches visionPatchGrid) ml.Tensor {
scaleFactor := max(p.scaleFactor, 1)
// The reference projector first pixel-shuffles the vision grid with
// downsample_ratio=0.5 before applying the RMSNorm/MLP. Preserve that exact
// v2 packing order here rather than flattening 2x2 neighborhoods via IM2Col.
merged := pixelShuffleVisionOutputs(ctx, visionOutputs, patches, scaleFactor)
merged = p.Norm.Forward(ctx, merged, 1e-5)
merged = p.Linear1.Forward(ctx, merged)
merged = merged.RELU(ctx)
merged = merged.Mul(ctx, merged)
return p.Linear2.Forward(ctx, merged)
}
func pixelShuffleVisionOutputs(ctx ml.Context, visionOutputs ml.Tensor, patches visionPatchGrid, scaleFactor int) ml.Tensor {
hiddenSize := visionOutputs.Dim(0)
scaleFactor = max(scaleFactor, 1)
merged := visionOutputs.Reshape(ctx, hiddenSize, patches.Width, patches.Height, 1)
width := patches.Width / scaleFactor
height := patches.Height / scaleFactor
channels := hiddenSize * scaleFactor
merged = merged.Reshape(ctx, channels, width, patches.Height, 1)
merged = merged.Reshape(ctx, channels, width, scaleFactor, height)
merged = merged.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
return merged.Reshape(ctx, channels*scaleFactor, width*height, 1)
}
func newMultiModalProjector(c fs.Config) *MultiModalProjector {
return &MultiModalProjector{
scaleFactor: int(c.Uint("vision.projector.scale_factor", 2)),
}
}

View File

@@ -0,0 +1,328 @@
package nemotronh
import (
"encoding/binary"
"fmt"
"math"
"math/cmplx"
)
const (
parakeetHopLength = 160
parakeetNFFT = 512
parakeetWinLength = 400
parakeetPreemphasis = 0.97
parakeetLogZeroGuardValue = 1.0 / (1 << 24)
parakeetNormalizeEps = 1e-5
)
func isAudioData(data []byte) bool {
return len(data) >= 12 && string(data[:4]) == "RIFF" && string(data[8:12]) == "WAVE"
}
func decodeWAV(data []byte, targetSampleRate int) ([]float32, error) {
if len(data) < 12 {
return nil, fmt.Errorf("WAV file too short")
}
if !isAudioData(data) {
return nil, fmt.Errorf("not a WAV file")
}
var audioFormat uint16
var numChannels, sampleRate, bitsPerSample int
var audioData []byte
foundFmt := false
offset := 12
for offset+8 <= len(data) {
chunkID := string(data[offset : offset+4])
chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
chunkEnd := min(offset+8+chunkSize, len(data))
chunkData := data[offset+8 : chunkEnd]
switch chunkID {
case "fmt ":
if len(chunkData) < 16 {
return nil, fmt.Errorf("fmt chunk too short")
}
audioFormat = binary.LittleEndian.Uint16(chunkData[0:2])
numChannels = int(binary.LittleEndian.Uint16(chunkData[2:4]))
sampleRate = int(binary.LittleEndian.Uint32(chunkData[4:8]))
bitsPerSample = int(binary.LittleEndian.Uint16(chunkData[14:16]))
if audioFormat == 0xfffe && len(chunkData) >= 26 {
audioFormat = binary.LittleEndian.Uint16(chunkData[24:26])
}
foundFmt = true
case "data":
audioData = chunkData
}
offset += 8 + chunkSize
if chunkSize%2 != 0 {
offset++
}
}
if !foundFmt {
return nil, fmt.Errorf("no fmt chunk found in WAV file")
}
if audioFormat != 1 && audioFormat != 3 {
return nil, fmt.Errorf("unsupported WAV format: %d (need PCM=1 or float=3)", audioFormat)
}
if audioData == nil {
return nil, fmt.Errorf("no data chunk found in WAV file")
}
if numChannels <= 0 {
return nil, fmt.Errorf("invalid WAV channel count: %d", numChannels)
}
samples := decodeWAVSamples(audioData, audioFormat, bitsPerSample, numChannels)
if sampleRate != targetSampleRate {
samples = resampleLinear(samples, sampleRate, targetSampleRate)
}
return samples, nil
}
func decodeWAVSamples(data []byte, format uint16, bits, channels int) []float32 {
bytesPerSample := bits / 8
if bytesPerSample <= 0 || channels <= 0 {
return nil
}
totalSamples := len(data) / (bytesPerSample * channels)
mono := make([]float32, totalSamples)
for i := range totalSamples {
var sum float64
for ch := range channels {
off := (i*channels + ch) * bytesPerSample
if off+bytesPerSample > len(data) {
break
}
switch {
case format == 1 && bits == 16:
v := int16(binary.LittleEndian.Uint16(data[off : off+2]))
sum += float64(v) / 32768.0
case format == 1 && bits == 32:
v := int32(binary.LittleEndian.Uint32(data[off : off+4]))
sum += float64(v) / 2147483648.0
case format == 1 && bits == 24:
v := int32(data[off]) | int32(data[off+1])<<8 | int32(data[off+2])<<16
if v&0x800000 != 0 {
v |= ^0xffffff
}
sum += float64(v) / 8388608.0
case format == 3 && bits == 32:
sum += float64(math.Float32frombits(binary.LittleEndian.Uint32(data[off : off+4])))
case format == 1 && bits == 8:
sum += (float64(data[off]) - 128.0) / 128.0
}
}
mono[i] = float32(sum / float64(channels))
}
return mono
}
func resampleLinear(samples []float32, fromRate, toRate int) []float32 {
if fromRate <= 0 || toRate <= 0 || len(samples) == 0 {
return samples
}
n := int(float64(len(samples)) / float64(fromRate) * float64(toRate))
if n <= 1 {
return slicesCloneOne(samples)
}
out := make([]float32, n)
for i := range n {
pos := float64(i) * float64(len(samples)-1) / float64(n-1)
idx := int(pos)
frac := float32(pos - float64(idx))
if idx+1 < len(samples) {
out[i] = samples[idx]*(1-frac) + samples[idx+1]*frac
} else {
out[i] = samples[idx]
}
}
return out
}
func slicesCloneOne(samples []float32) []float32 {
if len(samples) == 0 {
return nil
}
return []float32{samples[0]}
}
func computeParakeetMelSpectrogram(samples []float32, extractor *AudioFeatureExtractor, opts *AudioOptions) ([]float32, int, int, error) {
if len(samples) == 0 {
return nil, 0, 0, fmt.Errorf("audio too short to encode")
}
if opts == nil {
opts = defaultAudioOptions()
}
melBins := opts.melBins
freqBins := parakeetNFFT/2 + 1
window, melFilters := extractor.windowAndFilters(melBins, freqBins, opts.sampleRate)
if len(window) != parakeetWinLength {
return nil, 0, 0, fmt.Errorf("invalid Parakeet window length: %d", len(window))
}
if len(melFilters) != melBins*freqBins {
return nil, 0, 0, fmt.Errorf("invalid Parakeet mel filter shape: %d", len(melFilters))
}
emphasized := make([]float32, len(samples))
emphasized[0] = samples[0]
for i := 1; i < len(samples); i++ {
emphasized[i] = samples[i] - parakeetPreemphasis*samples[i-1]
}
frames := len(samples)/parakeetHopLength + 1
validFrames := max(1, len(samples)/parakeetHopLength)
if validFrames > frames {
validFrames = frames
}
result := make([]float32, frames*melBins)
fftInput := make([]complex128, parakeetNFFT)
winOffset := (parakeetNFFT - parakeetWinLength) / 2
centerPad := parakeetNFFT / 2
for frame := range frames {
for i := range parakeetNFFT {
fftInput[i] = 0
}
for i := range parakeetWinLength {
src := frame*parakeetHopLength + i + winOffset - centerPad
if src >= 0 && src < len(emphasized) {
fftInput[i+winOffset] = complex(float64(emphasized[src])*float64(window[i]), 0)
}
}
fft(fftInput)
for mel := range melBins {
var v float64
filterOffset := mel * freqBins
for freq := range freqBins {
mag := cmplx.Abs(fftInput[freq])
v += float64(melFilters[filterOffset+freq]) * mag * mag
}
result[frame*melBins+mel] = float32(math.Log(v + parakeetLogZeroGuardValue))
}
}
for mel := range melBins {
var sum float64
for frame := range validFrames {
sum += float64(result[frame*melBins+mel])
}
mean := sum / float64(validFrames)
var variance float64
for frame := range validFrames {
d := float64(result[frame*melBins+mel]) - mean
variance += d * d
}
denom := max(1, validFrames-1)
std := math.Sqrt(variance / float64(denom))
for frame := range frames {
idx := frame*melBins + mel
if frame >= validFrames {
result[idx] = 0
continue
}
result[idx] = float32((float64(result[idx]) - mean) / (std + parakeetNormalizeEps))
}
}
return result, frames, validFrames, nil
}
func defaultParakeetWindow() []float32 {
window := make([]float32, parakeetWinLength)
for i := range window {
window[i] = float32(0.5 - 0.5*math.Cos(2*math.Pi*float64(i)/float64(parakeetWinLength-1)))
}
return window
}
func buildSlaneyMelFilterBank(numFreqBins, numMels int, sampleRate int) []float32 {
hzToMel := func(f float64) float64 {
if f < 1000 {
return 3 * f / 200
}
return 15 + math.Log(f/1000)*27/math.Log(6.4)
}
melToHz := func(m float64) float64 {
if m < 15 {
return 200 * m / 3
}
return 1000 * math.Exp(math.Log(6.4)*(m-15)/27)
}
minMel := hzToMel(0)
maxMel := hzToMel(float64(sampleRate) / 2)
mels := make([]float64, numMels+2)
freqs := make([]float64, numMels+2)
for i := range mels {
mels[i] = minMel + (maxMel-minMel)*float64(i)/float64(numMels+1)
freqs[i] = melToHz(mels[i])
}
fftFreqs := make([]float64, numFreqBins)
for i := range fftFreqs {
fftFreqs[i] = float64(i) * float64(sampleRate) / float64(parakeetNFFT)
}
filters := make([]float32, numMels*numFreqBins)
for mel := range numMels {
left, center, right := freqs[mel], freqs[mel+1], freqs[mel+2]
enorm := 2.0 / (right - left)
for freq, fftFreq := range fftFreqs {
var lower, upper float64
if center > left {
lower = (fftFreq - left) / (center - left)
}
if right > center {
upper = (right - fftFreq) / (right - center)
}
v := math.Max(0, math.Min(lower, upper))
filters[mel*numFreqBins+freq] = float32(v * enorm)
}
}
return filters
}
func fft(x []complex128) {
n := len(x)
if n <= 1 {
return
}
j := 0
for i := 1; i < n; i++ {
bit := n >> 1
for j&bit != 0 {
j ^= bit
bit >>= 1
}
j ^= bit
if i < j {
x[i], x[j] = x[j], x[i]
}
}
for size := 2; size <= n; size <<= 1 {
halfSize := size / 2
w := complex(math.Cos(2*math.Pi/float64(size)), -math.Sin(2*math.Pi/float64(size)))
for start := 0; start < n; start += size {
wn := complex(1, 0)
for k := range halfSize {
t := wn * x[start+k+halfSize]
x[start+k+halfSize] = x[start+k] - t
x[start+k] = x[start+k] + t
wn *= w
}
}
}
}

498
model/parsers/laguna.go Normal file
View File

@@ -0,0 +1,498 @@
package parsers
import (
"encoding/json"
"fmt"
"regexp"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
const (
lagunaThinkingOpenTag = "<think>"
lagunaThinkingCloseTag = "</think>"
lagunaToolCallOpenTag = "<tool_call>"
lagunaToolCallCloseTag = "</tool_call>"
lagunaUserOpenTag = "<user>"
lagunaUserCloseTag = "</user>"
)
type lagunaParserState int
const (
lagunaParserStateThinking lagunaParserState = iota
lagunaParserStateContent
lagunaParserStateTool
)
type LagunaParser struct {
state lagunaParserState
buffer strings.Builder
tools []api.Tool
callIndex int
thinkingEnabled bool
thinkingSuppressed bool
allowLeadingThinkOpen bool
}
func (p *LagunaParser) HasToolSupport() bool {
return true
}
func (p *LagunaParser) HasThinkingSupport() bool {
return true
}
func (p *LagunaParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.callIndex = 0
p.buffer.Reset()
p.thinkingEnabled = thinkValue == nil || thinkValue.Bool()
p.thinkingSuppressed = thinkValue != nil && !thinkValue.Bool()
p.state = lagunaParserStateContent
p.allowLeadingThinkOpen = false
return tools
}
func (p *LagunaParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
var contentSB, thinkingSB strings.Builder
for {
progress := false
switch p.state {
case lagunaParserStateThinking:
progress, thinking = p.consumeThinking(done)
if p.thinkingEnabled {
thinkingSB.WriteString(thinking)
}
case lagunaParserStateContent:
var parsedCalls []api.ToolCall
progress, content, parsedCalls, err = p.consumeContent(done)
if err != nil {
return "", "", nil, err
}
contentSB.WriteString(content)
calls = append(calls, parsedCalls...)
case lagunaParserStateTool:
var call api.ToolCall
progress, call, err = p.consumeTool(done)
if err != nil {
return "", "", nil, err
}
if progress {
calls = append(calls, call)
}
}
if !progress {
break
}
}
return contentSB.String(), thinkingSB.String(), calls, nil
}
func (p *LagunaParser) consumeThinking(done bool) (bool, string) {
acc := p.buffer.String()
if p.allowLeadingThinkOpen {
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
if strings.HasPrefix(trimmed, lagunaThinkingOpenTag) {
p.buffer.Reset()
p.buffer.WriteString(strings.TrimLeftFunc(strings.TrimPrefix(trimmed, lagunaThinkingOpenTag), unicode.IsSpace))
p.allowLeadingThinkOpen = false
return true, ""
}
if strings.HasPrefix(lagunaThinkingOpenTag, trimmed) && !done {
return false, ""
}
p.allowLeadingThinkOpen = false
}
if idx := strings.Index(acc, lagunaThinkingCloseTag); idx != -1 {
thinking := acc[:idx]
after := strings.TrimLeftFunc(acc[idx+len(lagunaThinkingCloseTag):], unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = lagunaParserStateContent
return true, thinking
}
if idx := strings.Index(acc, lagunaToolCallOpenTag); idx != -1 {
thinking := strings.TrimRightFunc(acc[:idx], unicode.IsSpace)
after := acc[idx+len(lagunaToolCallOpenTag):]
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = lagunaParserStateTool
return true, thinking
}
if done {
p.buffer.Reset()
p.state = lagunaParserStateContent
return acc != "", acc
}
overlapLen := max(overlap(acc, lagunaThinkingCloseTag), overlap(acc, lagunaToolCallOpenTag))
trailingLen := trailingWhitespaceLen(acc)
keep := max(overlapLen, trailingLen)
if keep > 0 && keep < len(acc) {
emit := acc[:len(acc)-keep]
p.buffer.Reset()
p.buffer.WriteString(acc[len(acc)-keep:])
return emit != "", emit
}
return false, ""
}
func (p *LagunaParser) consumeContent(done bool) (bool, string, []api.ToolCall, error) {
acc := p.buffer.String()
if p.thinkingEnabled || p.thinkingSuppressed {
if idx := strings.Index(acc, lagunaThinkingOpenTag); idx != -1 {
content := acc[:idx]
after := strings.TrimLeftFunc(acc[idx+len(lagunaThinkingOpenTag):], unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = lagunaParserStateThinking
p.allowLeadingThinkOpen = false
return true, content, nil, nil
}
if !done {
overlapLen := overlap(acc, lagunaThinkingOpenTag)
if overlapLen > 0 && overlapLen < len(acc) {
content := acc[:len(acc)-overlapLen]
p.buffer.Reset()
p.buffer.WriteString(acc[len(acc)-overlapLen:])
return content != "", content, nil, nil
}
}
}
if p.thinkingEnabled {
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
if strings.HasPrefix(trimmed, lagunaThinkingCloseTag) {
p.buffer.Reset()
p.buffer.WriteString(strings.TrimLeftFunc(strings.TrimPrefix(trimmed, lagunaThinkingCloseTag), unicode.IsSpace))
return true, "", nil, nil
}
if strings.HasPrefix(lagunaThinkingCloseTag, trimmed) && !done {
return false, "", nil, nil
}
}
if p.thinkingSuppressed {
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
if strings.HasPrefix(trimmed, lagunaThinkingCloseTag) {
p.buffer.Reset()
p.buffer.WriteString(strings.TrimLeftFunc(strings.TrimPrefix(trimmed, lagunaThinkingCloseTag), unicode.IsSpace))
return true, "", nil, nil
}
if strings.HasPrefix(lagunaThinkingCloseTag, trimmed) && !done {
return false, "", nil, nil
}
}
if idx := strings.Index(acc, lagunaToolCallOpenTag); idx != -1 {
content := strings.TrimRightFunc(acc[:idx], unicode.IsSpace)
after := acc[idx+len(lagunaToolCallOpenTag):]
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = lagunaParserStateTool
return true, content, nil, nil
}
if idx := strings.Index(acc, lagunaUserOpenTag); idx != -1 && len(p.tools) > 0 {
before := strings.TrimRightFunc(acc[:idx], unicode.IsSpace)
afterOpen := acc[idx+len(lagunaUserOpenTag):]
if closeIdx := strings.Index(afterOpen, lagunaUserCloseTag); closeIdx != -1 {
raw := afterOpen[:closeIdx]
if call, ok := p.parseToolAlias(raw); ok {
after := strings.TrimLeftFunc(afterOpen[closeIdx+len(lagunaUserCloseTag):], unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
return true, before, []api.ToolCall{call}, nil
}
} else if !done {
if idx > 0 {
p.buffer.Reset()
p.buffer.WriteString(acc[idx:])
return true, before, nil, nil
}
return false, "", nil, nil
}
}
if len(p.tools) > 0 {
if progress, content, call, ok, err := p.consumeStandaloneJSONTool(done); ok || err != nil {
if err != nil {
return false, "", nil, err
}
if progress {
return true, content, []api.ToolCall{call}, nil
}
return false, "", nil, nil
}
}
if done {
p.buffer.Reset()
return acc != "", acc, nil, nil
}
overlapLen := max(overlap(acc, lagunaToolCallOpenTag), overlap(acc, lagunaUserOpenTag))
if p.thinkingEnabled || p.thinkingSuppressed {
overlapLen = max(overlapLen, overlap(acc, lagunaThinkingOpenTag))
}
if p.thinkingSuppressed {
overlapLen = max(overlapLen, overlap(acc, lagunaThinkingCloseTag))
}
trailingLen := trailingWhitespaceLen(acc)
keep := max(overlapLen, trailingLen)
if keep > 0 && keep < len(acc) {
emit := acc[:len(acc)-keep]
p.buffer.Reset()
p.buffer.WriteString(acc[len(acc)-keep:])
return emit != "", emit, nil, nil
}
if keep == 0 && acc != "" {
p.buffer.Reset()
return true, acc, nil, nil
}
return false, "", nil, nil
}
func (p *LagunaParser) consumeStandaloneJSONTool(done bool) (progress bool, content string, call api.ToolCall, ok bool, err error) {
acc := p.buffer.String()
jsonIdx := strings.Index(acc, "{")
if jsonIdx == -1 {
return false, "", api.ToolCall{}, false, nil
}
before := strings.TrimRightFunc(acc[:jsonIdx], unicode.IsSpace)
raw := strings.TrimLeftFunc(acc[jsonIdx:], unicode.IsSpace)
if !lagunaLooksLikeJSONToolCall(raw, done) {
return false, "", api.ToolCall{}, false, nil
}
if !done && !json.Valid([]byte(strings.TrimSpace(raw))) {
if before != "" {
p.buffer.Reset()
p.buffer.WriteString(acc[jsonIdx:])
return true, before, api.ToolCall{}, true, nil
}
return false, "", api.ToolCall{}, true, nil
}
call, err = parseLagunaToolCall(raw, p.tools)
if err != nil {
return false, "", api.ToolCall{}, true, err
}
call.Function.Index = p.callIndex
p.callIndex++
p.buffer.Reset()
p.state = lagunaParserStateContent
return true, before, call, true, nil
}
func lagunaLooksLikeJSONToolCall(raw string, done bool) bool {
trimmed := strings.TrimLeftFunc(raw, unicode.IsSpace)
if !strings.HasPrefix(trimmed, "{") {
return false
}
if strings.Contains(trimmed, `"name"`) || strings.Contains(trimmed, `"arguments"`) {
return true
}
if done {
return false
}
return strings.HasPrefix(trimmed, `{"`) || strings.HasPrefix(trimmed, "{\n") || strings.HasPrefix(trimmed, "{\r\n")
}
func (p *LagunaParser) parseToolAlias(raw string) (api.ToolCall, bool) {
raw = cleanLagunaToolCallRaw(raw)
name, ok := lagunaToolCallName(raw)
if !ok {
return api.ToolCall{}, false
}
if _, ok := lagunaResolveToolName(name, p.tools); !ok {
return api.ToolCall{}, false
}
call, err := parseLagunaToolCall(raw, p.tools)
if err != nil {
return api.ToolCall{}, false
}
call.Function.Index = p.callIndex
p.callIndex++
return call, true
}
func lagunaResolveToolName(name string, tools []api.Tool) (string, bool) {
for i := range tools {
if tools[i].Function.Name == name {
return name, true
}
}
aliases := map[string]string{
"read_file": "read",
"write_file": "write",
"edit_file": "edit",
"web_fetch": "webfetch",
}
if alias, ok := aliases[name]; ok {
for i := range tools {
if tools[i].Function.Name == alias {
return alias, true
}
}
}
return name, false
}
func cleanLagunaToolCallRaw(raw string) string {
raw = strings.TrimSpace(raw)
for strings.HasPrefix(raw, lagunaToolCallOpenTag) {
raw = strings.TrimSpace(strings.TrimPrefix(raw, lagunaToolCallOpenTag))
}
if idx := strings.Index(raw, lagunaToolCallCloseTag); idx != -1 {
raw = strings.TrimSpace(raw[:idx])
}
if idx := strings.Index(raw, lagunaToolCallOpenTag); idx != -1 {
before := strings.TrimSpace(raw[:idx])
if before != "" {
return before
}
raw = strings.TrimSpace(raw[idx+len(lagunaToolCallOpenTag):])
}
return raw
}
func lagunaToolCallName(raw string) (string, bool) {
raw = cleanLagunaToolCallRaw(raw)
if strings.HasPrefix(raw, "{") {
var parsed struct {
Name string `json:"name"`
}
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
return "", false
}
name := strings.TrimSpace(parsed.Name)
return name, name != ""
}
nameEnd := strings.Index(raw, "<arg_key>")
if nameEnd < 0 {
nameEnd = strings.Index(raw, "{")
}
if nameEnd < 0 {
nameEnd = strings.IndexAny(raw, "\r\n")
}
if nameEnd < 0 {
nameEnd = len(raw)
}
name := strings.TrimSpace(raw[:nameEnd])
return name, name != ""
}
func (p *LagunaParser) consumeTool(done bool) (bool, api.ToolCall, error) {
acc := p.buffer.String()
if idx := strings.Index(acc, lagunaToolCallCloseTag); idx != -1 {
raw := acc[:idx]
after := strings.TrimLeftFunc(acc[idx+len(lagunaToolCallCloseTag):], unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = lagunaParserStateContent
call, err := parseLagunaToolCall(raw, p.tools)
if err != nil {
return false, api.ToolCall{}, err
}
call.Function.Index = p.callIndex
p.callIndex++
return true, call, nil
}
if done && strings.TrimSpace(acc) != "" {
p.buffer.Reset()
p.state = lagunaParserStateContent
call, err := parseLagunaToolCall(acc, p.tools)
if err != nil {
return false, api.ToolCall{}, err
}
call.Function.Index = p.callIndex
p.callIndex++
return true, call, nil
}
return false, api.ToolCall{}, nil
}
var lagunaArgRE = regexp.MustCompile(`(?s)<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>`)
func parseLagunaToolCall(raw string, tools []api.Tool) (api.ToolCall, error) {
raw = cleanLagunaToolCallRaw(raw)
if strings.HasPrefix(raw, "{") {
var parsed struct {
Name string `json:"name"`
Arguments api.ToolCallFunctionArguments `json:"arguments"`
}
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse Laguna JSON tool call: %w", err)
}
if parsed.Name == "" {
return api.ToolCall{}, fmt.Errorf("empty Laguna tool call name")
}
if name, ok := lagunaResolveToolName(parsed.Name, tools); ok {
parsed.Name = name
}
return api.ToolCall{
Function: api.ToolCallFunction{
Name: parsed.Name,
Arguments: parsed.Arguments,
},
}, nil
}
nameEnd := strings.Index(raw, "<arg_key>")
name := raw
argsText := ""
if nameEnd >= 0 {
name = raw[:nameEnd]
argsText = raw[nameEnd:]
} else if jsonStart := strings.Index(raw, "{"); jsonStart >= 0 {
name = raw[:jsonStart]
argsText = raw[jsonStart:]
}
name = strings.TrimSpace(name)
if resolved, ok := lagunaResolveToolName(name, tools); ok {
name = resolved
}
var matchedTool *api.Tool
for i := range tools {
if tools[i].Function.Name == name {
matchedTool = &tools[i]
break
}
}
call := api.ToolCall{
Function: api.ToolCallFunction{
Name: name,
Arguments: api.NewToolCallFunctionArguments(),
},
}
if strings.HasPrefix(strings.TrimSpace(argsText), "{") {
if err := json.Unmarshal([]byte(strings.TrimSpace(argsText)), &call.Function.Arguments); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse Laguna JSON tool call arguments: %w", err)
}
return call, nil
}
for _, match := range lagunaArgRE.FindAllStringSubmatch(argsText, -1) {
key := strings.TrimSpace(match[1])
value := match[2]
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
if len(prop.AnyOf) > 0 {
for _, anyOfProp := range prop.AnyOf {
paramType = append(paramType, anyOfProp.Type...)
}
} else {
paramType = prop.Type
}
}
}
call.Function.Arguments.Set(key, parseValue(value, paramType))
}
return call, nil
}

View File

@@ -0,0 +1,484 @@
package parsers
import (
"testing"
"github.com/ollama/ollama/api"
)
func lagunaTestTools() []api.Tool {
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
props.Set("days", api.ToolProperty{Type: api.PropertyType{"integer"}})
return []api.Tool{{
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
}}
}
func TestLagunaParserToolCall(t *testing.T) {
parser := ParserForName("laguna")
if parser == nil {
t.Fatal("expected laguna parser")
}
if !parser.HasToolSupport() || !parser.HasThinkingSupport() {
t.Fatal("laguna parser should advertise tools and thinking")
}
parser.Init(lagunaTestTools(), nil, nil)
content, thinking, calls, err := parser.Add("<tool_call>get_weather\n<arg_key>location</arg_key>\n<arg_value>Paris</arg_value>\n<arg_key>days</arg_key>\n<arg_value>3</arg_value>\n</tool_call>", true)
if err != nil {
t.Fatal(err)
}
if content != "" || thinking != "" {
t.Fatalf("content=%q thinking=%q, want empty", content, thinking)
}
if len(calls) != 1 {
t.Fatalf("calls=%d, want 1", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
}
if got, _ := calls[0].Function.Arguments.Get("location"); got != "Paris" {
t.Fatalf("location=%v, want Paris", got)
}
if got, _ := calls[0].Function.Arguments.Get("days"); got != 3 {
t.Fatalf("days=%v, want 3", got)
}
}
func TestLagunaParserJSONToolCall(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(lagunaTestTools(), nil, nil)
_, _, calls, err := parser.Add("<tool_call>\n{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\",\"days\":3}}\n</tool_call>", true)
if err != nil {
t.Fatal(err)
}
if len(calls) != 1 {
t.Fatalf("calls=%d, want 1", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
}
if got, _ := calls[0].Function.Arguments.Get("location"); got != "Paris" {
t.Fatalf("location=%v, want Paris", got)
}
if got, _ := calls[0].Function.Arguments.Get("days"); got != float64(3) {
t.Fatalf("days=%v, want 3", got)
}
}
func TestLagunaParserStandaloneJSONToolCall(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(lagunaTestTools(), nil, nil)
content, thinking, calls, err := parser.Add("{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\",\"days\":3}}", true)
if err != nil {
t.Fatal(err)
}
if content != "" || thinking != "" {
t.Fatalf("content=%q thinking=%q", content, thinking)
}
if len(calls) != 1 {
t.Fatalf("calls=%d, want 1", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
}
}
func TestLagunaParserStandaloneJSONToolCallAfterLeadingContent(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(lagunaTestTools(), nil, nil)
content, thinking, calls, err := parser.Add("Let me call the weather tool.\n{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}", true)
if err != nil {
t.Fatal(err)
}
if content != "Let me call the weather tool." || thinking != "" {
t.Fatalf("content=%q thinking=%q", content, thinking)
}
if len(calls) != 1 {
t.Fatalf("calls=%d, want 1", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
}
}
func TestLagunaParserStreamingStandaloneJSONToolCall(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(lagunaTestTools(), nil, nil)
content, thinking, calls, err := parser.Add("{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco,", false)
if err != nil {
t.Fatal(err)
}
if content != "" || thinking != "" || len(calls) != 0 {
t.Fatalf("first chunk content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
content, thinking, calls, err = parser.Add(" CA\"}}", true)
if err != nil {
t.Fatal(err)
}
if content != "" || thinking != "" || len(calls) != 1 {
t.Fatalf("second chunk content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
}
if got, _ := calls[0].Function.Arguments.Get("location"); got != "San Francisco, CA" {
t.Fatalf("location=%v, want San Francisco, CA", got)
}
}
func TestLagunaParserNameLineJSONToolCall(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(lagunaTestTools(), nil, nil)
_, _, calls, err := parser.Add("<tool_call>get_weather\n{\"location\":\"San Francisco\"}</tool_call>", true)
if err != nil {
t.Fatal(err)
}
if len(calls) != 1 {
t.Fatalf("calls=%d, want 1", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
}
if got, _ := calls[0].Function.Arguments.Get("location"); got != "San Francisco" {
t.Fatalf("location=%v, want San Francisco", got)
}
}
func TestLagunaParserNormalizesCommonToolAliases(t *testing.T) {
props := api.NewToolPropertiesMap()
props.Set("path", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{{
Function: api.ToolFunction{
Name: "read",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
}}
parser := ParserForName("laguna")
parser.Init(tools, nil, nil)
_, _, calls, err := parser.Add("<tool_call>\n{\"name\":\"read_file\",\"arguments\":{\"path\":\"./go.mod\"}}\n</tool_call>", true)
if err != nil {
t.Fatal(err)
}
if len(calls) != 1 {
t.Fatalf("calls=%d, want 1", len(calls))
}
if calls[0].Function.Name != "read" {
t.Fatalf("name=%q, want read", calls[0].Function.Name)
}
if got, _ := calls[0].Function.Arguments.Get("path"); got != "./go.mod" {
t.Fatalf("path=%v, want ./go.mod", got)
}
}
func TestLagunaParserIgnoresDuplicatedNestedToolCall(t *testing.T) {
props := api.NewToolPropertiesMap()
props.Set("name", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{{
Function: api.ToolFunction{
Name: "skill",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
}}
parser := ParserForName("laguna")
parser.Init(tools, nil, nil)
_, _, calls, err := parser.Add("<tool_call>skill\n{\"name\":\"git-diff-review\"}\n<tool_call>skill\n{\"name\":\"git-diff-review\"}</tool_call>", true)
if err != nil {
t.Fatal(err)
}
if len(calls) != 1 {
t.Fatalf("calls=%d, want 1", len(calls))
}
if calls[0].Function.Name != "skill" {
t.Fatalf("name=%q, want skill", calls[0].Function.Name)
}
if got, _ := calls[0].Function.Arguments.Get("name"); got != "git-diff-review" {
t.Fatalf("name arg=%v, want git-diff-review", got)
}
}
func TestLagunaParserThinkingThenTool(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(lagunaTestTools(), nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("<think>Need current weather.</think>\n<tool_call>get_weather\n<arg_key>location</arg_key>\n<arg_value>SF</arg_value>\n</tool_call>", true)
if err != nil {
t.Fatal(err)
}
if content != "" {
t.Fatalf("content=%q, want empty", content)
}
if thinking != "Need current weather." {
t.Fatalf("thinking=%q, want reasoning", thinking)
}
if len(calls) != 1 || calls[0].Function.Name != "get_weather" {
t.Fatalf("unexpected calls: %#v", calls)
}
}
func TestLagunaParserUserTaggedToolAlias(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(lagunaTestTools(), nil, nil)
content, thinking, calls, err := parser.Add("<user>get_weather\n<arg_key>location</arg_key>\n<arg_value>San Francisco, CA</arg_value>\n</user>", true)
if err != nil {
t.Fatal(err)
}
if content != "" || thinking != "" {
t.Fatalf("content=%q thinking=%q, want empty", content, thinking)
}
if len(calls) != 1 {
t.Fatalf("calls=%d, want 1", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
}
if got, _ := calls[0].Function.Arguments.Get("location"); got != "San Francisco, CA" {
t.Fatalf("location=%v, want San Francisco, CA", got)
}
}
func TestLagunaParserUserTaggedToolAliasWithLeadingContent(t *testing.T) {
parser := ParserForName("laguna")
props := api.NewToolPropertiesMap()
props.Set("path", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{{
Function: api.ToolFunction{
Name: "read",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
}}
parser.Init(tools, nil, nil)
content, thinking, calls, err := parser.Add("I'll read the file for you.\n<user>read\n<arg_key>path</arg_key>\n<arg_value>/Users/test/code/myproject/go.mod</arg_value>\n</user>", true)
if err != nil {
t.Fatal(err)
}
if content != "I'll read the file for you." || thinking != "" {
t.Fatalf("content=%q thinking=%q", content, thinking)
}
if len(calls) != 1 {
t.Fatalf("calls=%d, want 1", len(calls))
}
if calls[0].Function.Name != "read" {
t.Fatalf("name=%q, want read", calls[0].Function.Name)
}
if got, _ := calls[0].Function.Arguments.Get("path"); got != "/Users/test/code/myproject/go.mod" {
t.Fatalf("path=%v, want /Users/test/code/myproject/go.mod", got)
}
}
func TestLagunaParserUserTaggedJSONToolCallWithLeadingContent(t *testing.T) {
parser := ParserForName("laguna")
props := api.NewToolPropertiesMap()
props.Set("command", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{{
Function: api.ToolFunction{
Name: "bash",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
}}
parser.Init(tools, nil, nil)
content, thinking, calls, err := parser.Add("I'll run git diff for you.<user>\n{\"name\":\"bash\",\"arguments\":{\"command\":\"git diff main\"}}\n</user>", true)
if err != nil {
t.Fatal(err)
}
if content != "I'll run git diff for you." || thinking != "" {
t.Fatalf("content=%q thinking=%q", content, thinking)
}
if len(calls) != 1 {
t.Fatalf("calls=%d, want 1", len(calls))
}
if calls[0].Function.Name != "bash" {
t.Fatalf("name=%q, want bash", calls[0].Function.Name)
}
if got, _ := calls[0].Function.Arguments.Get("command"); got != "git diff main" {
t.Fatalf("command=%v, want git diff main", got)
}
}
func TestLagunaParserStreamingUserTaggedToolAliasAfterContent(t *testing.T) {
parser := ParserForName("laguna")
props := api.NewToolPropertiesMap()
props.Set("path", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{{
Function: api.ToolFunction{
Name: "read",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
}}
parser.Init(tools, nil, nil)
content, thinking, calls, err := parser.Add("I'll read the file for you.<us", false)
if err != nil {
t.Fatal(err)
}
if content != "I'll read the file for you." || thinking != "" || len(calls) != 0 {
t.Fatalf("first chunk content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
content, thinking, calls, err = parser.Add("er>read\n<arg_key>path</arg_key>\n<arg_value>/Users/test/code/myproject/go.mod</arg_value>\n</user>", true)
if err != nil {
t.Fatal(err)
}
if content != "" || thinking != "" {
t.Fatalf("second chunk content=%q thinking=%q", content, thinking)
}
if len(calls) != 1 {
t.Fatalf("calls=%d, want 1", len(calls))
}
if calls[0].Function.Name != "read" {
t.Fatalf("name=%q, want read", calls[0].Function.Name)
}
if got, _ := calls[0].Function.Arguments.Get("path"); got != "/Users/test/code/myproject/go.mod" {
t.Fatalf("path=%v, want /Users/test/code/myproject/go.mod", got)
}
}
func TestLagunaParserUserTaggedNonToolContent(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(lagunaTestTools(), nil, nil)
content, thinking, calls, err := parser.Add("<user>hello</user>", true)
if err != nil {
t.Fatal(err)
}
if content != "<user>hello</user>" || thinking != "" || len(calls) != 0 {
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
}
func TestLagunaParserThinkingDefaultsOn(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(nil, nil, nil)
content, thinking, calls, err := parser.Add("<think>Need to reason.</think>\nDirect answer.", true)
if err != nil {
t.Fatal(err)
}
if content != "Direct answer." || thinking != "Need to reason." || len(calls) != 0 {
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
}
func TestLagunaParserThinkingDefaultsOnWhenToolsPresent(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(lagunaTestTools(), nil, nil)
content, thinking, calls, err := parser.Add("<think>Need to reason.</think>\n<tool_call>get_weather\n<arg_key>location</arg_key>\n<arg_value>Paris</arg_value>\n</tool_call>", true)
if err != nil {
t.Fatal(err)
}
if thinking != "Need to reason." || len(calls) != 1 {
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
if content != "" {
t.Fatalf("content=%q, want thinking block suppressed from content when default thinking is enabled", content)
}
}
func TestLagunaParserThinkingExplicitlyDisabled(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(nil, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add("<think>Hidden?</think>\nDirect answer.", true)
if err != nil {
t.Fatal(err)
}
if content != "Direct answer." || thinking != "" || len(calls) != 0 {
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
}
func TestLagunaParserThinkingExplicitlyDisabledDropsLeadingCloseTag(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(nil, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add("</think>\nTokyo\n", true)
if err != nil {
t.Fatal(err)
}
if content != "Tokyo\n" || thinking != "" || len(calls) != 0 {
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
}
func TestLagunaParserThinkingEnabledDropsLeadingCloseTag(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("</think>\nTokyo\n", true)
if err != nil {
t.Fatal(err)
}
if content != "Tokyo\n" || thinking != "" || len(calls) != 0 {
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
}
func TestLagunaParserThinkingDefaultOnDropsLeadingCloseTag(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(nil, nil, nil)
content, thinking, calls, err := parser.Add("</think>\nTokyo\n", true)
if err != nil {
t.Fatal(err)
}
if content != "Tokyo\n" || thinking != "" || len(calls) != 0 {
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
}
func TestLagunaParserThinkingEnabledUntaggedAnswerIsContent(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("Direct answer.", true)
if err != nil {
t.Fatal(err)
}
if content != "Direct answer." || thinking != "" || len(calls) != 0 {
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
}
func TestLagunaParserSplitToolTag(t *testing.T) {
parser := ParserForName("laguna")
parser.Init(lagunaTestTools(), nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("<think>Need lookup<tool_c", false)
if err != nil {
t.Fatal(err)
}
if content != "" || thinking != "Need lookup" || len(calls) != 0 {
t.Fatalf("first chunk content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
content, thinking, calls, err = parser.Add("all>get_weather\n<arg_key>location</arg_key>\n<arg_value>SF</arg_value>\n</tool_call>", true)
if err != nil {
t.Fatal(err)
}
if content != "" || thinking != "" || len(calls) != 1 {
t.Fatalf("second chunk content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
}

View File

@@ -16,14 +16,17 @@ const (
)
const (
nemotronThinkOpen = "<think>"
nemotronThinkClose = "</think>"
nemotronToolCallOpen = "<tool_call>"
)
type Nemotron3NanoParser struct {
state Nemotron3NanoParserState
buffer strings.Builder
toolParser *Qwen3CoderParser
state Nemotron3NanoParserState
buffer strings.Builder
toolParser *Qwen3CoderParser
maybeThinkingOpenAtBOL bool
skipThinkingLeadingWS bool
}
func (p *Nemotron3NanoParser) HasToolSupport() bool { return true }
@@ -32,14 +35,18 @@ func (p *Nemotron3NanoParser) HasThinkingSupport() bool { return true }
func (p *Nemotron3NanoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.toolParser = &Qwen3CoderParser{}
p.toolParser.Init(tools, nil, nil)
p.buffer.Reset()
p.maybeThinkingOpenAtBOL = false
p.skipThinkingLeadingWS = false
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
thinkingEnabled := thinkValue == nil || thinkValue.Bool()
prefill := lastMessage != nil && lastMessage.Role == "assistant"
if !thinkingEnabled || (prefill && lastMessage.Content != "") {
p.state = Nemotron3NanoCollectingContent
} else {
p.state = Nemotron3NanoCollectingThinking
p.maybeThinkingOpenAtBOL = true
}
return tools
@@ -61,6 +68,29 @@ func (p *Nemotron3NanoParser) Add(s string, done bool) (content string, thinking
// Nemotron3NanoCollectingThinking - buffer and look for end markers
p.buffer.WriteString(s)
if p.skipThinkingLeadingWS {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(trimmed)
if trimmed == "" {
return "", "", nil, nil
}
p.skipThinkingLeadingWS = false
}
if p.stripOpeningThinkTag() {
return p.Add("", done)
}
if p.maybeThinkingOpenAtBOL {
bufStr := p.buffer.String()
trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace)
if trimmed == "" || overlap(trimmed, nemotronThinkOpen) == len(trimmed) {
if len(trimmed) != len(bufStr) {
p.buffer.Reset()
p.buffer.WriteString(trimmed)
}
return "", "", nil, nil
}
}
bufStr := p.buffer.String()
// Look for end of thinking: </think> or <tool_call> (model may skip </think>)
@@ -124,3 +154,35 @@ func (p *Nemotron3NanoParser) emitThinking(bufStr string) string {
p.buffer.Reset()
return bufStr
}
func (p *Nemotron3NanoParser) stripOpeningThinkTag() bool {
if !p.maybeThinkingOpenAtBOL {
return false
}
bufStr := p.buffer.String()
trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace)
if trimmed == "" {
p.buffer.Reset()
return false
}
if strings.HasPrefix(trimmed, nemotronThinkOpen) {
p.buffer.Reset()
p.buffer.WriteString(strings.TrimLeftFunc(trimmed[len(nemotronThinkOpen):], unicode.IsSpace))
p.maybeThinkingOpenAtBOL = false
p.skipThinkingLeadingWS = true
return true
}
if overlap(trimmed, nemotronThinkOpen) == len(trimmed) {
if len(trimmed) != len(bufStr) {
p.buffer.Reset()
p.buffer.WriteString(trimmed)
}
return false
}
p.maybeThinkingOpenAtBOL = false
return false
}

View File

@@ -82,6 +82,20 @@ func TestNemotron3NanoParser(t *testing.T) {
expectedThinking: "My thoughts...",
expectedContent: "Content here.",
},
{
name: "leading open think tag is ignored",
input: "<think>\nLet me think about this...</think>\nHere is my answer.",
thinkValue: &api.ThinkValue{Value: true},
expectedThinking: "Let me think about this...",
expectedContent: "Here is my answer.",
},
{
name: "empty explicit think block is ignored",
input: "<think></think>\nHere is my answer.",
thinkValue: &api.ThinkValue{Value: true},
expectedThinking: "",
expectedContent: "Here is my answer.",
},
}
for _, tt := range tests {
@@ -191,6 +205,13 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
},
},
},
{
name: "leading open think tag split across chunks",
chunks: []string{"<th", "ink>", "\nThink first", "</think>", "\nDone."},
thinkValue: &api.ThinkValue{Value: true},
expectedThinking: "Think first",
expectedContent: "Done.",
},
}
for _, tt := range tests {
@@ -265,11 +286,11 @@ func TestNemotron3NanoParser_Init(t *testing.T) {
}
})
t.Run("starts in content state when nil thinkValue", func(t *testing.T) {
t.Run("starts in thinking state when nil thinkValue", func(t *testing.T) {
p := &Nemotron3NanoParser{}
p.Init(nil, nil, nil)
if p.state != Nemotron3NanoCollectingContent {
t.Errorf("expected state Nemotron3NanoCollectingContent, got %v", p.state)
if p.state != Nemotron3NanoCollectingThinking {
t.Errorf("expected state Nemotron3NanoCollectingThinking, got %v", p.state)
}
})
@@ -281,6 +302,29 @@ func TestNemotron3NanoParser_Init(t *testing.T) {
t.Errorf("expected state Nemotron3NanoCollectingContent, got %v", p.state)
}
})
t.Run("reinit clears buffered state", func(t *testing.T) {
p := &Nemotron3NanoParser{}
p.Init(nil, nil, &api.ThinkValue{Value: true})
if _, _, _, err := p.Add("thinking in progress", false); err != nil {
t.Fatalf("unexpected error: %v", err)
}
p.Init(nil, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := p.Add("content only", true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if content != "content only" {
t.Fatalf("expected content after reinit, got %q", content)
}
if thinking != "" {
t.Fatalf("expected no thinking after reinit, got %q", thinking)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls after reinit, got %v", calls)
}
})
}
func TestNemotron3NanoParser_WithTools(t *testing.T) {

View File

@@ -87,6 +87,8 @@ func ParserForName(name string) Parser {
return &LFM2Parser{hasThinkingSupport: false}
case "lfm2-thinking":
return &LFM2Parser{hasThinkingSupport: true}
case "laguna":
return &LagunaParser{}
default:
return nil
}

111
model/renderers/laguna.go Normal file
View File

@@ -0,0 +1,111 @@
package renderers
import (
"strings"
"github.com/ollama/ollama/api"
)
const (
lagunaBOS = "〈|EOS|〉"
lagunaThoughtOpen = "<think>"
lagunaThoughtClose = "</think>"
)
type LagunaRenderer struct{}
func (r *LagunaRenderer) Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString(lagunaBOS)
thinkingEnabled := think == nil || think.Bool()
systemMessage := ""
firstMessageIsSystem := len(messages) > 0 && messages[0].Role == "system"
if firstMessageIsSystem {
systemMessage = strings.TrimRight(messages[0].Content, "\n")
}
sb.WriteString("<system>\n")
if thinkingEnabled {
sb.WriteString("You should use chain-of-thought reasoning. Put your reasoning inside <think> </think> tags before your response.")
} else {
sb.WriteString("You should respond directly without using chain-of-thought reasoning tags.")
}
if strings.TrimSpace(systemMessage) != "" {
sb.WriteByte('\n')
sb.WriteString(systemMessage)
}
if len(tools) > 0 {
sb.WriteString("\n\n### Tools\n\n")
sb.WriteString("You may call functions to assist with the user query.\n")
sb.WriteString("All available function signatures are listed below:\n")
sb.WriteString("<available_tools>\n")
for _, tool := range tools {
if b, err := marshalWithSpaces(tool); err == nil {
sb.Write(b)
sb.WriteByte('\n')
}
}
sb.WriteString("</available_tools>\n\n")
sb.WriteString("For each function call, return a json object with function name and arguments within '<tool_call>' and '</tool_call>' tags:\n")
sb.WriteString("<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>")
}
sb.WriteString("\n</system>\n")
for i, message := range messages {
if i == 0 && firstMessageIsSystem {
continue
}
content := message.Content
switch message.Role {
case "user":
sb.WriteString("<user>\n")
sb.WriteString(content)
sb.WriteString("\n</user>\n")
case "assistant":
lastMessage := i == len(messages)-1
prefill := lastMessage && (content != "" || message.Thinking != "" || len(message.ToolCalls) > 0)
sb.WriteString("<assistant>\n")
if thinkingEnabled && message.Thinking != "" {
sb.WriteString(lagunaThoughtOpen)
sb.WriteString(message.Thinking)
sb.WriteString(lagunaThoughtClose)
sb.WriteByte('\n')
}
if strings.Trim(content, "\n") != "" {
sb.WriteString(strings.Trim(content, "\n"))
sb.WriteByte('\n')
}
for _, toolCall := range message.ToolCalls {
sb.WriteString("<tool_call>")
sb.WriteString(toolCall.Function.Name)
sb.WriteByte('\n')
for name, value := range toolCall.Function.Arguments.All() {
sb.WriteString("<arg_key>")
sb.WriteString(name)
sb.WriteString("</arg_key>\n")
sb.WriteString("<arg_value>")
sb.WriteString(formatToolCallArgument(value))
sb.WriteString("</arg_value>\n")
}
sb.WriteString("</tool_call>\n")
}
if !prefill {
sb.WriteString("</assistant>\n")
}
case "tool":
sb.WriteString("<tool_response>\n")
sb.WriteString(content)
sb.WriteString("\n</tool_response>\n")
case "system":
sb.WriteString("<system>\n")
sb.WriteString(content)
sb.WriteString("\n</system>\n")
}
}
if len(messages) == 0 || messages[len(messages)-1].Role != "assistant" {
sb.WriteString("<assistant>\n")
}
return sb.String(), nil
}

View File

@@ -0,0 +1,339 @@
package renderers
import (
"encoding/json"
"os"
"os/exec"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
lagunaDirectDirective = "You should respond directly without using chain-of-thought reasoning tags."
lagunaThinkDirective = "You should use chain-of-thought reasoning. Put your reasoning inside <think> </think> tags before your response."
)
func TestLagunaRendererReferenceFlowCoverage(t *testing.T) {
weather := lagunaWeatherTool()
tests := []struct {
name string
messages []api.Message
tools []api.Tool
think *api.ThinkValue
want string
}{
{
name: "user_only_thinking_default_on",
messages: []api.Message{{Role: "user", Content: "Hello"}},
want: "" +
"〈|EOS|〉<system>\n" +
lagunaThinkDirective +
"\n</system>\n" +
"<user>\nHello\n</user>\n" +
"<assistant>\n",
},
{
name: "user_only_thinking_enabled",
messages: []api.Message{{Role: "user", Content: "Hello"}},
think: &api.ThinkValue{Value: true},
want: "" +
"〈|EOS|〉<system>\n" +
lagunaThinkDirective +
"\n</system>\n" +
"<user>\nHello\n</user>\n" +
"<assistant>\n",
},
{
name: "user_only_thinking_disabled",
messages: []api.Message{{Role: "user", Content: "Hello"}},
think: &api.ThinkValue{Value: false},
want: "" +
"〈|EOS|〉<system>\n" +
lagunaDirectDirective +
"\n</system>\n" +
"<user>\nHello\n</user>\n" +
"<assistant>\n",
},
{
name: "first_system_is_header",
messages: []api.Message{
{Role: "system", Content: "Stay concise.\n\n"},
{Role: "user", Content: "Hi"},
},
want: "" +
"〈|EOS|〉<system>\n" +
lagunaThinkDirective +
"\nStay concise." +
"\n</system>\n" +
"<user>\nHi\n</user>\n" +
"<assistant>\n",
},
{
name: "additional_system_message_renders_in_loop",
messages: []api.Message{
{Role: "system", Content: "Primary."},
{Role: "user", Content: "Hi"},
{Role: "system", Content: "Secondary."},
},
want: "" +
"〈|EOS|〉<system>\n" +
lagunaThinkDirective +
"\nPrimary." +
"\n</system>\n" +
"<user>\nHi\n</user>\n" +
"<system>\nSecondary.\n</system>\n" +
"<assistant>\n",
},
{
name: "tools_in_header",
messages: []api.Message{
{Role: "system", Content: "Stay concise."},
{Role: "user", Content: "Weather?"},
},
tools: weather,
think: &api.ThinkValue{Value: true},
want: "" +
"〈|EOS|〉<system>\n" +
lagunaThinkDirective +
"\nStay concise." +
"\n\n### Tools\n\n" +
"You may call functions to assist with the user query.\n" +
"All available function signatures are listed below:\n" +
"<available_tools>\n" +
`{"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "City"}}}}}` + "\n" +
"</available_tools>\n\n" +
"For each function call, return a json object with function name and arguments within '<tool_call>' and '</tool_call>' tags:\n" +
"<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" +
"\n</system>\n" +
"<user>\nWeather?\n</user>\n" +
"<assistant>\n",
},
{
name: "tools_default_thinking_on_when_unspecified",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
},
tools: weather,
want: "" +
"〈|EOS|〉<system>\n" +
lagunaThinkDirective +
"\n\n### Tools\n\n" +
"You may call functions to assist with the user query.\n" +
"All available function signatures are listed below:\n" +
"<available_tools>\n" +
`{"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "City"}}}}}` + "\n" +
"</available_tools>\n\n" +
"For each function call, return a json object with function name and arguments within '<tool_call>' and '</tool_call>' tags:\n" +
"<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" +
"\n</system>\n" +
"<user>\nWeather?\n</user>\n" +
"<assistant>\n",
},
{
name: "assistant_history_with_thinking_content_tool_and_response",
messages: []api.Message{
{Role: "user", Content: "Add these."},
{
Role: "assistant",
Content: "\nCalling the tool.\n",
Thinking: "Need addition.",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "add",
Arguments: testArgsOrdered([]orderedArg{
{Key: "a", Value: 2},
{Key: "b", Value: 3},
}),
},
}},
},
{Role: "tool", Content: "5"},
{Role: "user", Content: "Thanks"},
},
think: &api.ThinkValue{Value: true},
want: "" +
"〈|EOS|〉<system>\n" +
lagunaThinkDirective +
"\n</system>\n" +
"<user>\nAdd these.\n</user>\n" +
"<assistant>\n" +
"<think>Need addition.</think>\n" +
"Calling the tool.\n" +
"<tool_call>add\n" +
"<arg_key>a</arg_key>\n<arg_value>2</arg_value>\n" +
"<arg_key>b</arg_key>\n<arg_value>3</arg_value>\n" +
"</tool_call>\n" +
"</assistant>\n" +
"<tool_response>\n5\n</tool_response>\n" +
"<user>\nThanks\n</user>\n" +
"<assistant>\n",
},
{
name: "final_assistant_prefill_is_continued",
messages: []api.Message{
{Role: "user", Content: "Complete this"},
{Role: "assistant", Content: "Partial"},
},
want: "" +
"〈|EOS|〉<system>\n" +
lagunaThinkDirective +
"\n</system>\n" +
"<user>\nComplete this\n</user>\n" +
"<assistant>\nPartial\n",
},
}
renderer := &LagunaRenderer{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Fatalf("renderer output mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestLagunaRendererMatchesLocalJinjaControlFlow(t *testing.T) {
if os.Getenv("VERIFY_LAGUNA_JINJA2") == "" {
t.Skip("set VERIFY_LAGUNA_JINJA2=1 to compare against the local Laguna chat_template.jinja")
}
python := "/Users/daniel/.codex/worktrees/7038/ollama/.venv/bin/python3"
if _, err := os.Stat(python); err != nil {
t.Fatalf("VERIFY_LAGUNA_JINJA2 requires %s with jinja2 installed", python)
}
tests := []struct {
name string
messages []api.Message
think *api.ThinkValue
}{
{
name: "user_only",
messages: []api.Message{{Role: "user", Content: "Hello"}},
},
{
name: "system_user",
messages: []api.Message{
{Role: "system", Content: "Stay concise.\n"},
{Role: "user", Content: "Hello"},
},
},
{
name: "additional_system_and_tool_response",
messages: []api.Message{
{Role: "system", Content: "Primary."},
{Role: "user", Content: "Weather?"},
{Role: "assistant", Content: "Calling."},
{Role: "tool", Content: "Sunny"},
{Role: "system", Content: "Secondary."},
},
},
{
name: "thinking_enabled",
messages: []api.Message{{Role: "user", Content: "Think briefly."}},
think: &api.ThinkValue{Value: true},
},
{
name: "thinking_disabled",
messages: []api.Message{{Role: "user", Content: "Answer directly."}},
think: &api.ThinkValue{Value: false},
},
}
renderer := &LagunaRenderer{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := renderer.Render(tt.messages, nil, tt.think)
if err != nil {
t.Fatal(err)
}
for _, modelDir := range []string{
"/Users/daniel/Models/poolside/laguna-xs-23-04-2026",
} {
want := renderLagunaChatTemplate(t, python, modelDir, tt.messages, tt.think)
if diff := cmp.Diff(want, got); diff != "" {
t.Fatalf("%s mismatch (-chat_template +renderer):\n%s", modelDir, diff)
}
}
})
}
}
func renderLagunaChatTemplate(t *testing.T, python, modelDir string, messages []api.Message, think *api.ThinkValue) string {
t.Helper()
type templateMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
templateMessages := make([]templateMessage, 0, len(messages))
for _, msg := range messages {
templateMessages = append(templateMessages, templateMessage{
Role: msg.Role,
Content: msg.Content,
})
}
messagesJSON, err := json.Marshal(templateMessages)
if err != nil {
t.Fatalf("failed to marshal messages: %v", err)
}
enableThinking := "True"
if think != nil && !think.Bool() {
enableThinking = "False"
}
script := `
import json
import sys
from transformers import AutoTokenizer
model_dir = sys.argv[1]
messages = json.loads(sys.argv[2])
enable_thinking = sys.argv[3] == "True"
tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
print(tok.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
), end="")
`
cmd := exec.Command(python, "-c", script, modelDir, string(messagesJSON), enableThinking)
var stdout, stderr strings.Builder
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
t.Fatalf("chat_template render failed: %v\nstderr: %s", err, stderr.String())
}
return stdout.String()
}
func lagunaWeatherTool() []api.Tool {
return []api.Tool{{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testPropsOrdered([]orderedProp{{
Key: "location",
Value: api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "City",
},
}}),
},
},
}}
}

View File

@@ -3,6 +3,9 @@ package renderers
import (
"encoding/json"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"github.com/ollama/ollama/api"
@@ -12,15 +15,15 @@ type Nemotron3NanoRenderer struct{}
func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
imageOffset := 0
// thinking is enabled if user requests it
enableThinking := thinkValue != nil && thinkValue.Bool()
enableThinking := r.resolveThinking(messages, thinkValue)
// Extract system message if present
var systemMessage string
var loopMessages []api.Message
if len(messages) > 0 && messages[0].Role == "system" {
systemMessage = messages[0].Content
systemMessage = r.sanitizeSystemMessage(messages[0].Content)
loopMessages = messages[1:]
} else {
loopMessages = messages
@@ -34,6 +37,7 @@ func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool,
}
}
sb.WriteString("\n\n\n")
sb.WriteString("<|im_start|>system\n")
if systemMessage != "" {
sb.WriteString(systemMessage)
@@ -45,28 +49,30 @@ func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool,
}
sb.WriteString(r.renderTools(tools))
}
sb.WriteString("<|im_end|>\n")
sb.WriteString("<|im_end|>\n\n")
for i, message := range loopMessages {
switch message.Role {
case "assistant":
// Build content with thinking tags
content := r.buildContent(message)
shouldTruncate := i < lastUserIdx
if len(message.ToolCalls) > 0 {
sb.WriteString("<|im_start|>assistant\n")
sb.WriteString(r.formatContent(content, shouldTruncate, true))
sb.WriteString(r.formatToolCallContent(content, shouldTruncate))
r.writeToolCalls(&sb, message.ToolCalls)
sb.WriteString("<|im_end|>\n")
} else {
formatted := r.formatContent(content, shouldTruncate, false)
sb.WriteString("<|im_start|>assistant\n" + formatted + "<|im_end|>\n")
formatted := r.formatAssistantContent(content, shouldTruncate)
sb.WriteString("<|im_start|>assistant\n")
sb.WriteString(formatted)
sb.WriteString("<|im_end|>\n")
}
case "user", "system":
sb.WriteString("<|im_start|>" + message.Role + "\n")
sb.WriteString(message.Content)
sb.WriteString(r.renderMessageContent(message, imageOffset))
imageOffset += len(message.Images)
sb.WriteString("<|im_end|>\n")
case "tool":
@@ -90,6 +96,8 @@ func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool,
}
}
sb.WriteString("\n")
// Add generation prompt
if enableThinking {
sb.WriteString("<|im_start|>assistant\n<think>\n")
@@ -119,7 +127,7 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string {
sb.WriteString("\n<name>" + paramName + "</name>")
if len(paramFields.Type) > 0 {
sb.WriteString("\n<type>" + strings.Join(paramFields.Type, ", ") + "</type>")
sb.WriteString("\n<type>" + r.formatPropertyType(paramFields.Type) + "</type>")
}
if paramFields.Description != "" {
@@ -127,17 +135,17 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string {
}
if len(paramFields.Enum) > 0 {
enumJSON, _ := json.Marshal(paramFields.Enum)
sb.WriteString("\n<enum>" + string(enumJSON) + "</enum>")
sb.WriteString("\n<enum>" + r.pythonJSON(paramFields.Enum) + "</enum>")
}
r.renderToolPropertyExtraKeys(&sb, paramFields)
sb.WriteString("\n</parameter>")
}
}
r.renderToolParameterExtraKeys(&sb, fn.Parameters)
if len(fn.Parameters.Required) > 0 {
reqJSON, _ := json.Marshal(fn.Parameters.Required)
sb.WriteString("\n<required>" + string(reqJSON) + "</required>")
sb.WriteString("\n<required>" + r.pythonJSON(fn.Parameters.Required) + "</required>")
}
sb.WriteString("\n</parameters>")
@@ -159,27 +167,38 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string {
}
func (r *Nemotron3NanoRenderer) buildContent(message api.Message) string {
// The parser always extracts thinking into the Thinking field,
// so Content will never have <think> tags embedded
content := nemotron3NanoRenderContent(message.Content)
if message.Thinking != "" {
return "<think>\n" + message.Thinking + "\n</think>\n" + message.Content
return "<think>\n" + message.Thinking + "\n</think>\n" + content
}
return "<think></think>" + message.Content
if !strings.Contains(content, "<think>") && !strings.Contains(content, "</think>") {
return "<think></think>" + content
}
return content
}
func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, addNewline bool) string {
if content == "" {
func (r *Nemotron3NanoRenderer) formatAssistantContent(content string, truncate bool) string {
if !truncate {
return strings.TrimSpace(content)
}
c := content
if strings.Contains(c, "<think>") && strings.Contains(c, "</think>") {
parts := strings.Split(c, "</think>")
c = "<think></think>" + parts[len(parts)-1]
}
return strings.TrimSpace(c)
}
func (r *Nemotron3NanoRenderer) formatToolCallContent(content string, truncate bool) string {
if strings.TrimSpace(content) == "" {
return "<think></think>"
}
if !truncate {
if addNewline {
return strings.TrimSpace(content) + "\n"
}
return strings.TrimSpace(content)
return strings.TrimSpace(content) + "\n"
}
// Truncate thinking - keep only content after </think>
c := content
if strings.Contains(c, "</think>") {
parts := strings.Split(c, "</think>")
@@ -190,13 +209,7 @@ func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, add
}
c = "<think></think>" + strings.TrimSpace(c)
if addNewline && len(c) > len("<think></think>") {
return c + "\n"
}
if c == "<think></think>" {
return c
}
return strings.TrimSpace(c)
return strings.TrimSpace(c) + "\n"
}
func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) {
@@ -212,9 +225,225 @@ func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []
func (r *Nemotron3NanoRenderer) formatArgValue(value any) string {
switch v := value.(type) {
case map[string]any, []any:
jsonBytes, _ := json.Marshal(v)
return string(jsonBytes)
return r.pythonJSON(v)
default:
return fmt.Sprintf("%v", v)
}
}
func (r *Nemotron3NanoRenderer) renderMessageContent(message api.Message, imageOffset int) string {
content := nemotron3NanoRenderContent(message.Content)
if len(message.Images) == 0 {
return content
}
if strings.Contains(content, "[img-") {
return content
}
if strings.Contains(content, "[img]") {
for i := range message.Images {
content = strings.Replace(content, "[img]", fmt.Sprintf("[img-%d]", imageOffset+i), 1)
}
return content
}
var sb strings.Builder
for i := range message.Images {
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset+i))
}
sb.WriteString(content)
return sb.String()
}
func nemotron3NanoRenderContent(content any) string {
switch v := content.(type) {
case string:
return v
case []any:
var sb strings.Builder
for _, item := range v {
obj, ok := item.(map[string]any)
if !ok {
bts, _ := json.Marshal(item)
sb.Write(bts)
continue
}
switch obj["type"] {
case "image":
sb.WriteString("<image>")
case "text":
if text, ok := obj["text"].(string); ok {
sb.WriteString(text)
}
default:
bts, _ := json.Marshal(item)
sb.Write(bts)
}
}
return sb.String()
default:
bts, _ := json.Marshal(v)
return string(bts)
}
}
func (r *Nemotron3NanoRenderer) resolveThinking(messages []api.Message, thinkValue *api.ThinkValue) bool {
enableThinking := thinkValue == nil || thinkValue.Bool()
for _, message := range messages {
if message.Role != "user" && message.Role != "system" {
continue
}
content := message.Content
if strings.Contains(strings.ReplaceAll(content, "</think>", ""), "/think") {
enableThinking = true
} else if strings.Contains(content, "/no_think") {
enableThinking = false
}
}
return enableThinking
}
func (r *Nemotron3NanoRenderer) sanitizeSystemMessage(content string) string {
system := nemotron3NanoRenderContent(content)
system = strings.ReplaceAll(system, "</think>", "<_end_think>")
system = strings.ReplaceAll(system, "/think", "")
system = strings.ReplaceAll(system, "/no_think", "")
system = strings.ReplaceAll(system, "<_end_think>", "</think>")
return system
}
func (r *Nemotron3NanoRenderer) formatPropertyType(propertyType api.PropertyType) string {
if len(propertyType) == 1 {
return propertyType[0]
}
quoted := make([]string, 0, len(propertyType))
for _, v := range propertyType {
quoted = append(quoted, "'"+v+"'")
}
return "[" + strings.Join(quoted, ", ") + "]"
}
func (r *Nemotron3NanoRenderer) renderToolPropertyExtraKeys(sb *strings.Builder, prop api.ToolProperty) {
if len(prop.AnyOf) > 0 {
sb.WriteString("\n<anyOf>" + r.pythonJSON(prop.AnyOf) + "</anyOf>")
}
if prop.Items != nil {
sb.WriteString("\n<items>" + r.pythonJSON(prop.Items) + "</items>")
}
if prop.Properties != nil {
sb.WriteString("\n<properties>" + r.pythonJSON(prop.Properties) + "</properties>")
}
if len(prop.Required) > 0 {
sb.WriteString("\n<required>" + r.pythonJSON(prop.Required) + "</required>")
}
}
func (r *Nemotron3NanoRenderer) renderToolParameterExtraKeys(sb *strings.Builder, params api.ToolFunctionParameters) {
if params.Defs != nil {
sb.WriteString("\n<$defs>" + r.pythonJSON(params.Defs) + "</$defs>")
}
if params.Items != nil {
sb.WriteString("\n<items>" + r.pythonJSON(params.Items) + "</items>")
}
}
func (r *Nemotron3NanoRenderer) pythonJSON(v any) string {
switch value := v.(type) {
case nil:
return "null"
case string:
return strconv.Quote(value)
case bool:
if value {
return "true"
}
return "false"
case int, int8, int16, int32, int64:
return fmt.Sprintf("%d", reflect.ValueOf(value).Int())
case uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", reflect.ValueOf(value).Uint())
case float32, float64:
b, _ := json.Marshal(value)
return string(b)
case api.PropertyType:
return r.pythonJSON([]string(value))
case []string:
parts := make([]string, 0, len(value))
for _, item := range value {
parts = append(parts, r.pythonJSON(item))
}
return "[" + strings.Join(parts, ", ") + "]"
case []any:
parts := make([]string, 0, len(value))
for _, item := range value {
parts = append(parts, r.pythonJSON(item))
}
return "[" + strings.Join(parts, ", ") + "]"
case []api.ToolProperty:
parts := make([]string, 0, len(value))
for _, item := range value {
parts = append(parts, r.pythonJSON(item))
}
return "[" + strings.Join(parts, ", ") + "]"
case map[string]any:
keys := make([]string, 0, len(value))
for key := range value {
keys = append(keys, key)
}
sort.Strings(keys)
parts := make([]string, 0, len(keys))
for _, key := range keys {
parts = append(parts, strconv.Quote(key)+": "+r.pythonJSON(value[key]))
}
return "{" + strings.Join(parts, ", ") + "}"
case *api.ToolPropertiesMap:
if value == nil {
return "null"
}
parts := make([]string, 0, value.Len())
for key, prop := range value.All() {
parts = append(parts, strconv.Quote(key)+": "+r.pythonJSON(prop))
}
return "{" + strings.Join(parts, ", ") + "}"
case api.ToolProperty:
parts := make([]string, 0, 6)
if len(value.AnyOf) > 0 {
parts = append(parts, `"anyOf": `+r.pythonJSON(value.AnyOf))
}
if len(value.Type) > 0 {
if len(value.Type) == 1 {
parts = append(parts, `"type": `+r.pythonJSON(value.Type[0]))
} else {
parts = append(parts, `"type": `+r.pythonJSON([]string(value.Type)))
}
}
if value.Items != nil {
parts = append(parts, `"items": `+r.pythonJSON(value.Items))
}
if value.Description != "" {
parts = append(parts, `"description": `+r.pythonJSON(value.Description))
}
if len(value.Enum) > 0 {
parts = append(parts, `"enum": `+r.pythonJSON(value.Enum))
}
if value.Properties != nil {
parts = append(parts, `"properties": `+r.pythonJSON(value.Properties))
}
if len(value.Required) > 0 {
parts = append(parts, `"required": `+r.pythonJSON(value.Required))
}
return "{" + strings.Join(parts, ", ") + "}"
default:
b, err := json.Marshal(value)
if err != nil {
return "null"
}
var generic any
if err := json.Unmarshal(b, &generic); err != nil {
return string(b)
}
return r.pythonJSON(generic)
}
}

View File

@@ -0,0 +1,614 @@
package renderers
import (
"encoding/json"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const nemotron3NanoTemplate = "testdata/nemotron3nano_chat_template.jinja2"
func TestNemotron3NanoRendererMatchesReference(t *testing.T) {
toolText := `<|im_start|>system
# Tools
You have access to the following functions:
<tools>
<function>
<name>search_docs</name>
<description>Search docs</description>
<parameters>
<parameter>
<name>query</name>
<type>string</type>
<description>Search query</description>
<enum>["api", "cli"]</enum>
</parameter>
<parameter>
<name>mode</name>
<type>['string', 'null']</type>
<description>Mode</description>
<anyOf>[{"type": "string"}, {"type": "number"}]</anyOf>
</parameter>
<parameter>
<name>payload</name>
<type>object</type>
<description>Payload</description>
<properties>{"enabled": {"type": "boolean"}}</properties>
<required>["enabled"]</required>
</parameter>
<parameter>
<name>tags</name>
<type>array</type>
<description>Tags</description>
<items>{"type": "string"}</items>
</parameter>
<$defs>{"shared": {"type": "string"}}</$defs>
<required>["query"]</required>
</parameters>
</function>
</tools>
If you choose to call a function ONLY reply in the following format with NO suffix:
<tool_call>
<function=example_function_name>
<parameter=example_parameter_1>
value_1
</parameter>
<parameter=example_parameter_2>
This is the value for the second parameter
that can span
multiple lines
</parameter>
</function>
</tool_call>
<IMPORTANT>
Reminder:
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
- Required parameters MUST be specified
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
</IMPORTANT><|im_end|>
`
toolTextWithSystem := `<|im_start|>system
Follow policy.
# Tools
You have access to the following functions:
<tools>
<function>
<name>search_docs</name>
<description>Search docs</description>
<parameters>
<parameter>
<name>query</name>
<type>string</type>
<description>Search query</description>
<enum>["api", "cli"]</enum>
</parameter>
<parameter>
<name>mode</name>
<type>['string', 'null']</type>
<description>Mode</description>
<anyOf>[{"type": "string"}, {"type": "number"}]</anyOf>
</parameter>
<parameter>
<name>payload</name>
<type>object</type>
<description>Payload</description>
<properties>{"enabled": {"type": "boolean"}}</properties>
<required>["enabled"]</required>
</parameter>
<parameter>
<name>tags</name>
<type>array</type>
<description>Tags</description>
<items>{"type": "string"}</items>
</parameter>
<$defs>{"shared": {"type": "string"}}</$defs>
<required>["query"]</required>
</parameters>
</function>
</tools>
If you choose to call a function ONLY reply in the following format with NO suffix:
<tool_call>
<function=example_function_name>
<parameter=example_parameter_1>
value_1
</parameter>
<parameter=example_parameter_2>
This is the value for the second parameter
that can span
multiple lines
</parameter>
</function>
</tool_call>
<IMPORTANT>
Reminder:
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
- Required parameters MUST be specified
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
</IMPORTANT><|im_end|>
`
tests := []struct {
name string
messages []api.Message
tools []api.Tool
think *api.ThinkValue
expected string
}{
{
name: "no system default thinking on",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHello<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "no system explicit thinking off",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
think: thinkFalse(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHello<|im_end|>\n\n<|im_start|>assistant\n<think></think>",
},
{
name: "literal endthink does not enable thinking",
messages: []api.Message{
{Role: "user", Content: "literal </think> only"},
},
think: thinkFalse(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nliteral </think> only<|im_end|>\n\n<|im_start|>assistant\n<think></think>",
},
{
name: "user no think toggle",
messages: []api.Message{
{Role: "user", Content: "Hello /no_think"},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHello /no_think<|im_end|>\n\n<|im_start|>assistant\n<think></think>",
},
{
name: "system think toggle overrides false",
messages: []api.Message{
{Role: "system", Content: "Policy /think"},
{Role: "user", Content: "Hello"},
},
think: thinkFalse(),
expected: "\n\n\n<|im_start|>system\nPolicy <|im_end|>\n\n<|im_start|>user\nHello<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "later toggle wins",
messages: []api.Message{
{Role: "system", Content: "Policy /no_think"},
{Role: "user", Content: "Actually /think"},
},
think: thinkFalse(),
expected: "\n\n\n<|im_start|>system\nPolicy <|im_end|>\n\n<|im_start|>user\nActually /think<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "system sanitizes toggles but preserves closing tag",
messages: []api.Message{
{Role: "system", Content: "A /think B /no_think C </think>"},
{Role: "user", Content: "Hello"},
},
think: thinkFalse(),
expected: "\n\n\n<|im_start|>system\nA B C </think><|im_end|>\n\n<|im_start|>user\nHello<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "assistant plain content adds empty think block",
messages: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello there"},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think></think>Hello there<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "assistant reasoning content",
messages: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Answer", Thinking: "Need to think"},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think>\nNeed to think\n</think>\nAnswer<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "assistant preserves existing think tags",
messages: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "<think>kept</think>Answer"},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think>kept</think>Answer<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "tools without system",
messages: []api.Message{
{Role: "user", Content: "Use a tool"},
},
tools: nemotron3NanoReferenceTools(),
think: thinkTrue(),
expected: "\n\n\n" + toolText + "\n<|im_start|>user\nUse a tool<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "system with tools",
messages: []api.Message{
{Role: "system", Content: "Follow policy."},
{Role: "user", Content: "Use a tool"},
},
tools: nemotron3NanoReferenceTools(),
think: thinkTrue(),
expected: "\n\n\n" + toolTextWithSystem + "\n<|im_start|>user\nUse a tool<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "assistant tool call with content",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
Content: "Checking now.",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
}},
},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nWeather?<|im_end|>\n<|im_start|>assistant\n<think></think>Checking now.\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "assistant tool call with structured arguments",
messages: []api.Message{
{Role: "user", Content: "Create data"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "create",
Arguments: testArgsOrdered([]orderedArg{
{Key: "payload", Value: map[string]any{"count": 42, "nested": map[string]any{"value": "ok"}}},
{Key: "tags", Value: []any{"a", "b"}},
}),
},
}},
},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nCreate data<|im_end|>\n<|im_start|>assistant\n<think></think>\n<tool_call>\n<function=create>\n<parameter=payload>\n{\"count\": 42, \"nested\": {\"value\": \"ok\"}}\n</parameter>\n<parameter=tags>\n[\"a\", \"b\"]\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "assistant tool call truncated with reasoning",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
Content: "Checking now.",
Thinking: "Need weather",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
}},
},
{Role: "user", Content: "And tomorrow?"},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nWeather?<|im_end|>\n<|im_start|>assistant\n<think></think>Checking now.\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n<|im_start|>user\nAnd tomorrow?<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "assistant tool call truncated open think only",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
Content: "<think>draft",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
}},
},
{Role: "user", Content: "And tomorrow?"},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nWeather?<|im_end|>\n<|im_start|>assistant\n<think></think>\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n<|im_start|>user\nAnd tomorrow?<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "assistant tool call empty content",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
Content: "",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
}},
},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nWeather?<|im_end|>\n<|im_start|>assistant\n<think></think>\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "assistant truncated with think pair",
messages: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "<think>hidden</think>Visible"},
{Role: "user", Content: "Next"},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think></think>Visible<|im_end|>\n<|im_start|>user\nNext<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "assistant truncated reasoning content",
messages: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Thinking: "hidden", Content: "Visible"},
{Role: "user", Content: "Next"},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think></think>\nVisible<|im_end|>\n<|im_start|>user\nNext<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "assistant truncated plain content",
messages: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Visible"},
{Role: "user", Content: "Next"},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think></think>Visible<|im_end|>\n<|im_start|>user\nNext<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "assistant truncated empty content",
messages: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: ""},
{Role: "user", Content: "Next"},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think></think><|im_end|>\n<|im_start|>user\nNext<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "consecutive tool messages grouped",
messages: []api.Message{
{Role: "user", Content: "Do work"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "step",
Arguments: testArgs(map[string]any{"value": 1}),
},
}},
},
{Role: "tool", Content: "one"},
{Role: "tool", Content: "two"},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nDo work<|im_end|>\n<|im_start|>assistant\n<think></think>\n<tool_call>\n<function=step>\n<parameter=value>\n1\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n<|im_start|>user\n<tool_response>\none\n</tool_response>\n<tool_response>\ntwo\n</tool_response>\n<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "fallback role",
messages: []api.Message{
{Role: "developer", Content: "Custom role content"},
},
think: thinkTrue(),
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>developer\nCustom role content<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
}
verifyJinja2 := os.Getenv("VERIFY_JINJA2") != ""
if verifyJinja2 {
if _, err := os.Stat(filepath.Join(nemotron3NanoRepoRoot(t), ".venv", "bin", "python3")); err != nil {
t.Fatal("VERIFY_JINJA2=1 requires .venv/bin/python3")
}
}
renderer := &Nemotron3NanoRenderer{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
if err != nil {
t.Fatalf("Render() error = %v", err)
}
if diff := cmp.Diff(tt.expected, got); diff != "" {
t.Fatalf("renderer mismatch (-want +got):\n%s", diff)
}
if verifyJinja2 {
jinja2Output := renderNemotron3NanoWithJinja2(t, tt.messages, tt.tools, tt.think)
if diff := cmp.Diff(tt.expected, jinja2Output); diff != "" {
t.Fatalf("reference template mismatch (-want +got):\n%s", diff)
}
}
})
}
}
func nemotron3NanoReferenceTools() []api.Tool {
return []api.Tool{{
Type: "function",
Function: api.ToolFunction{
Name: "search_docs",
Description: "Search docs",
Parameters: api.ToolFunctionParameters{
Type: "object",
Defs: map[string]any{"shared": map[string]any{"type": "string"}},
Required: []string{"query"},
Properties: testPropsOrdered([]orderedProp{
{
Key: "query",
Value: api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "Search query",
Enum: []any{"api", "cli"},
},
},
{
Key: "mode",
Value: api.ToolProperty{
Type: api.PropertyType{"string", "null"},
Description: "Mode",
AnyOf: []api.ToolProperty{
{Type: api.PropertyType{"string"}},
{Type: api.PropertyType{"number"}},
},
},
},
{
Key: "payload",
Value: api.ToolProperty{
Type: api.PropertyType{"object"},
Description: "Payload",
Properties: testPropsOrdered([]orderedProp{{Key: "enabled", Value: api.ToolProperty{Type: api.PropertyType{"boolean"}}}}),
Required: []string{"enabled"},
},
},
{
Key: "tags",
Value: api.ToolProperty{
Type: api.PropertyType{"array"},
Description: "Tags",
Items: map[string]any{"type": "string"},
},
},
}),
},
},
}}
}
func renderNemotron3NanoWithJinja2(t *testing.T, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string {
t.Helper()
type jinja2ToolCall struct {
ID string `json:"id,omitempty"`
Function struct {
Name string `json:"name"`
Arguments any `json:"arguments"`
} `json:"function"`
}
type jinja2Message struct {
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []jinja2ToolCall `json:"tool_calls,omitempty"`
Name string `json:"name,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
var jMsgs []jinja2Message
for _, m := range messages {
jm := jinja2Message{
Role: m.Role,
Content: m.Content,
ReasoningContent: m.Thinking,
Name: m.ToolName,
ToolCallID: m.ToolCallID,
}
for _, tc := range m.ToolCalls {
jtc := jinja2ToolCall{ID: tc.ID}
jtc.Function.Name = tc.Function.Name
var args map[string]any
raw, _ := tc.Function.Arguments.MarshalJSON()
if err := json.Unmarshal(raw, &args); err != nil {
t.Fatalf("failed to unmarshal tool args: %v", err)
}
jtc.Function.Arguments = args
jm.ToolCalls = append(jm.ToolCalls, jtc)
}
jMsgs = append(jMsgs, jm)
}
msgsJSON, err := json.Marshal(jMsgs)
if err != nil {
t.Fatalf("failed to marshal messages: %v", err)
}
toolsJSON := "None"
if len(tools) > 0 {
b, err := json.Marshal(tools)
if err != nil {
t.Fatalf("failed to marshal tools: %v", err)
}
toolsJSON = string(b)
}
thinking := "unset"
if think != nil {
if think.Bool() {
thinking = "true"
} else {
thinking = "false"
}
}
repoRoot := nemotron3NanoRepoRoot(t)
templatePath := filepath.Join(repoRoot, "model", "renderers", nemotron3NanoTemplate)
pythonPath := filepath.Join(repoRoot, ".venv", "bin", "python3")
script := `
import json
import sys
from pathlib import Path
from transformers.utils.chat_template_utils import _compile_jinja_template
template_path, messages_json, tools_json, thinking = sys.argv[1:5]
tmpl = _compile_jinja_template(Path(template_path).read_text())
kwargs = {
"messages": json.loads(messages_json),
"add_generation_prompt": True,
}
if tools_json != "None":
kwargs["tools"] = json.loads(tools_json)
if thinking == "true":
kwargs["enable_thinking"] = True
elif thinking == "false":
kwargs["enable_thinking"] = False
print(tmpl.render(**kwargs), end="")
`
cmd := exec.Command(pythonPath, "-c", script, templatePath, string(msgsJSON), toolsJSON, thinking)
var stdout, stderr strings.Builder
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
t.Fatalf("python render failed: %v\nstderr: %s", err, stderr.String())
}
return stdout.String()
}
func nemotron3NanoRepoRoot(t *testing.T) string {
t.Helper()
_, filename, _, ok := runtime.Caller(0)
if !ok {
t.Fatal("failed to locate test file")
}
return filepath.Dir(filepath.Dir(filepath.Dir(filename)))
}

View File

@@ -8,561 +8,46 @@ import (
"github.com/ollama/ollama/api"
)
func TestNemotron3NanoRenderer(t *testing.T) {
func TestNemotron3NanoRenderer_Images(t *testing.T) {
tests := []struct {
name string
msgs []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
name string
msgs []api.Message
expected string
}{
{
name: "basic user message - thinking mode",
name: "single image inserts placeholder",
msgs: []api.Message{
{Role: "user", Content: "Hello!"},
{Role: "user", Content: "Describe this image.", Images: []api.ImageData{api.ImageData("img1")}},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\n<|im_end|>\n" +
"<|im_start|>user\nHello!<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0]Describe this image.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "basic user message - no thinking",
name: "generic image placeholder is rewritten",
msgs: []api.Message{
{Role: "user", Content: "Hello!"},
{Role: "user", Content: "[img]Describe this image.", Images: []api.ImageData{api.ImageData("img1")}},
},
thinkValue: nil,
expected: "<|im_start|>system\n<|im_end|>\n" +
"<|im_start|>user\nHello!<|im_end|>\n" +
"<|im_start|>assistant\n<think></think>",
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0]Describe this image.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
{
name: "with system message",
name: "image offsets increment across messages",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello!"},
{Role: "user", Content: "Describe the first image.", Images: []api.ImageData{api.ImageData("img1")}},
{Role: "assistant", Content: "It shows something."},
{Role: "user", Content: "Compare these.", Images: []api.ImageData{api.ImageData("img2"), api.ImageData("img3")}},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" +
"<|im_start|>user\nHello!<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
},
{
name: "multi-turn conversation",
msgs: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello! How can I help?"},
{Role: "user", Content: "Tell me a joke"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\n<|im_end|>\n" +
"<|im_start|>user\nHi<|im_end|>\n" +
"<|im_start|>assistant\n<think></think>Hello! How can I help?<|im_end|>\n" +
"<|im_start|>user\nTell me a joke<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
},
{
name: "with tools",
msgs: []api.Message{
{Role: "user", Content: "What's the weather in Paris?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"city"},
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
}),
},
},
},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\n" +
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
"<function>\n<name>get_weather</name>\n" +
"<description>Get the current weather</description>\n" +
"<parameters>\n" +
"<parameter>\n<name>city</name>\n<type>string</type>\n<description>The city name</description>\n</parameter>\n" +
"<required>[\"city\"]</required>\n" +
"</parameters>\n</function>\n</tools>\n\n" +
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
"- Required parameters MUST be specified\n" +
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
"</IMPORTANT><|im_end|>\n" +
"<|im_start|>user\nWhat's the weather in Paris?<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
},
{
name: "tool call with response",
msgs: []api.Message{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny, 72F"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"city"},
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
}),
},
},
},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\n" +
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
"<function>\n<name>get_weather</name>\n" +
"<description>Get the current weather</description>\n" +
"<parameters>\n" +
"<parameter>\n<name>city</name>\n<type>string</type>\n<description>The city name</description>\n</parameter>\n" +
"<required>[\"city\"]</required>\n" +
"</parameters>\n</function>\n</tools>\n\n" +
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
"- Required parameters MUST be specified\n" +
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
"</IMPORTANT><|im_end|>\n" +
"<|im_start|>user\nWhat's the weather in Paris?<|im_end|>\n" +
"<|im_start|>assistant\n<think></think>\n" +
"<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
"<|im_start|>user\n<tool_response>\nSunny, 72F\n</tool_response>\n<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
},
{
name: "assistant with content and tool call",
msgs: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "Let me check that for you.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\n" +
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
"<function>\n<name>get_weather</name>\n" +
"<parameters>\n" +
"<parameter>\n<name>city</name>\n<type>string</type>\n</parameter>\n" +
"</parameters>\n</function>\n</tools>\n\n" +
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
"- Required parameters MUST be specified\n" +
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
"</IMPORTANT><|im_end|>\n" +
"<|im_start|>user\nWhat's the weather?<|im_end|>\n" +
"<|im_start|>assistant\n<think></think>Let me check that for you.\n" +
"<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
"<|im_start|>user\n<tool_response>\nSunny\n</tool_response>\n<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
},
{
name: "thinking in history is truncated",
msgs: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello!", Thinking: "Let me think about this..."},
{Role: "user", Content: "How are you?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\n<|im_end|>\n" +
"<|im_start|>user\nHi<|im_end|>\n" +
"<|im_start|>assistant\n<think></think>Hello!<|im_end|>\n" +
"<|im_start|>user\nHow are you?<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
},
{
name: "parallel tool calls",
msgs: []api.Message{
{Role: "user", Content: "Weather in Paris and London?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "London"}),
},
},
},
},
{Role: "tool", Content: "Sunny"},
{Role: "tool", Content: "Rainy"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\n" +
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
"<function>\n<name>get_weather</name>\n" +
"<parameters>\n" +
"<parameter>\n<name>city</name>\n<type>string</type>\n</parameter>\n" +
"</parameters>\n</function>\n</tools>\n\n" +
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
"- Required parameters MUST be specified\n" +
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
"</IMPORTANT><|im_end|>\n" +
"<|im_start|>user\nWeather in Paris and London?<|im_end|>\n" +
"<|im_start|>assistant\n<think></think>\n" +
"<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n" +
"<tool_call>\n<function=get_weather>\n<parameter=city>\nLondon\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
"<|im_start|>user\n<tool_response>\nSunny\n</tool_response>\n<tool_response>\nRainy\n</tool_response>\n<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
},
{
name: "thinking disabled when user doesn't request it",
msgs: []api.Message{
{Role: "user", Content: "Hello!"},
},
thinkValue: nil,
expected: "<|im_start|>system\n<|im_end|>\n" +
"<|im_start|>user\nHello!<|im_end|>\n" +
"<|im_start|>assistant\n<think></think>",
},
{
name: "complex message history with thinking, tools, tool calls, tool results and content",
msgs: []api.Message{
{Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"},
{Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "Paris"})}},
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "London"})}},
}},
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"},
{Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"},
{Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "calculate", Arguments: testArgs(map[string]any{"expression": "2+2"})}},
}},
{Role: "tool", Content: "4", ToolCallID: "call3"},
{Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}},
}),
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "calculate",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"expression": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\n" +
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
"<function>\n<name>get_weather</name>\n" +
"<parameters>\n" +
"<parameter>\n<name>city</name>\n<type>string</type>\n</parameter>\n" +
"</parameters>\n</function>\n" +
"<function>\n<name>calculate</name>\n" +
"<parameters>\n" +
"<parameter>\n<name>expression</name>\n<type>string</type>\n</parameter>\n" +
"</parameters>\n</function>\n</tools>\n\n" +
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
"- Required parameters MUST be specified\n" +
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
"</IMPORTANT><|im_end|>\n" +
"<|im_start|>user\nWhat's the weather in Paris and London? Also, what's 2+2?<|im_end|>\n" +
"<|im_start|>assistant\n" +
"<think>\nI need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.\n</think>\n" +
"<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n" +
"<tool_call>\n<function=get_weather>\n<parameter=city>\nLondon\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
"<|im_start|>user\n<tool_response>\nSunny, 22°C\n</tool_response>\n<tool_response>\nRainy, 15°C\n</tool_response>\n<|im_end|>\n" +
"<|im_start|>assistant\n" +
"<think>\nNow I have the weather data. Let me calculate 2+2.\n</think>\n" +
"<tool_call>\n<function=calculate>\n<parameter=expression>\n2+2\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
"<|im_start|>user\n<tool_response>\n4\n</tool_response>\n<|im_end|>\n" +
"<|im_start|>assistant\n" +
"<think>\nPerfect! I have all the information needed to provide a complete answer.\n</think>\n" +
"Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
},
{
name: "empty messages list",
msgs: []api.Message{},
thinkValue: nil,
expected: "<|im_start|>system\n<|im_end|>\n<|im_start|>assistant\n<think></think>",
},
{
name: "tool result with JSON content",
msgs: []api.Message{
{Role: "user", Content: "Get user info"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "get_user", Arguments: testArgs(map[string]any{"id": "123"})}},
},
},
{Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_user",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}}),
},
},
},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\n" +
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
"<function>\n<name>get_user</name>\n<parameters>\n" +
"<parameter>\n<name>id</name>\n<type>string</type>\n</parameter>\n" +
"</parameters>\n</function>\n</tools>\n\n" +
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
"- Required parameters MUST be specified\n" +
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
"</IMPORTANT><|im_end|>\n" +
"<|im_start|>user\nGet user info<|im_end|>\n" +
"<|im_start|>assistant\n<think></think>\n" +
"<tool_call>\n<function=get_user>\n<parameter=id>\n123\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
"<|im_start|>user\n<tool_response>\n{\"name\": \"John\", \"age\": 30, \"active\": true}\n</tool_response>\n<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
},
{
name: "assistant message with only thinking no content",
msgs: []api.Message{
{Role: "user", Content: "Think about this"},
{Role: "assistant", Thinking: "Deep thoughts here...", Content: ""},
{Role: "user", Content: "What did you think?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\n<|im_end|>\n" +
"<|im_start|>user\nThink about this<|im_end|>\n" +
"<|im_start|>assistant\n<think></think><|im_end|>\n" +
"<|im_start|>user\nWhat did you think?<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
},
{
name: "tool call with complex nested argument",
msgs: []api.Message{
{Role: "user", Content: "Create data"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{
Name: "create",
Arguments: testArgs(map[string]any{
"data": map[string]any{"nested": "value", "count": 42},
}),
}},
},
},
{Role: "tool", Content: "Created"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "create",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}}),
},
},
},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\n" +
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
"<function>\n<name>create</name>\n<parameters>\n" +
"<parameter>\n<name>data</name>\n<type>object</type>\n</parameter>\n" +
"</parameters>\n</function>\n</tools>\n\n" +
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
"- Required parameters MUST be specified\n" +
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
"</IMPORTANT><|im_end|>\n" +
"<|im_start|>user\nCreate data<|im_end|>\n" +
"<|im_start|>assistant\n<think></think>\n" +
"<tool_call>\n<function=create>\n<parameter=data>\n{\"count\":42,\"nested\":\"value\"}\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
"<|im_start|>user\n<tool_response>\nCreated\n</tool_response>\n<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
},
{
name: "content explaining the format itself",
msgs: []api.Message{
{Role: "user", Content: "How do I format a tool call?"},
{Role: "assistant", Content: "To call a tool, use <tool_call> tags with <function=name> inside."},
{Role: "user", Content: "Thanks!"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\n<|im_end|>\n" +
"<|im_start|>user\nHow do I format a tool call?<|im_end|>\n" +
"<|im_start|>assistant\n<think></think>To call a tool, use <tool_call> tags with <function=name> inside.<|im_end|>\n" +
"<|im_start|>user\nThanks!<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
},
{
name: "unicode in content and tool args",
msgs: []api.Message{
{Role: "user", Content: "Translate 你好"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "translate", Arguments: testArgs(map[string]any{"text": "你好"})}},
},
},
{Role: "tool", Content: "Hello"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "translate",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"text": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>system\n" +
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
"<function>\n<name>translate</name>\n<parameters>\n" +
"<parameter>\n<name>text</name>\n<type>string</type>\n</parameter>\n" +
"</parameters>\n</function>\n</tools>\n\n" +
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
"- Required parameters MUST be specified\n" +
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
"</IMPORTANT><|im_end|>\n" +
"<|im_start|>user\nTranslate 你好<|im_end|>\n" +
"<|im_start|>assistant\n<think></think>\n" +
"<tool_call>\n<function=translate>\n<parameter=text>\n你好\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
"<|im_start|>user\n<tool_response>\nHello\n</tool_response>\n<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n",
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0]Describe the first image.<|im_end|>\n<|im_start|>assistant\n<think></think>It shows something.<|im_end|>\n<|im_start|>user\n[img-1][img-2]Compare these.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
renderer := &Nemotron3NanoRenderer{}
rendered, err := renderer.Render(tt.msgs, tt.tools, tt.thinkValue)
rendered, err := renderer.Render(tt.msgs, nil, nil)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
t.Fatalf("mismatch (-want +got):\n%s", diff)
}
})
}

View File

@@ -95,6 +95,8 @@ func rendererForName(name string) Renderer {
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
case "lfm2-thinking":
return &LFM2Renderer{IsThinking: true, useImgTags: RenderImgTags}
case "laguna":
return &LagunaRenderer{}
default:
return nil
}

View File

@@ -0,0 +1,222 @@
{% macro render_extra_keys(json_dict, handled_keys) %}
{%- if json_dict is mapping %}
{%- for json_key in json_dict if json_key not in handled_keys %}
{%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
{{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
{%- else %}
{{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
{%- endif %}
{%- endfor %}
{%- endif %}
{% endmacro %}
{%- set enable_thinking = enable_thinking if enable_thinking is defined else True %}
{%- set reasoning_budget = reasoning_budget if reasoning_budget is defined else None %}
{%- set truncate_history_thinking = truncate_history_thinking if truncate_history_thinking is defined else True %}
{%- set response_format = response_format if response_format is defined else None %}
{# Scan messages for VLM thinking toggles to override enable_thinking #}
{%- set toggle = namespace(enable=enable_thinking) %}
{%- for m in messages %}
{%- if m['role'] == 'user' or m['role'] == 'system' -%}
{%- if m['content'] is string -%}
{%- set c = m['content'] %}
{%- if '/think' in c.replace('</think>', '') -%}
{%- set toggle.enable = true -%}
{%- elif '/no_think' in c -%}
{%- set toggle.enable = false -%}
{%- endif -%}
{%- endif -%}
{%- endif -%}
{%- endfor %}
{# Prepare message iteration similar to LM template #}
{%- set ns = namespace(last_user_idx = -1) %}
{%- set loop_messages = messages %}
{%- for m in loop_messages %}
{%- if m["role"] == "user" %}
{%- set ns.last_user_idx = loop.index0 %}
{%- endif %}
{%- endfor %}
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set system_message = "" %}
{%- set loop_messages = messages %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = [] %}
{%- endif %}
{# Recompute last_user_idx relative to loop_messages after handling system #}
{%- set ns = namespace(last_user_idx = -1) %}
{%- for m in loop_messages %}
{%- if m["role"] == "user" %}
{%- set ns.last_user_idx = loop.index0 %}
{%- endif %}
{%- endfor %}
{# System preamble with LM formatting, sanitize thinking toggles #}
{%- if system_message is defined %}
{%- set sys_content = system_message | string %}
{%- set sys_content = sys_content.replace('</think>', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '</think>') %}
{{- "<|im_start|>system\n" + sys_content }}
{%- else %}
{%- if tools is iterable and tools | length > 0 %}
{{- "<|im_start|>system\n" }}
{%- endif %}
{%- endif %}
{%- if tools is iterable and tools | length > 0 %}
{%- if system_message is defined and system_message | length > 0 %}
{{- "\n\n" }}
{%- endif %}
{{- "# Tools\n\nYou have access to the following functions:\n\n" }}
{{- "<tools>" }}
{%- for tool in tools %}
{%- if tool.function is defined %}
{%- set tool = tool.function %}
{%- endif %}
{{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
{%- if tool.description is defined %}
{{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
{%- endif %}
{{- '\n<parameters>' }}
{%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
{%- for param_name, param_fields in tool.parameters.properties|items %}
{{- '\n<parameter>' }}
{{- '\n<name>' ~ param_name ~ '</name>' }}
{%- if param_fields.type is defined %}
{{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
{%- endif %}
{%- if param_fields.description is defined %}
{{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
{%- endif %}
{%- if param_fields.enum is defined %}
{{- '\n<enum>' ~ (param_fields.enum | tojson | safe) ~ '</enum>' }}
{%- endif %}
{%- set handled_keys = ['name', 'type', 'description', 'enum'] %}
{{- render_extra_keys(param_fields, handled_keys) }}
{{- '\n</parameter>' }}
{%- endfor %}
{%- endif %}
{% set handled_keys = ['type', 'properties', 'required'] %}
{{- render_extra_keys(tool.parameters, handled_keys) }}
{%- if tool.parameters is defined and tool.parameters.required is defined %}
{{- '\n<required>' ~ (tool.parameters.required | tojson | safe) ~ '</required>' }}
{%- endif %}
{{- '\n</parameters>' }}
{%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
{{- render_extra_keys(tool, handled_keys) }}
{{- '\n</function>' }}
{%- endfor %}
{{- "\n</tools>" }}
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
{%- endif %}
{%- if system_message is defined %}
{{- '<|im_end|>\n' }}
{%- else %}
{%- if tools is iterable and tools | length > 0 %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{# Iterate conversation #}
{%- for message in loop_messages %}
{%- if message.role == "assistant" %}
{# Use LM assistant handling #}
{%- if message.reasoning_content is defined and message.reasoning_content is string and message.reasoning_content | trim | length > 0 %}
{%- set content = "<think>\n" ~ message.reasoning_content ~ "\n</think>\n" ~ (message.content | default('', true)) %}
{%- else %}
{%- set content = message.content | default('', true) %}
{%- if content is string -%}
{%- if '<think>' not in content and '</think>' not in content -%}
{%- set content = "<think></think>" ~ content -%}
{%- endif -%}
{%- else -%}
{%- set content = content -%}
{%- endif -%}
{%- endif %}
{%- if message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}
{{- '<|im_start|>assistant\n' }}
{%- set include_content = not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
{%- if content is string and content | trim | length > 0 %}
{%- if include_content %}
{{- (content | trim) ~ '\n' -}}
{%- else %}
{%- set c = (content | string) %}
{%- if '</think>' in c %}
{%- set c = c.split('</think>')[-1] %}
{%- elif '<think>' in c %}
{%- set c = c.split('<think>')[0] %}
{%- endif %}
{%- set c = "<think></think>" ~ c | trim %}
{%- if c | length > 0 %}
{{- c ~ '\n' -}}
{%- endif %}
{%- endif %}
{%- else %}
{{- "<think></think>" -}}
{%- endif %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>\n<function=' ~ tool_call.name ~ '>\n' -}}
{%- if tool_call.arguments is defined %}
{%- for args_name, args_value in tool_call.arguments|items %}
{{- '<parameter=' ~ args_name ~ '>\n' -}}
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
{{- args_value ~ '\n</parameter>\n' -}}
{%- endfor %}
{%- endif %}
{{- '</function>\n</tool_call>\n' -}}
{%- endfor %}
{{- '<|im_end|>\n' }}
{%- else %}
{%- if not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
{{- '<|im_start|>assistant\n' ~ (content | default('', true) | string | trim) ~ '<|im_end|>\n' }}
{%- else %}
{%- set c = (content | default('', true) | string) %}
{%- if '<think>' in c and '</think>' in c %}
{%- set c = "<think></think>" ~ c.split('</think>')[-1] %}
{%- endif %}
{%- set c = c | trim %}
{%- if c | length > 0 %}
{{- '<|im_start|>assistant\n' ~ c ~ '<|im_end|>\n' }}
{%- else %}
{{- '<|im_start|>assistant\n<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endif %}
{%- elif message.role == "user" or message.role == "system" %}
{{- '<|im_start|>' + message.role + '\n' }}
{%- set content = message.content | string %}
{{- content }}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.previtem and loop.previtem.role != "tool" %}
{{- '<|im_start|>user\n' }}
{%- endif %}
{{- '<tool_response>\n' }}
{{- message.content }}
{{- '\n</tool_response>\n' }}
{%- if not loop.last and loop.nextitem.role != "tool" %}
{{- '<|im_end|>\n' }}
{%- elif loop.last %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
{%- endif %}
{%- endfor %}
{# Generation prompt using computed thinking toggle #}
{%- if add_generation_prompt %}
{%- if toggle.enable %}
{{- '<|im_start|>assistant\n<think>\n' }}
{%- else %}
{{- '<|im_start|>assistant\n<think></think>' }}
{%- endif %}
{%- endif %}

View File

@@ -494,15 +494,18 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
for _, layer := range baseLayers {
if layer.GGML != nil {
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
ft := layer.GGML.KV().FileType()
if quantType == "" && hasSourceFP8Tensors(layer.GGML.KV()) && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" && slices.Contains([]string{"F16", "BF16", "F32"}, ft.String()) {
quantType = "Q8_0"
}
if quantType != "" && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" {
want, err := ggml.ParseFileType(quantType)
if err != nil {
return err
}
ft := layer.GGML.KV().FileType()
if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
return errors.New("quantization is only supported for F16 and F32 models")
if !slices.Contains([]string{"F16", "BF16", "F32"}, ft.String()) {
return errors.New("quantization is only supported for F16, BF16 and F32 models")
} else if ft != want {
layer, err = quantizeLayer(layer, quantType, fn)
if err != nil {
@@ -531,6 +534,12 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
}
r.Parameters["stop"] = []string{"<turn|>"}
}
case "laguna":
config.Renderer = cmp.Or(config.Renderer, "laguna")
config.Parser = cmp.Or(config.Parser, "laguna")
case "nemotron_h", "nemotron_h_moe", "nemotron_h_omni":
config.Renderer = cmp.Or(config.Renderer, "nemotron-3-nano")
config.Parser = cmp.Or(config.Parser, "nemotron-3-nano")
}
}
}
@@ -606,6 +615,10 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
return nil
}
func hasSourceFP8Tensors(kv ggml.KV) bool {
return kv.String("source_quantization") == "hf_fp8" && len(kv.Strings("source_fp8_tensors")) > 0
}
func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.ProgressResponse)) (*layerGGML, error) {
ft := layer.GGML.KV().FileType()
var doneBytes atomic.Uint64

View File

@@ -0,0 +1,90 @@
package server
import (
"testing"
fsggml "github.com/ollama/ollama/fs/ggml"
)
func TestLagunaGGUFQuantization(t *testing.T) {
cases := []struct {
name string
tensor string
originalType fsggml.TensorType
requestedType fsggml.TensorType
fileType fsggml.FileType
blockCount int
wantType fsggml.TensorType
wantQuantize bool
}{
{
name: "non_routed_weights_preserved",
tensor: "blk.1.attn_q.weight",
originalType: fsggml.TensorTypeBF16,
requestedType: fsggml.TensorTypeQ8_0,
fileType: fsggml.FileTypeQ8_0,
blockCount: 2,
wantType: fsggml.TensorTypeBF16,
wantQuantize: false,
},
{
name: "shared_expert_weights_preserved",
tensor: "blk.1.ffn_gate_shexp.weight",
originalType: fsggml.TensorTypeBF16,
requestedType: fsggml.TensorTypeQ4_K,
fileType: fsggml.FileTypeQ4_K_M,
blockCount: 2,
wantType: fsggml.TensorTypeBF16,
wantQuantize: false,
},
{
name: "routed_gate_q8",
tensor: "blk.1.ffn_gate_exps.weight",
originalType: fsggml.TensorTypeBF16,
requestedType: fsggml.TensorTypeQ8_0,
fileType: fsggml.FileTypeQ8_0,
blockCount: 2,
wantType: fsggml.TensorTypeQ8_0,
wantQuantize: true,
},
{
name: "routed_down_q4_promoted",
tensor: "blk.1.ffn_down_exps.weight",
originalType: fsggml.TensorTypeBF16,
requestedType: fsggml.TensorTypeQ4_K,
fileType: fsggml.FileTypeQ4_K_M,
blockCount: 2,
wantType: fsggml.TensorTypeQ6_K,
wantQuantize: true,
},
{
name: "routed_down_q4_not_promoted_when_q8_requested",
tensor: "blk.1.ffn_down_exps.weight",
originalType: fsggml.TensorTypeBF16,
requestedType: fsggml.TensorTypeQ8_0,
fileType: fsggml.FileTypeQ4_K_M,
blockCount: 2,
wantType: fsggml.TensorTypeQ8_0,
wantQuantize: true,
},
{
name: "routed_down_q4_k_s_promoted",
tensor: "blk.0.ffn_down_exps.weight",
originalType: fsggml.TensorTypeBF16,
requestedType: fsggml.TensorTypeQ4_K,
fileType: fsggml.FileTypeQ4_K_S,
blockCount: 8,
wantType: fsggml.TensorTypeQ5_K,
wantQuantize: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
gotType, gotQuantize := lagunaGGUFQuantization(tt.tensor, tt.originalType, tt.requestedType, tt.fileType, tt.blockCount)
if gotType != tt.wantType || gotQuantize != tt.wantQuantize {
t.Fatalf("lagunaGGUFQuantization(%q) = (%s, %v), want (%s, %v)", tt.tensor, gotType, gotQuantize, tt.wantType, tt.wantQuantize)
}
})
}
}

View File

@@ -7,6 +7,7 @@ import (
"maps"
"os"
"slices"
"strconv"
"strings"
"unsafe"
@@ -51,11 +52,14 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
}
type quantizeState struct {
nAttnV int // Number of attn_*v* weight tensors
nFfnDown int // Number of ffn_down tensors
iAttnV int // Running counter of number of attn_v tensors that have been processed
iFfnDown int // Running counter of number of ffn_down tensors that have been processed
hasOutput bool // used to figure out if a model shares tok_embd with the output weight
nAttnV int // Number of attn_*v* weight tensors
nFfnDown int // Number of ffn_down tensors
iAttnV int // Running counter of number of attn_v tensors that have been processed
iFfnDown int // Running counter of number of ffn_down tensors that have been processed
hasOutput bool // used to figure out if a model shares tok_embd with the output weight
preserveSourceFP8ToQ8 bool
preserveSourceQ4 bool
sourceFP8Tensors map[string]struct{}
}
func useMoreBits(iLayer, nLayers int) bool {
@@ -108,6 +112,53 @@ func qwen3LinearAttnQuantType(name string) (fsggml.TensorType, bool) {
return 0, false
}
func isLagunaGGUFRoutedExpertWeight(name string) bool {
return strings.HasSuffix(name, ".weight") && (strings.Contains(name, "ffn_gate_exps") ||
strings.Contains(name, "ffn_up_exps") ||
strings.Contains(name, "ffn_down_exps"))
}
func lagunaGGUFBlockIndex(name string) (int, bool) {
if !strings.HasPrefix(name, "blk.") {
return 0, false
}
parts := strings.SplitN(strings.TrimPrefix(name, "blk."), ".", 2)
if len(parts) != 2 {
return 0, false
}
i, err := strconv.Atoi(parts[0])
if err != nil {
return 0, false
}
return i, true
}
func lagunaGGUFQuantization(name string, originalType, requestedType fsggml.TensorType, ftype fsggml.FileType, blockCount int) (fsggml.TensorType, bool) {
if !isLagunaGGUFRoutedExpertWeight(name) {
return originalType, false
}
if strings.HasSuffix(name, ".ffn_down_exps.weight") {
if i, ok := lagunaGGUFBlockIndex(name); ok && blockCount > 0 {
switch ftype {
case fsggml.FileTypeQ4_K_M:
if requestedType != fsggml.TensorTypeQ8_0 && useMoreBits(i, blockCount) {
return fsggml.TensorTypeQ6_K, true
}
case fsggml.FileTypeQ4_K_S:
if requestedType != fsggml.TensorTypeQ8_0 && i < blockCount/8 {
return fsggml.TensorTypeQ5_K, true
}
}
}
}
return requestedType, true
}
func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType, name string, shape []uint64, ftype fsggml.FileType) fsggml.TensorType {
// Ported from llama_tensor_get_type, removed unsupported quantization types
nExperts := max(1, kv.Uint("expert_count", 0))
@@ -120,10 +171,10 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
newType = fsggml.TensorTypeQ6_K
}
} else if strings.Contains(name, "attn_v.weight") {
if (ftype == fsggml.FileTypeQ4_K_M) &&
if newType != fsggml.TensorTypeQ8_0 && (ftype == fsggml.FileTypeQ4_K_M) &&
useMoreBits(qs.iAttnV, qs.nAttnV) {
newType = fsggml.TensorTypeQ6_K
} else if ftype == fsggml.FileTypeQ4_K_S && qs.iAttnV < 4 {
} else if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_S && qs.iAttnV < 4 {
newType = fsggml.TensorTypeQ5_K
}
@@ -158,31 +209,35 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
// expert alphabetically, so dense increments the counter and expert uses counter-1.
var iLayer int
if strings.Contains(name, "_exps") {
if kv.Architecture() == "laguna" {
goto finalize
}
iLayer = max(0, qs.iFfnDown-1)
} else {
iLayer = qs.iFfnDown
qs.iFfnDown++
}
n_layer := qs.nFfnDown
if ftype == fsggml.FileTypeQ4_K_M {
if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_M {
if useMoreBits(iLayer, n_layer) {
newType = fsggml.TensorTypeQ6_K
}
} else if ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 {
} else if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 {
newType = fsggml.TensorTypeQ5_K
}
} else if strings.Contains(name, "attn_output.weight") {
if nExperts == 8 {
if newType != fsggml.TensorTypeQ8_0 && nExperts == 8 {
if ftype == fsggml.FileTypeQ4_K_S || ftype == fsggml.FileTypeQ4_K_M {
newType = fsggml.TensorTypeQ5_K
}
}
} else if strings.Contains(name, "attn_qkv.weight") {
if ftype == fsggml.FileTypeQ4_K_M {
if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_M {
newType = fsggml.TensorTypeQ5_K
}
}
finalize:
if newType.IsQuantized() {
nx := shape[0]
qk_k := newType.BlockSize()
@@ -218,7 +273,12 @@ func quantize(in, out *os.File, orig *fsggml.GGML, newFileType fsggml.FileType,
kv := maps.Clone(orig.KV())
kv["general.file_type"] = newFileType
// kv["general.quantization_version"] = ggml.QuantizationVersion()
qs := &quantizeState{}
qs := &quantizeState{
sourceFP8Tensors: sourceFP8TensorSet(kv),
}
hasSourceFP8 := hasSourceFP8Tensors(kv)
qs.preserveSourceFP8ToQ8 = hasSourceFP8 && newFileType == fsggml.FileTypeQ8_0
qs.preserveSourceQ4 = hasSourceFP8 && slices.Contains([]fsggml.FileType{fsggml.FileTypeQ4_K_M, fsggml.FileTypeQ4_K_S}, newFileType)
// Build up the quantize state so newType can adjust types
layerCount := 0
for k, l := range orig.Tensors().GroupLayers() {
@@ -304,13 +364,34 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
newType := fsggml.TensorType(t.Kind)
if quantize {
if qs.preserveSourceFP8ToQ8 {
if _, ok := qs.sourceFP8Tensors[name]; !ok {
return newType
}
}
if slices.Contains([]string{"qwen3next", "qwen35", "qwen35moe"}, kv.Architecture()) && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
if qt, ok := qwen3LinearAttnQuantType(name); ok {
return qt
}
}
// TODO: Consider extracting architecture-specific GGUF quantization policy
// from server so different quantization backends can share one source of
// truth for model-family specializations.
// get more optimal quantization type based on the tensor shape, layer, etc.
if qs.preserveSourceQ4 {
if _, ok := qs.sourceFP8Tensors[name]; !ok {
defaultType = fsggml.TensorTypeQ8_0
}
}
if kv.Architecture() == "laguna" {
var ok bool
defaultType, ok = lagunaGGUFQuantization(name, newType, defaultType, ftype, int(kv.Uint("block_count", 0)))
if !ok {
return newType
}
}
newType = getTensorNewType(kv, qs, defaultType, t.Name, t.Shape, ftype)
if newType != defaultType {
slog.Debug("tensor quantization adjusted for better quality", "name", t.Name, "requested", defaultType, "quantization", newType)
@@ -318,3 +399,16 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
}
return newType
}
func sourceFP8TensorSet(kv fsggml.KV) map[string]struct{} {
names := kv.Strings("source_fp8_tensors")
if len(names) == 0 {
return nil
}
out := make(map[string]struct{}, len(names))
for _, name := range names {
out[name] = struct{}{}
}
return out
}

View File

@@ -308,6 +308,95 @@ func TestQuantizeModel(t *testing.T) {
"output.weight": fsggml.TensorTypeQ8_0,
},
},
{
name: "source_fp8_q8_preserves_bf16_tensors",
kv: map[string]any{
"general.architecture": "test",
"source_quantization": "hf_fp8",
"source_fp8_tensors": []string{"blk.1.ffn_down_exps.weight"},
},
tensors: []*fsggml.Tensor{
{
Name: "blk.1.ffn_down_exps.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.attn_q.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
},
newType: "Q8_0",
expectedTensorTypes: map[string]fsggml.TensorType{
"blk.1.ffn_down_exps.weight": fsggml.TensorTypeQ8_0,
"blk.1.attn_q.weight": fsggml.TensorTypeBF16,
},
},
{
name: "source_fp8_q4_promotes_bf16_tensors_to_q8",
kv: map[string]any{
"general.architecture": "test",
"source_quantization": "hf_fp8",
"source_fp8_tensors": []string{
"blk.1.ffn_gate_exps.weight",
"blk.1.ffn_down_exps.weight",
},
},
tensors: []*fsggml.Tensor{
{
Name: "blk.1.ffn_gate_exps.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.ffn_down_exps.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.attn_q.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.attn_v.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.ffn_down.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.attn_q_norm.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.ffn_gate_inp.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "output.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
},
newType: "Q4_K_M",
expectedTensorTypes: map[string]fsggml.TensorType{
"blk.1.ffn_gate_exps.weight": fsggml.TensorTypeQ4_K,
"blk.1.ffn_down_exps.weight": fsggml.TensorTypeQ6_K,
"blk.1.attn_q.weight": fsggml.TensorTypeQ8_0,
"blk.1.attn_v.weight": fsggml.TensorTypeQ8_0,
"blk.1.ffn_down.weight": fsggml.TensorTypeQ8_0,
"blk.1.attn_q_norm.weight": fsggml.TensorTypeBF16,
"blk.1.ffn_gate_inp.weight": fsggml.TensorTypeBF16,
"output.weight": fsggml.TensorTypeQ8_0,
},
},
{
name: "f32_short_data",
kv: map[string]any{

View File

@@ -618,8 +618,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
if builtinParser != nil {
// only send messages with meaningful content (empty messages confuse clients)
if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 {
// Emit chunks that carry logprobs even if the parser is still buffering
// visible content, otherwise generate logprobs disappear for models with
// builtin thinking/tool parsers.
if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 || len(res.Logprobs) > 0 {
ch <- res
}

View File

@@ -102,6 +102,36 @@ func createRequest(t *testing.T, fn func(*gin.Context), body any) *httptest.Resp
return w.ResponseRecorder
}
func readCreatedModelConfig(t *testing.T, name string) model.ConfigV2 {
t.Helper()
mf, err := manifest.ParseNamedManifest(model.ParseName(name))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
if mf.Config.Digest == "" {
t.Fatalf("unexpected empty config digest for manifest")
}
configPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("config blob path: %v", err)
}
cfgFile, err := os.Open(configPath)
if err != nil {
t.Fatalf("open config blob: %v", err)
}
defer cfgFile.Close()
var cfg model.ConfigV2
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
t.Fatalf("decode config: %v", err)
}
return cfg
}
func checkFileExists(t *testing.T, p string, expect []string) {
t.Helper()
@@ -981,6 +1011,134 @@ func TestCreateGemma4KeepsDynamicRendererAlias(t *testing.T) {
}
}
func TestCreateLagunaDetectsRendererParser(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "laguna",
"general.parameter_count": uint64(33_400_000_000),
}, nil)
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
mf, err := manifest.ParseNamedManifest(model.ParseName("test"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
if mf.Config.Digest == "" {
t.Fatalf("unexpected empty config digest for manifest")
}
configPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("config blob path: %v", err)
}
cfgFile, err := os.Open(configPath)
if err != nil {
t.Fatalf("open config blob: %v", err)
}
defer cfgFile.Close()
var cfg model.ConfigV2
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
t.Fatalf("decode config: %v", err)
}
if cfg.Renderer != "laguna" {
t.Fatalf("expected renderer %q, got %q", "laguna", cfg.Renderer)
}
if cfg.Parser != "laguna" {
t.Fatalf("expected parser %q, got %q", "laguna", cfg.Parser)
}
}
func TestCreateNemotronHDefaultsRendererParser(t *testing.T) {
gin.SetMode(gin.TestMode)
for _, arch := range []string{"nemotron_h", "nemotron_h_moe", "nemotron_h_omni"} {
t.Run(arch, func(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
_, digest := createBinFile(t, ggml.KV{
"general.architecture": arch,
}, nil)
name := strings.ReplaceAll(arch, "_", "-")
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: name,
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
cfg := readCreatedModelConfig(t, name)
if cfg.Renderer != "nemotron-3-nano" {
t.Fatalf("expected renderer %q, got %q", "nemotron-3-nano", cfg.Renderer)
}
if cfg.Parser != "nemotron-3-nano" {
t.Fatalf("expected parser %q, got %q", "nemotron-3-nano", cfg.Parser)
}
})
}
}
func TestCreateNemotronHDefaultsKeepExplicitRendererParser(t *testing.T) {
gin.SetMode(gin.TestMode)
for _, arch := range []string{"nemotron_h", "nemotron_h_moe", "nemotron_h_omni"} {
t.Run(arch, func(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
_, digest := createBinFile(t, ggml.KV{
"general.architecture": arch,
}, nil)
const (
renderer = "custom-renderer"
parser = "custom-parser"
)
name := strings.ReplaceAll(arch, "_", "-") + "-custom"
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: name,
Files: map[string]string{"test.gguf": digest},
Renderer: renderer,
Parser: parser,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
cfg := readCreatedModelConfig(t, name)
if cfg.Renderer != renderer {
t.Fatalf("expected renderer %q, got %q", renderer, cfg.Renderer)
}
if cfg.Parser != parser {
t.Fatalf("expected parser %q, got %q", parser, cfg.Parser)
}
})
}
}
func TestDetectModelTypeFromFiles(t *testing.T) {
t.Run("gguf file", func(t *testing.T) {
_, digest := createBinFile(t, nil, nil)

View File

@@ -1473,6 +1473,119 @@ func TestGenerateLogprobs(t *testing.T) {
})
}
func TestGenerateLogprobsWithBuiltinParser(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{}
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
responses := []llm.CompletionResponse{
{
Content: "h",
Logprobs: []llm.Logprob{{TokenLogprob: llm.TokenLogprob{Token: "h", Logprob: -0.1}}},
},
{
Content: "i</think>Hello",
Logprobs: []llm.Logprob{{TokenLogprob: llm.TokenLogprob{Token: "hi", Logprob: -0.2}}},
},
{
Content: " world",
Done: true,
DoneReason: llm.DoneReasonStop,
Logprobs: []llm.Logprob{{TokenLogprob: llm.TokenLogprob{Token: " world", Logprob: -0.3}}},
},
}
for _, resp := range responses {
select {
case <-ctx.Done():
return ctx.Err()
default:
fn(resp)
}
}
return nil
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{llama: &mock}
return false
},
},
}
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
if w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-generate-logprob-parser",
Files: map[string]string{"file.gguf": digest},
Parser: "deepseek3",
Template: `{{ .Prompt }}`,
Stream: &stream,
}); w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
noStream := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-generate-logprob-parser",
Prompt: "Why is the sky blue?",
Stream: &noStream,
Logprobs: true,
Options: map[string]any{
"temperature": 0,
},
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
var resp api.GenerateResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got := len(resp.Logprobs); got != 3 {
t.Fatalf("expected 3 logprob entries, got %d", got)
}
}
func TestChatLogprobs(t *testing.T) {
t.Run("invalid top_logprobs negative", func(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -418,7 +418,7 @@ func (s *Scheduler) load(req *LlmRequest, systemInfo ml.SystemInfo, gpus []ml.De
// Some architectures are not safe with num_parallel > 1.
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen35", "qwen35moe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen35", "qwen35moe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe", "nemotron_h_omni"}, req.model.Config.ModelFamily) && numParallel != 1 {
numParallel = 1
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
modelparsers "github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/types/model"
@@ -132,11 +133,11 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
if isSafetensors {
modelType = "safetensors model"
spinnerKey = "create"
capabilities = inferSafetensorsCapabilities(opts.ModelDir)
// Set parser and renderer name based on architecture
parserName = getParserName(opts.ModelDir)
rendererName = getRendererName(opts.ModelDir)
capabilities = inferSafetensorsCapabilities(opts.ModelDir, resolveParserName(opts.Modelfile, parserName))
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
@@ -183,7 +184,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
return nil
}
func inferSafetensorsCapabilities(modelDir string) []string {
func inferSafetensorsCapabilities(modelDir, parserName string) []string {
capabilities := []string{"completion"}
// Qwen3.5 multimodal checkpoints use ConditionalGeneration architectures.
@@ -195,7 +196,16 @@ func inferSafetensorsCapabilities(modelDir string) []string {
capabilities = append(capabilities, "audio")
}
if supportsThinking(modelDir) {
var builtinParser modelparsers.Parser
if parserName != "" {
builtinParser = modelparsers.ParserForName(parserName)
}
if builtinParser != nil && builtinParser.HasToolSupport() {
capabilities = append(capabilities, "tools")
}
if supportsThinking(modelDir) || (builtinParser != nil && builtinParser.HasThinkingSupport()) {
capabilities = append(capabilities, "thinking")
}
@@ -453,8 +463,8 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
return layers, nil
}
// supportsThinking checks if the model supports thinking mode based on its architecture.
// This reads the config.json from the model directory and checks the architectures field.
// supportsThinking checks if the model supports thinking mode based on known
// architectures that do not expose a cleaner signal in their local metadata.
func supportsThinking(modelDir string) bool {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
@@ -554,6 +564,9 @@ func getParserName(modelDir string) string {
// Check architectures for known parsers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "laguna") {
return "laguna"
}
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7"
}
@@ -571,6 +584,9 @@ func getParserName(modelDir string) string {
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "laguna") {
return "laguna"
}
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}
@@ -608,6 +624,9 @@ func getRendererName(modelDir string) string {
// Check architectures for known renderers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "laguna") {
return "laguna"
}
if strings.Contains(archLower, "gemma4") {
return "gemma4"
}
@@ -625,6 +644,9 @@ func getRendererName(modelDir string) string {
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "laguna") {
return "laguna"
}
if strings.Contains(typeLower, "gemma4") {
return "gemma4"
}

View File

@@ -352,7 +352,7 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
t.Fatal(err)
}
if got := inferSafetensorsCapabilities(dir); !slices.Equal(got, tt.want) {
if got := inferSafetensorsCapabilities(dir, ""); !slices.Equal(got, tt.want) {
t.Fatalf("inferSafetensorsCapabilities() = %#v, want %#v", got, tt.want)
}
})
@@ -554,6 +554,11 @@ func TestSupportsThinking(t *testing.T) {
configJSON: `{"model_type": "deepseek"}`,
want: true,
},
{
name: "laguna architecture without template",
configJSON: `{"architectures": ["LagunaForCausalLM"], "model_type": "laguna"}`,
want: false,
},
{
name: "empty config",
configJSON: `{}`,
@@ -584,6 +589,55 @@ func TestSupportsThinking_NoConfig(t *testing.T) {
}
}
func TestInferSafetensorsCapabilitiesFromParser(t *testing.T) {
tests := []struct {
name string
parserName string
want []string
}{
{
name: "laguna tools and thinking",
parserName: "laguna",
want: []string{"completion", "tools", "thinking"},
},
{
name: "functiongemma tools only",
parserName: "functiongemma",
want: []string{"completion", "tools"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
t.Fatal(err)
}
if got := inferSafetensorsCapabilities(dir, tt.parserName); !slices.Equal(got, tt.want) {
t.Fatalf("inferSafetensorsCapabilities() = %#v, want %#v", got, tt.want)
}
})
}
}
func TestInferSafetensorsCapabilitiesLaguna(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"architectures": ["LagunaForCausalLM"], "model_type": "laguna"}`), 0o644); err != nil {
t.Fatal(err)
}
got := inferSafetensorsCapabilities(dir, "laguna")
for _, want := range []string{"completion", "tools", "thinking"} {
if !slices.Contains(got, want) {
t.Fatalf("capabilities %v missing %q", got, want)
}
}
if slices.Contains(got, "vision") || slices.Contains(got, "audio") {
t.Fatalf("unexpected non-text capability in %v", got)
}
}
func TestGetParserName(t *testing.T) {
tests := []struct {
name string
@@ -615,6 +669,11 @@ func TestGetParserName(t *testing.T) {
configJSON: `{"model_type": "qwen3"}`,
want: "qwen3",
},
{
name: "laguna model",
configJSON: `{"architectures": ["LagunaForCausalLM"], "model_type": "laguna"}`,
want: "laguna",
},
{
name: "no config",
configJSON: `{}`,
@@ -660,6 +719,11 @@ func TestGetRendererName(t *testing.T) {
configJSON: `{"architectures": ["LlamaForCausalLM"]}`,
want: "",
},
{
name: "laguna model",
configJSON: `{"architectures": ["LagunaForCausalLM"], "model_type": "laguna"}`,
want: "laguna",
},
}
for _, tt := range tests {

View File

@@ -776,6 +776,11 @@ type tensorImportTransform interface {
quantizationType(name string, shape []int32, quantize string) string
}
type sourceFP8TensorImportTransform interface {
sourceFP8TensorQuantization(name string, shape []int32, requested string, fallback string) string
sourceFP8BF16Quantization(name string, shape []int32, requested string) string
}
type noopImportTransform struct{}
func (noopImportTransform) skipTensor(string) bool { return false }
@@ -804,6 +809,7 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
"Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform,
"Gemma4ForCausalLM": newGemma4ImportTransform,
"Gemma4ForConditionalGeneration": newGemma4ImportTransform,
"LagunaForCausalLM": newLagunaImportTransform,
}
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
@@ -842,6 +848,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
if err != nil {
return fmt.Errorf("failed to construct import transform for architecture %q: %w", sourceConfig.Architecture(), err)
}
sourceFP8Transform, _ := importTransform.(sourceFP8TensorImportTransform)
// Resolve the optional packed layer creator
var packedCreator PackedTensorLayerCreator
@@ -991,9 +998,17 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
// synthetic tests may not pass the generic import size filter.
quantizeType = "mxfp8"
}
quantizeType = sourceFP8TensorQuantization(outTD.Name, outTD.Shape, quantize, quantizeType)
if sourceFP8Transform != nil {
quantizeType = sourceFP8Transform.sourceFP8TensorQuantization(outTD.Name, outTD.Shape, quantize, quantizeType)
} else {
quantizeType = sourceFP8TensorQuantization(outTD.Name, outTD.Shape, quantize, quantizeType)
}
case sourceQuantKind == sourceQuantizedKindSourceFP8:
quantizeType = sourceFP8BF16PromotionQuantization(outTD.Name, outTD.Shape, quantize)
if sourceFP8Transform != nil {
quantizeType = sourceFP8Transform.sourceFP8BF16Quantization(outTD.Name, outTD.Shape, quantize)
} else {
quantizeType = sourceFP8BF16PromotionQuantization(outTD.Name, outTD.Shape, quantize)
}
case effectiveQuantize != "":
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, effectiveQuantize)
}

59
x/create/laguna.go Normal file
View File

@@ -0,0 +1,59 @@
package create
import (
"strings"
"github.com/ollama/ollama/x/safetensors"
)
type lagunaImportTransform struct{}
func newLagunaImportTransform(string, sourceModelConfig) (tensorImportTransform, error) {
return lagunaImportTransform{}, nil
}
func (lagunaImportTransform) skipTensor(string) bool { return false }
func (lagunaImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
if td == nil {
return nil, nil
}
return []*safetensors.TensorData{td}, nil
}
func (lagunaImportTransform) quantizationType(name string, shape []int32, quantize string) string {
if !lagunaIsHFRoutedExpertWeight(name) {
return ""
}
return GetTensorQuantization(name, shape, quantize)
}
func (lagunaImportTransform) sourceFP8TensorQuantization(name string, shape []int32, requested string, fallback string) string {
if !lagunaIsHFRoutedExpertWeight(name) {
return ""
}
switch normalizeQuantType(requested) {
case "nvfp4", "mxfp4":
if lagunaKeepSourceFP8TensorAtMXFP8(name, shape) {
return "mxfp8"
}
}
return fallback
}
func (lagunaImportTransform) sourceFP8BF16Quantization(string, []int32, string) string {
return ""
}
func lagunaKeepSourceFP8TensorAtMXFP8(name string, shape []int32) bool {
if len(shape) != 2 || !isAligned(shape, "mxfp8") {
return false
}
return strings.Contains(name, "down_proj")
}
func lagunaIsHFRoutedExpertWeight(name string) bool {
return strings.HasSuffix(name, ".weight") && strings.Contains(name, ".mlp.experts.")
}

200
x/create/laguna_test.go Normal file
View File

@@ -0,0 +1,200 @@
package create
import (
"io"
"os"
"path/filepath"
"testing"
st "github.com/ollama/ollama/x/safetensors"
)
func TestCreateSafetensorsModel_LagunaHFFP8RespectsSourceTensorPrecision(t *testing.T) {
tests := []struct {
name string
requested string
wantFP8Gate string
wantFP8Up string
wantFP8Down string
wantBF16QProj string
}{
{
name: "default mxfp8 import keeps source bf16 tensors",
requested: "",
wantFP8Gate: "mxfp8",
wantFP8Up: "mxfp8",
wantFP8Down: "mxfp8",
wantBF16QProj: "",
},
{
name: "nvfp4 import keeps source bf16 tensors and preserves down_proj at mxfp8",
requested: "nvfp4",
wantFP8Gate: "nvfp4",
wantFP8Up: "nvfp4",
wantFP8Down: "mxfp8",
wantBF16QProj: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "laguna",
"architectures": ["LagunaForCausalLM"],
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("model.embed_tokens.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("model.layers.0.mlp.gate.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
})
quantizeByName := make(map[string]string)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return LayerInfo{}, err
}
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return nil, err
}
quantizeByName[name] = quantize
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
if err := CreateSafetensorsModel("test-model", dir, tt.requested, createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if got := quantizeByName["model.layers.0.mlp.experts.0.gate_proj.weight"]; got != tt.wantFP8Gate {
t.Fatalf("gate_proj quantization = %q, want %q", got, tt.wantFP8Gate)
}
if got := quantizeByName["model.layers.0.mlp.experts.0.up_proj.weight"]; got != tt.wantFP8Up {
t.Fatalf("up_proj quantization = %q, want %q", got, tt.wantFP8Up)
}
if got := quantizeByName["model.layers.0.mlp.experts.0.down_proj.weight"]; got != tt.wantFP8Down {
t.Fatalf("down_proj quantization = %q, want %q", got, tt.wantFP8Down)
}
for _, name := range []string{
"model.layers.0.self_attn.q_proj.weight",
"model.embed_tokens.weight",
"lm_head.weight",
"model.layers.0.mlp.gate.weight",
} {
if got := quantizeByName[name]; got != tt.wantBF16QProj {
t.Fatalf("%s quantization = %q, want %q", name, got, tt.wantBF16QProj)
}
}
})
}
}
func TestCreateSafetensorsModel_LagunaBF16QuantizesOnlyRoutedExperts(t *testing.T) {
tests := []struct {
name string
requested string
want map[string]string
}{
{
name: "int8 quantizes only routed experts",
requested: "int8",
want: map[string]string{
"model.layers.0.mlp.experts.0.gate_proj.weight": "int8",
"model.layers.0.mlp.experts.0.up_proj.weight": "int8",
"model.layers.0.mlp.experts.0.down_proj.weight": "int8",
"model.layers.0.mlp.shared_experts.gate_proj.weight": "",
"model.layers.0.mlp.shared_experts.down_proj.weight": "",
"model.layers.0.self_attn.q_proj.weight": "",
"model.layers.0.mlp.down_proj.weight": "",
"model.embed_tokens.weight": "",
"lm_head.weight": "",
"model.layers.0.mlp.gate.weight": "",
},
},
{
name: "int4 keeps routed down_proj at int8 and leaves others bf16",
requested: "int4",
want: map[string]string{
"model.layers.0.mlp.experts.0.gate_proj.weight": "int4",
"model.layers.0.mlp.experts.0.up_proj.weight": "int4",
"model.layers.0.mlp.experts.0.down_proj.weight": "int8",
"model.layers.0.mlp.shared_experts.gate_proj.weight": "",
"model.layers.0.mlp.shared_experts.down_proj.weight": "",
"model.layers.0.self_attn.q_proj.weight": "",
"model.layers.0.mlp.down_proj.weight": "",
"model.embed_tokens.weight": "",
"lm_head.weight": "",
"model.layers.0.mlp.gate.weight": "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "laguna",
"architectures": ["LagunaForCausalLM"]
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.gate_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.up_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.down_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("model.layers.0.mlp.shared_experts.gate_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("model.layers.0.mlp.shared_experts.down_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("model.layers.0.mlp.down_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("model.embed_tokens.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("model.layers.0.mlp.gate.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
})
quantizeByName := make(map[string]string)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return LayerInfo{}, err
}
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return nil, err
}
quantizeByName[name] = quantize
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
if err := CreateSafetensorsModel("test-model", dir, tt.requested, createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
for name, want := range tt.want {
if got := quantizeByName[name]; got != want {
t.Fatalf("%s quantization = %q, want %q", name, got, want)
}
}
})
}
}

View File

@@ -4,6 +4,7 @@ import (
_ "github.com/ollama/ollama/x/models/gemma3"
_ "github.com/ollama/ollama/x/models/gemma4"
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/laguna"
_ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3"
_ "github.com/ollama/ollama/x/models/qwen3_5"

View File

@@ -34,6 +34,19 @@ var SiLU = Compile1(
Shapeless(),
)
// SoftplusF32 returns softplus(x) computed in float32 precision and cast back
// to x's original dtype, as a fused kernel. Matches the laguna attention
// output-gate formula: softplus(cast_f32(x)).cast(orig_dtype).
var SoftplusF32 = Compile1(
"SoftplusF32",
func(x *Array) *Array {
dt := x.DType()
zero := FromValue[float32](0)
return Logaddexp(x.AsType(DTypeFloat32), zero).AsType(dt)
},
Shapeless(),
)
// SwiGLU returns silu(gate) * up as a fused kernel.
var SwiGLU = Compile2(
"SwiGLU",

1216
x/models/laguna/laguna.go Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,509 @@
package laguna
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
func TestParseConfigLagunaXS(t *testing.T) {
skipIfNoMLX(t)
cfg, err := parseConfig([]byte(`{
"model_type": "laguna",
"hidden_size": 2048,
"intermediate_size": 8192,
"moe_intermediate_size": 512,
"shared_expert_intermediate_size": 512,
"num_hidden_layers": 4,
"num_attention_heads": 48,
"num_attention_heads_per_layer": [48, 64, 64, 64],
"num_key_value_heads": 8,
"head_dim": 128,
"vocab_size": 100352,
"max_position_embeddings": 131072,
"layer_types": ["full_attention", "sliding_attention", "sliding_attention", "sliding_attention"],
"sliding_window": 512,
"mlp_only_layers": [0],
"decoder_sparse_step": 1,
"num_experts": 256,
"num_experts_per_tok": 8,
"norm_topk_prob": true,
"moe_routed_scaling_factor": 2.5,
"gating": "per-head",
"rms_norm_eps": 1e-6,
"partial_rotary_factor": 0.5,
"rope_parameters": {
"rope_theta": 500000,
"rope_type": "yarn",
"factor": 32,
"original_max_position_embeddings": 4096,
"beta_fast": 64,
"beta_slow": 1,
"attention_factor": 1
},
"swa_rope_parameters": {
"partial_rotary_factor": 1.0,
"rope_theta": 10000,
"rope_type": "linear"
}
}`))
if err != nil {
t.Fatal(err)
}
if cfg.FullRopeDim != 64 {
t.Fatalf("FullRopeDim = %d, want 64", cfg.FullRopeDim)
}
if cfg.FullRopeBase != 500000 {
t.Fatalf("FullRopeBase = %v, want 500000", cfg.FullRopeBase)
}
if cfg.FullRopeScale != 1 {
t.Fatalf("FullRopeScale = %v, want explicit YaRN attention_factor", cfg.FullRopeScale)
}
if cfg.FullRopeFreqs == nil {
t.Fatal("FullRopeFreqs should be precomputed for YaRN")
}
if cfg.SlidingRopeDim != 128 {
t.Fatalf("SlidingRopeDim = %d, want 128", cfg.SlidingRopeDim)
}
if cfg.SlidingRopeBase != 10000 {
t.Fatalf("SlidingRopeBase = %v, want 10000", cfg.SlidingRopeBase)
}
if !layerIsSliding(&cfg, 1) {
t.Fatal("layer 1 should use sliding attention")
}
if layerUsesMoE(&cfg, 0) {
t.Fatal("layer 0 should be dense due to mlp_only_layers")
}
if !layerUsesMoE(&cfg, 1) {
t.Fatal("layer 1 should use MoE")
}
if got := numHeadsForLayer(&cfg, 1); got != 64 {
t.Fatalf("numHeadsForLayer(1) = %d, want 64", got)
}
}
func TestParseConfigLagunaFP8RopeScaling(t *testing.T) {
skipIfNoMLX(t)
cfg, err := parseConfig([]byte(`{
"hidden_size": 2048,
"intermediate_size": 8192,
"num_hidden_layers": 1,
"num_attention_heads": 48,
"num_key_value_heads": 8,
"head_dim": 128,
"vocab_size": 100352,
"max_position_embeddings": 131072,
"rope_theta": 500000,
"partial_rotary_factor": 0.5,
"rope_scaling": {
"rope_type": "yarn",
"factor": 32
}
}`))
if err != nil {
t.Fatal(err)
}
if cfg.FullRopeBase != 500000 {
t.Fatalf("FullRopeBase = %v, want 500000", cfg.FullRopeBase)
}
if cfg.FullRopeDim != 64 {
t.Fatalf("FullRopeDim = %d, want 64", cfg.FullRopeDim)
}
}
func TestParseConfigLagunaGASchema(t *testing.T) {
skipIfNoMLX(t)
cfg, err := parseConfig([]byte(`{
"model_type": "laguna",
"hidden_size": 2048,
"intermediate_size": 8192,
"moe_intermediate_size": 512,
"shared_expert_intermediate_size": 512,
"num_hidden_layers": 4,
"num_attention_heads": 48,
"num_attention_heads_per_layer": [48, 64, 64, 64],
"num_key_value_heads": 8,
"head_dim": 128,
"vocab_size": 100352,
"max_position_embeddings": 131072,
"layer_types": ["full_attention", "sliding_attention", "sliding_attention", "sliding_attention"],
"sliding_window": 512,
"mlp_layer_types": ["dense", "sparse", "sparse", "sparse"],
"num_experts": 256,
"num_experts_per_tok": 8,
"moe_routed_scaling_factor": 2.5,
"gating": true,
"rms_norm_eps": 1e-6,
"partial_rotary_factor": 0.5,
"rope_parameters": {
"full_attention": {
"rope_theta": 500000,
"rope_type": "yarn",
"factor": 32,
"original_max_position_embeddings": 4096,
"beta_fast": 64,
"beta_slow": 1,
"attention_factor": 1,
"partial_rotary_factor": 0.5
},
"sliding_attention": {
"rope_theta": 10000,
"rope_type": "default",
"partial_rotary_factor": 1.0
}
}
}`))
if err != nil {
t.Fatal(err)
}
if cfg.Gating != "per-head" {
t.Fatalf("Gating = %q, want per-head", cfg.Gating)
}
if !cfg.NormTopKProb {
t.Fatal("NormTopKProb should default true")
}
if cfg.FullRopeBase != 500000 {
t.Fatalf("FullRopeBase = %v, want 500000", cfg.FullRopeBase)
}
if cfg.SlidingRopeBase != 10000 {
t.Fatalf("SlidingRopeBase = %v, want 10000", cfg.SlidingRopeBase)
}
if cfg.FullRopeDim != 64 {
t.Fatalf("FullRopeDim = %d, want 64", cfg.FullRopeDim)
}
if cfg.SlidingRopeDim != 128 {
t.Fatalf("SlidingRopeDim = %d, want 128", cfg.SlidingRopeDim)
}
if layerUsesMoE(&cfg, 0) {
t.Fatal("layer 0 should be dense due to mlp_layer_types")
}
if !layerUsesMoE(&cfg, 1) {
t.Fatal("layer 1 should use MoE")
}
}
func TestTinyLagunaLoadAndForward(t *testing.T) {
skipIfNoMLX(t)
cfg, err := parseConfig([]byte(`{
"model_type": "laguna",
"hidden_size": 8,
"intermediate_size": 12,
"moe_intermediate_size": 4,
"shared_expert_intermediate_size": 4,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_attention_heads_per_layer": [2, 2],
"num_key_value_heads": 1,
"head_dim": 4,
"vocab_size": 16,
"max_position_embeddings": 64,
"layer_types": ["full_attention", "sliding_attention"],
"sliding_window": 2,
"mlp_only_layers": [0],
"decoder_sparse_step": 1,
"num_experts": 2,
"num_experts_per_tok": 1,
"norm_topk_prob": false,
"moe_routed_scaling_factor": 2.5,
"gating": "per-head",
"rms_norm_eps": 1e-5,
"partial_rotary_factor": 0.5,
"rope_parameters": {
"rope_theta": 10000,
"rope_type": "yarn",
"factor": 2,
"original_max_position_embeddings": 16,
"beta_fast": 32,
"beta_slow": 1
},
"swa_rope_parameters": {
"partial_rotary_factor": 1.0,
"rope_theta": 10000,
"rope_type": "linear"
}
}`))
if err != nil {
t.Fatal(err)
}
m := &Model{
Config: &cfg,
Layers: []*Layer{
{LayerIdx: 0, IsSliding: false},
{LayerIdx: 1, IsSliding: true},
},
}
tensors := tinyLagunaTensors()
if err := m.LoadWeights(tensors); err != nil {
t.Fatalf("LoadWeights failed: %v", err)
}
tokens := mlx.FromValues([]int32{1, 2, 3}, 1, 3)
caches := m.NewCaches()
defer func() {
for _, c := range caches {
if c != nil {
c.Free()
}
}
}()
hidden := m.Forward(&batch.Batch{
InputIDs: tokens,
SeqOffsets: []int32{0},
SeqQueryLens: []int32{int32(tokens.Dim(1))},
}, caches)
mlx.Eval(hidden)
if got := hidden.Dims(); len(got) != 3 || got[0] != 1 || got[1] != 3 || got[2] != 8 {
t.Fatalf("hidden shape = %v, want [1 3 8]", got)
}
logits := m.Unembed(hidden)
mlx.Eval(logits)
if got := logits.Dims(); len(got) != 3 || got[0] != 1 || got[1] != 3 || got[2] != 16 {
t.Fatalf("logits shape = %v, want [1 3 16]", got)
}
for i, v := range logits.Floats() {
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
t.Fatalf("logits[%d] is not finite: %v", i, v)
}
}
}
func TestTinyLagunaLoadWeightsFusesDenseGateUp(t *testing.T) {
skipIfNoMLX(t)
cfg, err := parseConfig([]byte(`{
"model_type": "laguna",
"hidden_size": 8,
"intermediate_size": 12,
"moe_intermediate_size": 4,
"shared_expert_intermediate_size": 4,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_attention_heads_per_layer": [2, 2],
"num_key_value_heads": 1,
"head_dim": 4,
"vocab_size": 16,
"max_position_embeddings": 64,
"layer_types": ["full_attention", "sliding_attention"],
"sliding_window": 2,
"mlp_only_layers": [0],
"decoder_sparse_step": 1,
"num_experts": 2,
"num_experts_per_tok": 1,
"norm_topk_prob": false,
"moe_routed_scaling_factor": 2.5,
"gating": "per-head",
"rms_norm_eps": 1e-5
}`))
if err != nil {
t.Fatal(err)
}
m := &Model{
Config: &cfg,
Layers: []*Layer{
{LayerIdx: 0, IsSliding: false},
{LayerIdx: 1, IsSliding: true},
},
}
if err := m.LoadWeights(tinyLagunaTensors()); err != nil {
t.Fatalf("LoadWeights failed: %v", err)
}
moe, ok := m.Layers[1].MLP.(*SparseMoE)
if !ok {
t.Fatalf("layer 1 MLP type = %T, want *SparseMoE", m.Layers[1].MLP)
}
if !moe.SwitchMLP.UseFusedGateUp {
t.Fatal("expected dense SwitchMLP to fuse gate/up expert weights")
}
if moe.SwitchMLP.GateUpWeight == nil {
t.Fatal("expected fused GateUpWeight to be populated")
}
if got, want := moe.SwitchMLP.GateUpWeight.Dims(), []int{2, 8, 8}; len(got) != len(want) || got[0] != want[0] || got[1] != want[1] || got[2] != want[2] {
t.Fatalf("GateUpWeight dims = %v, want %v", got, want)
}
}
func TestSparseMoERouteBiasAffectsSelectionNotRoutingWeights(t *testing.T) {
skipIfNoMLX(t)
cfg := &Config{
HiddenSize: 1,
NumExperts: 2,
NumExpertsPerTok: 1,
NormTopKProb: false,
}
moe := &SparseMoE{
Gate: nn.NewLinear(mlx.FromValues([]float32{-4, -3}, 2, 1).AsType(mlx.DTypeBFloat16), nil),
EScoreCorrectionBias: mlx.FromValues([]float32{0.5, 0}, 2),
}
xFlat := mlx.FromValues([]float32{1}, 1, int(cfg.HiddenSize)).AsType(mlx.DTypeBFloat16)
scores, inds := moe.route(xFlat, cfg)
scores = scores.AsType(mlx.DTypeFloat32)
inds = inds.AsType(mlx.DTypeInt32)
mlx.Eval(scores, inds)
gates := moe.Gate.Forward(xFlat).AsType(mlx.DTypeFloat32)
probs := mlx.Sigmoid(gates)
mlx.Eval(probs)
probVals := probs.Floats()
if probVals[0] >= probVals[1] {
t.Fatalf("expected unbiased sigmoid scores to prefer expert 1, got %v", probVals)
}
if probVals[0]+0.5 <= probVals[1] {
t.Fatalf("expected bias to flip selection to expert 0, got probs=%v", probVals)
}
if got := inds.Ints(); len(got) != 1 || got[0] != 0 {
t.Fatalf("selected experts = %v, want [0]", got)
}
if got := scores.Floats(); len(got) != 1 || math.Abs(float64(got[0]-probVals[0])) > 1e-6 {
t.Fatalf("routing weights = %v, want [%v] using unbiased sigmoid scores", got, probVals[0])
}
}
func TestSwitchMLPFusedGateUpMatchesSeparate(t *testing.T) {
skipIfNoMLX(t)
cfg := &Config{HiddenSize: 4, NumExpertsPerTok: 2}
B, L := int32(2), int32(3)
xVals := make([]float32, int(B*L*cfg.HiddenSize))
for i := range xVals {
xVals[i] = float32((i%17)-8) * 0.01
}
x := mlx.FromValues(xVals, int(B), int(L), int(cfg.HiddenSize)).AsType(mlx.DTypeBFloat16)
indicesVals := make([]int32, B*L*cfg.NumExpertsPerTok)
for i := 0; i < len(indicesVals); i += int(cfg.NumExpertsPerTok) {
indicesVals[i] = int32((i / int(cfg.NumExpertsPerTok)) % 2)
indicesVals[i+1] = int32(((i / int(cfg.NumExpertsPerTok)) + 1) % 2)
}
indices := mlx.FromValues(indicesVals, int(B*L), int(cfg.NumExpertsPerTok))
separate := &SwitchMLP{
GateWeight: makePatternExpertWeight(2, 4, 3, 0.011),
UpWeight: makePatternExpertWeight(2, 4, 3, 0.017),
DownWeight: makePatternExpertWeight(2, 3, 4, 0.013),
}
fused := &SwitchMLP{
GateUpWeight: fuseExpertStacks(separate.GateWeight, separate.UpWeight, 2),
DownWeight: separate.DownWeight,
UseFusedGateUp: true,
}
gotSeparate := separate.Forward(x, indices, cfg)
gotFused := fused.Forward(x, indices, cfg)
mlx.Eval(gotSeparate, gotFused)
gotFusedF32 := gotFused.AsType(mlx.DTypeFloat32)
gotSeparateF32 := gotSeparate.AsType(mlx.DTypeFloat32)
mlx.Eval(gotFusedF32, gotSeparateF32)
assertFloatSlicesClose(t, gotFusedF32.Floats(), gotSeparateF32.Floats(), 1e-5)
}
func TestCombinedTensorGlobalScaleIgnoresInputGlobalScale(t *testing.T) {
skipIfNoMLX(t)
tensors := map[string]*mlx.Array{
"proj.weight.global_scale": mlx.FromValues([]float32{0.25}, 1),
"proj.weight.input_global_scale": mlx.FromValues([]float32{8}, 1),
}
got, _ := combinedTensorGlobalScale(tensors, "proj.weight")
if got == nil {
t.Fatal("combinedTensorGlobalScale returned nil")
}
mlx.Eval(got)
vals := got.Floats()
if len(vals) != 1 || vals[0] != 0.25 {
t.Fatalf("combinedTensorGlobalScale = %v, want [0.25]", vals)
}
}
func tinyLagunaTensors() map[string]*mlx.Array {
tensors := map[string]*mlx.Array{
"model.embed_tokens.weight": weights(16, 8),
"model.norm.weight": ones(8),
"lm_head.weight": weights(16, 8),
}
for layer := range 2 {
prefix := "model.layers." + string(rune('0'+layer))
tensors[prefix+".input_layernorm.weight"] = ones(8)
tensors[prefix+".post_attention_layernorm.weight"] = ones(8)
tensors[prefix+".self_attn.q_proj.weight"] = weights(8, 8)
tensors[prefix+".self_attn.k_proj.weight"] = weights(4, 8)
tensors[prefix+".self_attn.v_proj.weight"] = weights(4, 8)
tensors[prefix+".self_attn.o_proj.weight"] = weights(8, 8)
tensors[prefix+".self_attn.g_proj.weight"] = weights(2, 8)
tensors[prefix+".self_attn.q_norm.weight"] = ones(4)
tensors[prefix+".self_attn.k_norm.weight"] = ones(4)
}
tensors["model.layers.0.mlp.gate_proj.weight"] = weights(12, 8)
tensors["model.layers.0.mlp.up_proj.weight"] = weights(12, 8)
tensors["model.layers.0.mlp.down_proj.weight"] = weights(8, 12)
tensors["model.layers.1.mlp.gate.weight"] = weights(2, 8)
tensors["model.layers.1.mlp.experts.e_score_correction_bias"] = mlx.FromValues([]float32{0.1, -0.1}, 2)
for expert := range 2 {
prefix := "model.layers.1.mlp.experts." + string(rune('0'+expert))
tensors[prefix+".gate_proj.weight"] = weights(4, 8)
tensors[prefix+".up_proj.weight"] = weights(4, 8)
tensors[prefix+".down_proj.weight"] = weights(8, 4)
}
tensors["model.layers.1.mlp.shared_expert.gate_proj.weight"] = weights(4, 8)
tensors["model.layers.1.mlp.shared_expert.up_proj.weight"] = weights(4, 8)
tensors["model.layers.1.mlp.shared_expert.down_proj.weight"] = weights(8, 4)
return tensors
}
func makeExpertWeight(vals []float32, dims ...int) *mlx.Array {
return mlx.FromValues(vals, dims...).AsType(mlx.DTypeBFloat16)
}
func makePatternExpertWeight(numExperts, rows, cols int, scale float32) *mlx.Array {
vals := make([]float32, numExperts*rows*cols)
for i := range vals {
vals[i] = float32((i%23)-11) * scale
}
return makeExpertWeight(vals, numExperts, rows, cols)
}
func assertFloatSlicesClose(t *testing.T, got, want []float32, tol float64) {
t.Helper()
if len(got) != len(want) {
t.Fatalf("length mismatch: got %d want %d", len(got), len(want))
}
for i := range got {
if math.Abs(float64(got[i]-want[i])) > tol {
t.Fatalf("value[%d] = %v, want %v (tol=%g)", i, got[i], want[i], tol)
}
}
}
func weights(rows, cols int) *mlx.Array {
vals := make([]float32, rows*cols)
for i := range vals {
vals[i] = float32((i%7)-3) * 0.01
}
return mlx.FromValues(vals, rows, cols)
}
func ones(n int) *mlx.Array {
vals := make([]float32, n)
for i := range vals {
vals[i] = 1
}
return mlx.FromValues(vals, n)
}
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}

View File

@@ -447,7 +447,9 @@ func extractPretokenizer(data json.RawMessage) string {
if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
for _, pt := range seq.Pretokenizers {
if pt.Type == "Split" && pt.Pattern.Regex != "" {
return pt.Pattern.Regex
if _, err := regexp.Compile(rewritePatternForRE2(pt.Pattern.Regex)); err == nil {
return pt.Pattern.Regex
}
}
}
}

View File

@@ -22,3 +22,31 @@ func TestLoadFromBytesRejectsWordPiece(t *testing.T) {
t.Fatalf("unexpected error: %v", err)
}
}
func TestExtractPretokenizerSkipsUnsupportedSequenceSplit(t *testing.T) {
data := []byte(`{
"type": "Sequence",
"pretokenizers": [
{
"type": "Split",
"pattern": {
"Regex": "(?:\\r?\\n)+(?!\\r?\\n)"
}
},
{
"type": "Split",
"pattern": {
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
}
}
]
}`)
pattern := extractPretokenizer(data)
if pattern == "" {
t.Fatal("expected supported Split pretokenizer")
}
if strings.Contains(pattern, `(?!\r?\n)`) {
t.Fatalf("selected unsupported newline splitter: %q", pattern)
}
}