mirror of
https://github.com/ollama/ollama.git
synced 2026-05-19 14:18:20 -04:00
New models (#15861)
* mlx: add laguna model support * convert: support fp8 safetensors import Decode HF F8_E4M3 safetensors with block scale companions into GGUF-supported tensor types, and record which output tensors came from FP8 source weights. Use that source-precision metadata during create quantization: default FP8-sourced GGUFs to Q8_0, keep non-FP8 tensors at their original precision for Q8_0, and promote non-FP8 quantizable tensors to Q8_0 for Q4_K requests. * ggml: add laguna model support * server: preserve generate logprobs with builtin parsers Generate requests were dropping logprob-only chunks whenever a builtin parser buffered visible content. Chat already handled this case, but generate only forwarded chunks with visible response, thinking, or tool-call output. Keep generate chunks that carry logprobs even when the builtin parser has not flushed visible content yet, and add a regression test that exercises the behavior with a generic thinking parser. * review comments - perf improvements * ggml: implement nemotron 3 nano omni * add poolside integration * update poolside doc * adapt to new cache setup * fix test * fix test --------- Co-authored-by: Eva Ho <hoyyeva@gmail.com>
This commit is contained in:
@@ -75,8 +75,7 @@ func TestIntegrationLookup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIntegrationRegistry(t *testing.T) {
|
||||
expectedIntegrations := []string{"claude", "codex", "kimi", "droid", "opencode", "hermes"}
|
||||
|
||||
expectedIntegrations := []string{"claude", "codex", "kimi", "droid", "opencode", "hermes", "pool"}
|
||||
for _, name := range expectedIntegrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
r, ok := integrations[name]
|
||||
@@ -1494,6 +1493,11 @@ func TestIntegration_InstallHint(t *testing.T) {
|
||||
input: "openclaw",
|
||||
wantURL: "https://docs.openclaw.ai",
|
||||
},
|
||||
{
|
||||
name: "pool has hint",
|
||||
input: "pool",
|
||||
wantURL: "https://github.com/poolsideai/pool",
|
||||
},
|
||||
{
|
||||
name: "unknown has no hint",
|
||||
input: "unknown",
|
||||
@@ -1549,7 +1553,19 @@ func TestListIntegrationInfos(t *testing.T) {
|
||||
for _, info := range infos {
|
||||
got = append(got, info.Name)
|
||||
}
|
||||
if diff := compareStrings(got, integrationOrder); diff != "" {
|
||||
|
||||
want := append([]string(nil), integrationOrder...)
|
||||
if poolsideGOOS == "windows" {
|
||||
filtered := make([]string, 0, len(want))
|
||||
for _, name := range want {
|
||||
if name != "pool" {
|
||||
filtered = append(filtered, name)
|
||||
}
|
||||
}
|
||||
want = filtered
|
||||
}
|
||||
|
||||
if diff := compareStrings(got, want); diff != "" {
|
||||
t.Fatalf("launcher integration order mismatch: %s", diff)
|
||||
}
|
||||
})
|
||||
@@ -1567,6 +1583,9 @@ func TestListIntegrationInfos(t *testing.T) {
|
||||
|
||||
t.Run("includes known integrations", func(t *testing.T) {
|
||||
known := map[string]bool{"claude": false, "codex": false, "opencode": false}
|
||||
if poolsideGOOS != "windows" {
|
||||
known["pool"] = false
|
||||
}
|
||||
for _, info := range infos {
|
||||
if _, ok := known[info.Name]; ok {
|
||||
known[info.Name] = true
|
||||
@@ -1601,6 +1620,17 @@ func TestListIntegrationInfos(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
func TestListIntegrationInfos_HidesPoolsideOnWindows(t *testing.T) {
|
||||
prev := poolsideGOOS
|
||||
poolsideGOOS = "windows"
|
||||
t.Cleanup(func() { poolsideGOOS = prev })
|
||||
|
||||
for _, info := range ListIntegrationInfos() {
|
||||
if info.Name == "pool" {
|
||||
t.Fatal("expected pool to be hidden on Windows")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_Descriptions(t *testing.T) {
|
||||
t.Run("installed recommended has base description", func(t *testing.T) {
|
||||
@@ -1707,6 +1737,20 @@ func TestIntegration_AutoInstallable(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureIntegrationInstalled_PoolsideUnsupportedOnWindows(t *testing.T) {
|
||||
prev := poolsideGOOS
|
||||
poolsideGOOS = "windows"
|
||||
t.Cleanup(func() { poolsideGOOS = prev })
|
||||
|
||||
err := EnsureIntegrationInstalled("pool", &Poolside{})
|
||||
if err == nil {
|
||||
t.Fatal("expected Windows unsupported error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not currently supported on Windows") {
|
||||
t.Fatalf("expected Windows warning, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
@@ -213,6 +213,7 @@ Supported integrations:
|
||||
opencode OpenCode
|
||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||
pi Pi
|
||||
pool Poolside
|
||||
vscode VS Code (aliases: code)
|
||||
|
||||
Examples:
|
||||
|
||||
51
cmd/launch/poolside.go
Normal file
51
cmd/launch/poolside.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Poolside implements Runner for Poolside's CLI.
|
||||
type Poolside struct{}
|
||||
|
||||
var poolsideGOOS = runtime.GOOS
|
||||
|
||||
func (p *Poolside) String() string { return "Poolside" }
|
||||
|
||||
func poolsideUnsupportedError() error {
|
||||
return fmt.Errorf("Warning: Poolside is not currently supported on Windows")
|
||||
}
|
||||
|
||||
func (p *Poolside) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (p *Poolside) Run(model string, args []string) error {
|
||||
if poolsideGOOS == "windows" {
|
||||
return poolsideUnsupportedError()
|
||||
}
|
||||
|
||||
bin, err := exec.LookPath("pool")
|
||||
if err != nil {
|
||||
return fmt.Errorf("pool is not installed")
|
||||
}
|
||||
|
||||
cmd := exec.Command(bin, p.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"POOLSIDE_STANDALONE_BASE_URL="+envconfig.Host().String()+"/v1",
|
||||
"POOLSIDE_API_KEY=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
88
cmd/launch/poolside_test.go
Normal file
88
cmd/launch/poolside_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPoolsideArgs(t *testing.T) {
|
||||
p := &Poolside{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
extra []string
|
||||
want []string
|
||||
}{
|
||||
{name: "with model", model: "qwen3.5", want: []string{"-m", "qwen3.5"}},
|
||||
{name: "without model", extra: []string{"session"}, want: []string{"session"}},
|
||||
{name: "with model and extra args", model: "llama3.2", extra: []string{"--foo", "bar"}, want: []string{"-m", "llama3.2", "--foo", "bar"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := p.args(tt.model, tt.extra)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Fatalf("args(%q, %v) = %v, want %v", tt.model, tt.extra, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolsideRunSetsOllamaEnv(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell fake binary")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "pool.log")
|
||||
poolPath := filepath.Join(tmpDir, "pool")
|
||||
script := "#!/bin/sh\n" +
|
||||
"printf 'base=%s\\nkey=%s\\nargs=%s\\n' \"$POOLSIDE_STANDALONE_BASE_URL\" \"$POOLSIDE_API_KEY\" \"$*\" > \"" + logPath + "\"\n"
|
||||
if err := os.WriteFile(poolPath, []byte(script), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake pool binary: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PATH", tmpDir)
|
||||
t.Setenv("OLLAMA_HOST", "http://127.0.0.1:11434")
|
||||
|
||||
p := &Poolside{}
|
||||
if err := p.Run("qwen3.5", []string{"session"}); err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read pool log: %v", err)
|
||||
}
|
||||
|
||||
got := string(data)
|
||||
if !strings.Contains(got, "base=http://127.0.0.1:11434/v1") {
|
||||
t.Fatalf("expected Poolside base URL override in log, got:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, "key=ollama") {
|
||||
t.Fatalf("expected Poolside API key override in log, got:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, "args=-m qwen3.5 session") {
|
||||
t.Fatalf("expected model and extra args in log, got:\n%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolsideRunWindowsUnsupported(t *testing.T) {
|
||||
prev := poolsideGOOS
|
||||
poolsideGOOS = "windows"
|
||||
t.Cleanup(func() { poolsideGOOS = prev })
|
||||
|
||||
p := &Poolside{}
|
||||
err := p.Run("kimi-k2.6:cloud", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected Windows unsupported error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not currently supported on Windows") {
|
||||
t.Fatalf("expected Windows warning, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -33,7 +33,7 @@ type IntegrationInfo struct {
|
||||
Description string
|
||||
}
|
||||
|
||||
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi"}
|
||||
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi", "pool"}
|
||||
|
||||
var integrationSpecs = []*IntegrationSpec{
|
||||
{
|
||||
@@ -166,6 +166,18 @@ var integrationSpecs = []*IntegrationSpec{
|
||||
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "pool",
|
||||
Runner: &Poolside{},
|
||||
Description: "Poolside's software agent for enterprise development",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("pool")
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://github.com/poolsideai/pool",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "hermes",
|
||||
Runner: &Hermes{},
|
||||
@@ -285,6 +297,9 @@ func ListVisibleIntegrationSpecs() []IntegrationSpec {
|
||||
if spec.Hidden {
|
||||
continue
|
||||
}
|
||||
if spec.Name == "pool" && poolsideGOOS == "windows" {
|
||||
continue
|
||||
}
|
||||
visible = append(visible, *spec)
|
||||
}
|
||||
|
||||
@@ -399,6 +414,10 @@ func EnsureIntegrationInstalled(name string, runner Runner) error {
|
||||
return fmt.Errorf("%s is not installed", runner)
|
||||
}
|
||||
|
||||
if integration.spec.Name == "pool" && poolsideGOOS == "windows" {
|
||||
return poolsideUnsupportedError()
|
||||
}
|
||||
|
||||
if integration.installed {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -45,6 +45,14 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||
return filepath.Join(home, ".pi", "agent", "models.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pool",
|
||||
binary: "pool",
|
||||
runner: &Poolside{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".poolside", "config")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "kimi",
|
||||
binary: "kimi",
|
||||
@@ -57,6 +65,10 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.name == "pool" && poolsideGOOS == "windows" {
|
||||
t.Skip("Poolside is intentionally unsupported on Windows")
|
||||
}
|
||||
|
||||
home := t.TempDir()
|
||||
setTestHome(t, home)
|
||||
|
||||
|
||||
@@ -316,6 +316,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &deepseek2Model{}
|
||||
case "Glm4MoeLiteForCausalLM":
|
||||
conv = &glm4MoeLiteModel{}
|
||||
case "LagunaForCausalLM":
|
||||
conv = &lagunaModel{}
|
||||
case "GlmOcrForConditionalGeneration":
|
||||
conv = &glmOcrModel{}
|
||||
case "Lfm2ForCausalLM", "Lfm2MoeForCausalLM":
|
||||
@@ -324,6 +326,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &lfm2VLTextModel{}
|
||||
case "Qwen3NextForCausalLM", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration":
|
||||
conv = &qwen3NextModel{}
|
||||
case "NemotronH_Nano_VL_V2", "NemotronH_Nano_Omni_Reasoning_V3":
|
||||
conv = &nemotronHNanoVLModel{}
|
||||
case "NemotronHForCausalLM":
|
||||
conv = &nemotronHModel{}
|
||||
default:
|
||||
@@ -387,6 +391,10 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
}
|
||||
|
||||
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
|
||||
for k, v := range sourceTensorKV(ts) {
|
||||
kv[k] = v
|
||||
}
|
||||
|
||||
for i := range ts {
|
||||
ts[i].Shape = slices.Clone(ts[i].Shape)
|
||||
slices.Reverse(ts[i].Shape)
|
||||
|
||||
604
convert/convert_laguna.go
Normal file
604
convert/convert_laguna.go
Normal file
@@ -0,0 +1,604 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
iofs "io/fs"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type lagunaModel struct {
|
||||
ModelParameters
|
||||
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
Gating lagunaGatingMode `json:"gating"`
|
||||
QKNormType string `json:"qk_norm_type"`
|
||||
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
NumAttentionHeadsPerLayer []uint32 `json:"num_attention_heads_per_layer"`
|
||||
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
|
||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
SharedExpertIntermediateSize uint32 `json:"shared_expert_intermediate_size"`
|
||||
NormTopKProb bool `json:"norm_topk_prob"`
|
||||
MoeRoutedScalingFactor float32 `json:"moe_routed_scaling_factor"`
|
||||
MoERouterUseSigmoid bool `json:"moe_router_use_sigmoid"`
|
||||
MoEApplyRouterWeightOnInput bool `json:"moe_apply_router_weight_on_input"`
|
||||
DecoderSparseStep uint32 `json:"decoder_sparse_step"`
|
||||
MLPOnlyLayers []uint32 `json:"mlp_only_layers"`
|
||||
MLPLayerTypes []string `json:"mlp_layer_types"`
|
||||
|
||||
RopeParameters lagunaRopeParameters `json:"rope_parameters"`
|
||||
SwaRopeParameters lagunaRopeParameters `json:"swa_rope_parameters"`
|
||||
|
||||
SwaAttentionSinkEnabled bool `json:"swa_attention_sink_enabled"`
|
||||
}
|
||||
|
||||
type lagunaGatingMode string
|
||||
|
||||
type lagunaRopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeType string `json:"rope_type"`
|
||||
Type string `json:"type"`
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
AttentionFactor float32 `json:"attention_factor"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
}
|
||||
|
||||
type lagunaRopeConfig struct {
|
||||
flat lagunaRopeParameters
|
||||
full lagunaRopeParameters
|
||||
sliding lagunaRopeParameters
|
||||
nested bool
|
||||
}
|
||||
|
||||
func (g *lagunaGatingMode) UnmarshalJSON(b []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(b, &s); err == nil {
|
||||
*g = lagunaGatingMode(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
var enabled bool
|
||||
if err := json.Unmarshal(b, &enabled); err == nil {
|
||||
if enabled {
|
||||
*g = "true"
|
||||
} else {
|
||||
*g = "false"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if string(b) == "null" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unsupported Laguna gating JSON value %s", string(b))
|
||||
}
|
||||
|
||||
func (g lagunaGatingMode) perHead() bool {
|
||||
return strings.EqualFold(string(g), "per-head") || strings.EqualFold(string(g), "true")
|
||||
}
|
||||
|
||||
func (r *lagunaRopeConfig) UnmarshalJSON(b []byte) error {
|
||||
if string(b) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var probe map[string]json.RawMessage
|
||||
if err := json.Unmarshal(b, &probe); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(probe) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if raw, ok := probe["full_attention"]; ok {
|
||||
r.nested = true
|
||||
if err := json.Unmarshal(raw, &r.full); err != nil {
|
||||
return err
|
||||
}
|
||||
if raw = probe["sliding_attention"]; raw != nil {
|
||||
if err := json.Unmarshal(raw, &r.sliding); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if raw, ok := probe["global_attention"]; ok {
|
||||
r.nested = true
|
||||
if err := json.Unmarshal(raw, &r.full); err != nil {
|
||||
return err
|
||||
}
|
||||
if raw = probe["sliding_attention"]; raw != nil {
|
||||
if err := json.Unmarshal(raw, &r.sliding); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal(b, &r.flat)
|
||||
}
|
||||
|
||||
func (r lagunaRopeConfig) fullParams() lagunaRopeParameters {
|
||||
if r.nested {
|
||||
return r.full
|
||||
}
|
||||
return r.flat
|
||||
}
|
||||
|
||||
func (r lagunaRopeConfig) slidingParams() (lagunaRopeParameters, bool) {
|
||||
if !r.nested {
|
||||
return lagunaRopeParameters{}, false
|
||||
}
|
||||
return r.sliding, true
|
||||
}
|
||||
|
||||
func (r lagunaRopeParameters) ropeType() string {
|
||||
return cmp.Or(r.RopeType, r.Type)
|
||||
}
|
||||
|
||||
func (r lagunaRopeParameters) withDefaultPartialRotaryFactor(v float32) lagunaRopeParameters {
|
||||
if r.PartialRotaryFactor == 0 {
|
||||
r.PartialRotaryFactor = v
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r lagunaRopeParameters) empty() bool {
|
||||
return r == (lagunaRopeParameters{})
|
||||
}
|
||||
|
||||
type rawLagunaModel struct {
|
||||
ModelParameters
|
||||
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
Gating lagunaGatingMode `json:"gating"`
|
||||
QKNormType string `json:"qk_norm_type"`
|
||||
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
NumAttentionHeadsPerLayer []uint32 `json:"num_attention_heads_per_layer"`
|
||||
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
|
||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
SharedExpertIntermediateSize uint32 `json:"shared_expert_intermediate_size"`
|
||||
NormTopKProb *bool `json:"norm_topk_prob"`
|
||||
MoeRoutedScalingFactor float32 `json:"moe_routed_scaling_factor"`
|
||||
MoERouterUseSigmoid *bool `json:"moe_router_use_sigmoid"`
|
||||
MoEApplyRouterWeightOnInput bool `json:"moe_apply_router_weight_on_input"`
|
||||
DecoderSparseStep uint32 `json:"decoder_sparse_step"`
|
||||
MLPOnlyLayers []uint32 `json:"mlp_only_layers"`
|
||||
MLPLayerTypes []string `json:"mlp_layer_types"`
|
||||
|
||||
RopeParameters lagunaRopeConfig `json:"rope_parameters"`
|
||||
SwaRopeParameters lagunaRopeParameters `json:"swa_rope_parameters"`
|
||||
|
||||
SwaAttentionSinkEnabled bool `json:"swa_attention_sink_enabled"`
|
||||
}
|
||||
|
||||
func (p *lagunaModel) UnmarshalJSON(b []byte) error {
|
||||
var raw rawLagunaModel
|
||||
if err := json.Unmarshal(b, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mlpOnlyLayers, err := lagunaDenseLayers(raw.MLPOnlyLayers, raw.MLPLayerTypes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fullRope := raw.RopeParameters.fullParams().withDefaultPartialRotaryFactor(cmp.Or(raw.PartialRotaryFactor, float32(1)))
|
||||
swaRope := raw.SwaRopeParameters
|
||||
if nestedSwa, ok := raw.RopeParameters.slidingParams(); ok && !nestedSwa.empty() {
|
||||
swaRope = nestedSwa
|
||||
}
|
||||
swaRope = swaRope.withDefaultPartialRotaryFactor(cmp.Or(fullRope.PartialRotaryFactor, float32(1)))
|
||||
|
||||
*p = lagunaModel{
|
||||
ModelParameters: raw.ModelParameters,
|
||||
NumHiddenLayers: raw.NumHiddenLayers,
|
||||
HiddenSize: raw.HiddenSize,
|
||||
IntermediateSize: raw.IntermediateSize,
|
||||
NumAttentionHeads: raw.NumAttentionHeads,
|
||||
NumKeyValueHeads: raw.NumKeyValueHeads,
|
||||
HeadDim: raw.HeadDim,
|
||||
RMSNormEPS: raw.RMSNormEPS,
|
||||
MaxPositionEmbeddings: raw.MaxPositionEmbeddings,
|
||||
SlidingWindow: raw.SlidingWindow,
|
||||
PartialRotaryFactor: cmp.Or(raw.PartialRotaryFactor, fullRope.PartialRotaryFactor),
|
||||
Gating: raw.Gating,
|
||||
QKNormType: cmp.Or(raw.QKNormType, "rmsnorm"),
|
||||
LayerTypes: raw.LayerTypes,
|
||||
NumAttentionHeadsPerLayer: raw.NumAttentionHeadsPerLayer,
|
||||
NumExperts: raw.NumExperts,
|
||||
NumExpertsPerTok: raw.NumExpertsPerTok,
|
||||
MoEIntermediateSize: raw.MoEIntermediateSize,
|
||||
SharedExpertIntermediateSize: raw.SharedExpertIntermediateSize,
|
||||
NormTopKProb: defaultBool(raw.NormTopKProb, true),
|
||||
MoeRoutedScalingFactor: raw.MoeRoutedScalingFactor,
|
||||
MoERouterUseSigmoid: defaultBool(raw.MoERouterUseSigmoid, true),
|
||||
MoEApplyRouterWeightOnInput: raw.MoEApplyRouterWeightOnInput,
|
||||
DecoderSparseStep: raw.DecoderSparseStep,
|
||||
MLPOnlyLayers: mlpOnlyLayers,
|
||||
MLPLayerTypes: raw.MLPLayerTypes,
|
||||
RopeParameters: fullRope,
|
||||
SwaRopeParameters: swaRope,
|
||||
SwaAttentionSinkEnabled: raw.SwaAttentionSinkEnabled,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultBool(v *bool, fallback bool) bool {
|
||||
if v == nil {
|
||||
return fallback
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
const (
|
||||
lagunaGatingFuncSoftmax uint32 = 1
|
||||
lagunaGatingFuncSigmoid uint32 = 2
|
||||
|
||||
lagunaLayerTypeGlobal uint32 = 0
|
||||
lagunaLayerTypeSliding uint32 = 1
|
||||
)
|
||||
|
||||
func (p *lagunaModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "laguna"
|
||||
// Laguna's chat template and built-in renderer both emit the leading
|
||||
// special token explicitly. Auto-prepending BOS here would duplicate it.
|
||||
kv["tokenizer.ggml.add_bos_token"] = false
|
||||
kv["tokenizer.ggml.pre"] = "laguna"
|
||||
// Laguna does not need tokenizer.chat_template at runtime: Ollama create
|
||||
// sets the Laguna renderer/parser from the architecture, and the renderer
|
||||
// owns prompt formatting.
|
||||
delete(kv, "tokenizer.chat_template")
|
||||
|
||||
kv["laguna.block_count"] = p.NumHiddenLayers
|
||||
kv["laguna.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["laguna.embedding_length"] = p.HiddenSize
|
||||
kv["laguna.feed_forward_length"] = p.IntermediateSize
|
||||
|
||||
if len(p.NumAttentionHeadsPerLayer) == int(p.NumHiddenLayers) {
|
||||
kv["laguna.attention.head_count"] = p.NumAttentionHeadsPerLayer
|
||||
} else {
|
||||
kv["laguna.attention.head_count"] = p.NumAttentionHeads
|
||||
}
|
||||
kv["laguna.attention.head_count_kv"] = p.NumKeyValueHeads
|
||||
kv["laguna.attention.key_length"] = p.HeadDim
|
||||
kv["laguna.attention.value_length"] = p.HeadDim
|
||||
kv["laguna.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
kv["laguna.attention.sliding_window"] = p.SlidingWindow
|
||||
kv["laguna.attention.sink_enabled"] = p.SwaAttentionSinkEnabled
|
||||
|
||||
if len(p.LayerTypes) > 0 {
|
||||
encoded := make([]uint32, len(p.LayerTypes))
|
||||
slidingPattern := make([]bool, len(p.LayerTypes))
|
||||
for i, layerType := range p.LayerTypes {
|
||||
if lagunaLayerIsSliding(layerType) {
|
||||
encoded[i] = lagunaLayerTypeSliding
|
||||
slidingPattern[i] = true
|
||||
} else {
|
||||
encoded[i] = lagunaLayerTypeGlobal
|
||||
}
|
||||
}
|
||||
kv["laguna.attention.layer_types"] = encoded
|
||||
kv["laguna.attention.sliding_window_pattern"] = slidingPattern
|
||||
}
|
||||
|
||||
if p.Gating.perHead() {
|
||||
kv["laguna.attention.gating_type"] = uint32(1)
|
||||
} else {
|
||||
kv["laguna.attention.gating_type"] = uint32(0)
|
||||
}
|
||||
kv["laguna.attention.qk_norm"] = p.QKNormType == "rmsnorm"
|
||||
|
||||
kv["laguna.expert_count"] = p.NumExperts
|
||||
kv["laguna.expert_used_count"] = p.NumExpertsPerTok
|
||||
kv["laguna.expert_feed_forward_length"] = p.MoEIntermediateSize
|
||||
kv["laguna.expert_shared_feed_forward_length"] = p.SharedExpertIntermediateSize
|
||||
kv["laguna.expert_shared_count"] = uint32(1)
|
||||
kv["laguna.expert_weights_norm"] = p.NormTopKProb
|
||||
kv["laguna.expert_weights_scale"] = p.MoeRoutedScalingFactor
|
||||
kv["laguna.expert_gating_func"] = lagunaMoeGatingFunc(p.MoERouterUseSigmoid)
|
||||
kv["laguna.decoder_sparse_step"] = cmp.Or(p.DecoderSparseStep, uint32(1))
|
||||
|
||||
if leading, ok := lagunaLeadingDensePrefix(p.MLPOnlyLayers); ok {
|
||||
kv["laguna.leading_dense_block_count"] = leading
|
||||
}
|
||||
if len(p.MLPOnlyLayers) > 0 {
|
||||
kv["laguna.dense_layers"] = p.MLPOnlyLayers
|
||||
}
|
||||
|
||||
ropeType := p.RopeParameters.ropeType()
|
||||
kv["laguna.rope.freq_base"] = cmp.Or(p.RopeParameters.RopeTheta, float32(10000))
|
||||
kv["laguna.rope.scaling.type"] = ropeType
|
||||
ropeFactor := cmp.Or(p.RopeParameters.Factor, float32(1))
|
||||
kv["laguna.rope.scaling.factor"] = ropeFactor
|
||||
kv["laguna.rope.scaling.original_context_length"] = p.RopeParameters.OriginalMaxPositionEmbeddings
|
||||
kv["laguna.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
|
||||
kv["laguna.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow
|
||||
kv["laguna.rope.scaling.attn_factor"] = lagunaAttentionFactor(ropeType, ropeFactor, p.RopeParameters.AttentionFactor)
|
||||
kv["laguna.rope.partial_rotary_factor"] = cmp.Or(p.PartialRotaryFactor, float32(1))
|
||||
|
||||
swaRopeType := p.SwaRopeParameters.ropeType()
|
||||
kv["laguna.rope.swa.freq_base"] = cmp.Or(p.SwaRopeParameters.RopeTheta, float32(10000))
|
||||
kv["laguna.rope.swa.scaling.type"] = cmp.Or(swaRopeType, "linear")
|
||||
kv["laguna.rope.swa.scaling.factor"] = cmp.Or(p.SwaRopeParameters.Factor, float32(1))
|
||||
kv["laguna.rope.swa.partial_rotary_factor"] = cmp.Or(p.SwaRopeParameters.PartialRotaryFactor, float32(1))
|
||||
|
||||
headDim := p.HeadDim
|
||||
if headDim == 0 && p.NumAttentionHeads > 0 {
|
||||
headDim = p.HiddenSize / p.NumAttentionHeads
|
||||
}
|
||||
kv["laguna.rope.dimension_count"] = lagunaRopeDim(headDim, cmp.Or(p.PartialRotaryFactor, float32(1)))
|
||||
kv["laguna.rope.swa.dimension_count"] = lagunaRopeDim(headDim, cmp.Or(p.SwaRopeParameters.PartialRotaryFactor, float32(1)))
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *lagunaModel) parseMore(_ iofs.FS) error {
|
||||
return p.validate()
|
||||
}
|
||||
|
||||
func (p *lagunaModel) validate() error {
|
||||
if p.NumHiddenLayers == 0 {
|
||||
return fmt.Errorf("laguna: num_hidden_layers must be set")
|
||||
}
|
||||
if p.HiddenSize == 0 {
|
||||
return fmt.Errorf("laguna: hidden_size must be set")
|
||||
}
|
||||
if p.HeadDim == 0 {
|
||||
return fmt.Errorf("laguna: head_dim must be set")
|
||||
}
|
||||
if p.NumKeyValueHeads == 0 {
|
||||
return fmt.Errorf("laguna: num_key_value_heads must be set")
|
||||
}
|
||||
if p.SwaAttentionSinkEnabled {
|
||||
return fmt.Errorf("laguna: unsupported swa_attention_sink_enabled=true")
|
||||
}
|
||||
if !p.Gating.perHead() {
|
||||
return fmt.Errorf("laguna: unsupported attention gating %q: only gating=\"per-head\" is supported", p.Gating)
|
||||
}
|
||||
if p.QKNormType != "rmsnorm" {
|
||||
return fmt.Errorf("laguna: unsupported qk_norm_type %q: only rmsnorm is supported", p.QKNormType)
|
||||
}
|
||||
if !p.MoERouterUseSigmoid {
|
||||
return fmt.Errorf("laguna: unsupported moe_router_use_sigmoid=false")
|
||||
}
|
||||
if p.MoEApplyRouterWeightOnInput {
|
||||
return fmt.Errorf("laguna: unsupported moe_apply_router_weight_on_input=true")
|
||||
}
|
||||
if p.DecoderSparseStep != 0 && p.DecoderSparseStep != 1 {
|
||||
return fmt.Errorf("laguna: unsupported decoder_sparse_step=%d: only 1 is supported", p.DecoderSparseStep)
|
||||
}
|
||||
if len(p.MLPOnlyLayers) != 1 || p.MLPOnlyLayers[0] != 0 {
|
||||
return fmt.Errorf("laguna: unsupported mlp_only_layers=%v: only [0] is supported", p.MLPOnlyLayers)
|
||||
}
|
||||
if p.NumExperts == 0 {
|
||||
return fmt.Errorf("laguna: num_experts must be set")
|
||||
}
|
||||
if p.NumExpertsPerTok == 0 {
|
||||
return fmt.Errorf("laguna: num_experts_per_tok must be set")
|
||||
}
|
||||
if p.MoEIntermediateSize == 0 {
|
||||
return fmt.Errorf("laguna: moe_intermediate_size must be set")
|
||||
}
|
||||
if p.SharedExpertIntermediateSize == 0 {
|
||||
return fmt.Errorf("laguna: shared_expert_intermediate_size must be set")
|
||||
}
|
||||
|
||||
if len(p.LayerTypes) > 0 && len(p.LayerTypes) != int(p.NumHiddenLayers) {
|
||||
return fmt.Errorf("laguna: layer_types has %d entries, expected %d", len(p.LayerTypes), p.NumHiddenLayers)
|
||||
}
|
||||
for i, layerType := range p.LayerTypes {
|
||||
if !lagunaLayerIsGlobal(layerType) && !lagunaLayerIsSliding(layerType) {
|
||||
return fmt.Errorf("laguna: unsupported layer_types[%d]=%q", i, layerType)
|
||||
}
|
||||
}
|
||||
if len(p.NumAttentionHeadsPerLayer) > 0 && len(p.NumAttentionHeadsPerLayer) != int(p.NumHiddenLayers) {
|
||||
return fmt.Errorf("laguna: num_attention_heads_per_layer has %d entries, expected %d", len(p.NumAttentionHeadsPerLayer), p.NumHiddenLayers)
|
||||
}
|
||||
if len(p.NumAttentionHeadsPerLayer) == 0 && p.NumAttentionHeads == 0 {
|
||||
return fmt.Errorf("laguna: num_attention_heads or num_attention_heads_per_layer must be set")
|
||||
}
|
||||
for i, heads := range p.NumAttentionHeadsPerLayer {
|
||||
if heads == 0 {
|
||||
return fmt.Errorf("laguna: num_attention_heads_per_layer[%d] must be non-zero", i)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *lagunaModel) numHeadsForLayer(layer uint32) uint32 {
|
||||
if len(p.NumAttentionHeadsPerLayer) > int(layer) && p.NumAttentionHeadsPerLayer[layer] > 0 {
|
||||
return p.NumAttentionHeadsPerLayer[layer]
|
||||
}
|
||||
return p.NumAttentionHeads
|
||||
}
|
||||
|
||||
func (p *lagunaModel) layerUsesMoE(layer uint32) bool {
|
||||
for _, denseLayer := range p.MLPOnlyLayers {
|
||||
if denseLayer == layer {
|
||||
return false
|
||||
}
|
||||
}
|
||||
step := cmp.Or(p.DecoderSparseStep, uint32(1))
|
||||
return p.NumExperts > 0 && (layer+1)%step == 0
|
||||
}
|
||||
|
||||
func (p *lagunaModel) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"self_attn.g_proj", "attn_g",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate.weight", "ffn_gate_inp.weight",
|
||||
"mlp.experts.e_score_correction_bias", "exp_probs_b.bias",
|
||||
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_expert.up_proj", "ffn_up_shexp",
|
||||
"mlp.shared_expert.down_proj", "ffn_down_shexp",
|
||||
"mlp.experts.*.gate_proj", "ffn_gate_exps",
|
||||
"mlp.experts.*.up_proj", "ffn_up_exps",
|
||||
"mlp.experts.*.down_proj", "ffn_down_exps",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *lagunaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
// Current Laguna drops store routed MoE experts as separate per-expert
|
||||
// tensors. GGUF stores each projection as one stacked tensor. If future
|
||||
// drops change expert naming or layout, update these patterns with a
|
||||
// focused conversion test using the new tensor names.
|
||||
merges := make([]merge, 0, p.NumHiddenLayers*3)
|
||||
for i := range p.NumHiddenLayers {
|
||||
merges = append(merges,
|
||||
merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
},
|
||||
merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
},
|
||||
merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
out, rest := mergeTensors(ts, merges...)
|
||||
for _, t := range rest {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *lagunaModel) specialTokenTypes() []string {
|
||||
return []string{"bos", "eos", "pad", "unk"}
|
||||
}
|
||||
|
||||
func lagunaLayerIsSliding(layerType string) bool {
|
||||
return strings.EqualFold(layerType, "sliding_attention")
|
||||
}
|
||||
|
||||
func lagunaLayerIsGlobal(layerType string) bool {
|
||||
return strings.EqualFold(layerType, "full_attention") || strings.EqualFold(layerType, "global_attention")
|
||||
}
|
||||
|
||||
func lagunaLeadingDensePrefix(layers []uint32) (uint32, bool) {
|
||||
for i, v := range layers {
|
||||
if v != uint32(i) {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
return uint32(len(layers)), true
|
||||
}
|
||||
|
||||
func lagunaDenseLayers(mlpOnlyLayers []uint32, mlpLayerTypes []string) ([]uint32, error) {
|
||||
if len(mlpOnlyLayers) > 0 {
|
||||
return mlpOnlyLayers, nil
|
||||
}
|
||||
if len(mlpLayerTypes) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
denseLayers := make([]uint32, 0, len(mlpLayerTypes))
|
||||
for i, layerType := range mlpLayerTypes {
|
||||
switch {
|
||||
case strings.EqualFold(layerType, "dense"):
|
||||
denseLayers = append(denseLayers, uint32(i))
|
||||
case strings.EqualFold(layerType, "sparse"):
|
||||
default:
|
||||
return nil, fmt.Errorf("laguna: unsupported mlp_layer_types[%d]=%q", i, layerType)
|
||||
}
|
||||
}
|
||||
return denseLayers, nil
|
||||
}
|
||||
|
||||
func lagunaMoeGatingFunc(useSigmoid bool) uint32 {
|
||||
if useSigmoid {
|
||||
return lagunaGatingFuncSigmoid
|
||||
}
|
||||
return lagunaGatingFuncSoftmax
|
||||
}
|
||||
|
||||
func lagunaAttentionFactor(ropeType string, scaleFactor, attentionFactor float32) float32 {
|
||||
if attentionFactor != 0 {
|
||||
return attentionFactor
|
||||
}
|
||||
if strings.EqualFold(ropeType, "yarn") && scaleFactor > 1 {
|
||||
return float32(0.1*math.Log(float64(scaleFactor)) + 1)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func lagunaRopeDim(headDim uint32, partialRotaryFactor float32) uint32 {
|
||||
if headDim == 0 {
|
||||
return 0
|
||||
}
|
||||
dim := uint32(float32(headDim) * partialRotaryFactor)
|
||||
if dim == 0 || dim > headDim {
|
||||
dim = headDim
|
||||
}
|
||||
if dim%2 != 0 {
|
||||
dim--
|
||||
}
|
||||
if dim == 0 {
|
||||
return headDim
|
||||
}
|
||||
return dim
|
||||
}
|
||||
|
||||
var (
|
||||
_ ModelConverter = (*lagunaModel)(nil)
|
||||
_ moreParser = (*lagunaModel)(nil)
|
||||
)
|
||||
450
convert/convert_laguna_test.go
Normal file
450
convert/convert_laguna_test.go
Normal file
@@ -0,0 +1,450 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type lagunaTestTensor struct {
|
||||
tensorBase
|
||||
}
|
||||
|
||||
func newLagunaTestTensor(name string, shape ...uint64) Tensor {
|
||||
return &lagunaTestTensor{tensorBase: tensorBase{name: name, shape: shape}}
|
||||
}
|
||||
|
||||
func (t *lagunaTestTensor) WriteTo(io.Writer) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (t *lagunaTestTensor) Clone() Tensor {
|
||||
return &lagunaTestTensor{tensorBase: tensorBase{
|
||||
name: t.name,
|
||||
shape: append([]uint64(nil), t.shape...),
|
||||
}}
|
||||
}
|
||||
|
||||
func TestLagunaReplacements(t *testing.T) {
|
||||
p := lagunaModel{}
|
||||
r := strings.NewReplacer(p.Replacements()...)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"embed", "model.embed_tokens.weight", "token_embd.weight"},
|
||||
{"final_norm", "model.norm.weight", "output_norm.weight"},
|
||||
{"lm_head", "lm_head.weight", "output.weight"},
|
||||
{"block prefix", "model.layers.7.input_layernorm.weight", "blk.7.attn_norm.weight"},
|
||||
{"q", "model.layers.3.self_attn.q_proj.weight", "blk.3.attn_q.weight"},
|
||||
{"k", "model.layers.3.self_attn.k_proj.weight", "blk.3.attn_k.weight"},
|
||||
{"v", "model.layers.3.self_attn.v_proj.weight", "blk.3.attn_v.weight"},
|
||||
{"o", "model.layers.3.self_attn.o_proj.weight", "blk.3.attn_output.weight"},
|
||||
{"g", "model.layers.3.self_attn.g_proj.weight", "blk.3.attn_g.weight"},
|
||||
{"q_norm", "model.layers.3.self_attn.q_norm.weight", "blk.3.attn_q_norm.weight"},
|
||||
{"k_norm", "model.layers.3.self_attn.k_norm.weight", "blk.3.attn_k_norm.weight"},
|
||||
{"post_attn_norm", "model.layers.3.post_attention_layernorm.weight", "blk.3.ffn_norm.weight"},
|
||||
{"dense gate", "model.layers.0.mlp.gate_proj.weight", "blk.0.ffn_gate.weight"},
|
||||
{"dense up", "model.layers.0.mlp.up_proj.weight", "blk.0.ffn_up.weight"},
|
||||
{"dense down", "model.layers.0.mlp.down_proj.weight", "blk.0.ffn_down.weight"},
|
||||
{"shexp gate", "model.layers.5.mlp.shared_expert.gate_proj.weight", "blk.5.ffn_gate_shexp.weight"},
|
||||
{"shexp up", "model.layers.5.mlp.shared_expert.up_proj.weight", "blk.5.ffn_up_shexp.weight"},
|
||||
{"shexp down", "model.layers.5.mlp.shared_expert.down_proj.weight", "blk.5.ffn_down_shexp.weight"},
|
||||
{"router", "model.layers.5.mlp.gate.weight", "blk.5.ffn_gate_inp.weight"},
|
||||
{"score bias", "model.layers.5.mlp.experts.e_score_correction_bias", "blk.5.exp_probs_b.bias"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := r.Replace(tc.in); got != tc.want {
|
||||
t.Errorf("Replace(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaValidateRejectsUnsupportedVariants(t *testing.T) {
|
||||
base := validLagunaTestModel()
|
||||
tests := []struct {
|
||||
name string
|
||||
edit func(*lagunaModel)
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "per-element gating",
|
||||
edit: func(m *lagunaModel) {
|
||||
m.Gating = "per-element"
|
||||
},
|
||||
want: "unsupported attention gating",
|
||||
},
|
||||
{
|
||||
name: "attention sinks",
|
||||
edit: func(m *lagunaModel) {
|
||||
m.SwaAttentionSinkEnabled = true
|
||||
},
|
||||
want: "swa_attention_sink_enabled=true",
|
||||
},
|
||||
{
|
||||
name: "qk norm disabled",
|
||||
edit: func(m *lagunaModel) {
|
||||
m.QKNormType = "none"
|
||||
},
|
||||
want: "unsupported qk_norm_type",
|
||||
},
|
||||
{
|
||||
name: "softmax moe",
|
||||
edit: func(m *lagunaModel) {
|
||||
m.MoERouterUseSigmoid = false
|
||||
},
|
||||
want: "moe_router_use_sigmoid=false",
|
||||
},
|
||||
{
|
||||
name: "router weight on input",
|
||||
edit: func(m *lagunaModel) {
|
||||
m.MoEApplyRouterWeightOnInput = true
|
||||
},
|
||||
want: "moe_apply_router_weight_on_input=true",
|
||||
},
|
||||
{
|
||||
name: "unknown layer type",
|
||||
edit: func(m *lagunaModel) {
|
||||
m.LayerTypes[1] = "local_attention"
|
||||
},
|
||||
want: "unsupported layer_types[1]",
|
||||
},
|
||||
{
|
||||
name: "nonstandard dense layout",
|
||||
edit: func(m *lagunaModel) {
|
||||
m.MLPOnlyLayers = []uint32{0, 3}
|
||||
},
|
||||
want: "unsupported mlp_only_layers",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
m := base
|
||||
m.LayerTypes = append([]string(nil), base.LayerTypes...)
|
||||
m.NumAttentionHeadsPerLayer = append([]uint32(nil), base.NumAttentionHeadsPerLayer...)
|
||||
m.MLPOnlyLayers = append([]uint32(nil), base.MLPOnlyLayers...)
|
||||
tc.edit(&m)
|
||||
err := m.validate()
|
||||
if err == nil || !strings.Contains(err.Error(), tc.want) {
|
||||
t.Fatalf("validate() error = %v, want substring %q", err, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaGAConfigNormalizesBoolGatingAndNestedRope(t *testing.T) {
|
||||
var m lagunaModel
|
||||
if err := json.Unmarshal([]byte(`{
|
||||
"architectures": ["LagunaForCausalLM"],
|
||||
"num_hidden_layers": 1,
|
||||
"hidden_size": 8,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 4,
|
||||
"gating": true,
|
||||
"num_experts": 2,
|
||||
"num_experts_per_tok": 1,
|
||||
"moe_intermediate_size": 4,
|
||||
"shared_expert_intermediate_size": 4,
|
||||
"decoder_sparse_step": 1,
|
||||
"mlp_layer_types": ["dense"],
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"rope_theta": 500000,
|
||||
"rope_type": "yarn",
|
||||
"factor": 32,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"beta_fast": 64,
|
||||
"beta_slow": 1,
|
||||
"attention_factor": 1,
|
||||
"partial_rotary_factor": 0.5
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000,
|
||||
"rope_type": "default",
|
||||
"partial_rotary_factor": 1
|
||||
}
|
||||
}
|
||||
}`), &m); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
if err := m.validate(); err != nil {
|
||||
t.Fatalf("validate() error = %v", err)
|
||||
}
|
||||
if m.Gating != "true" {
|
||||
t.Fatalf("Gating = %q, want raw true marker", m.Gating)
|
||||
}
|
||||
if !m.Gating.perHead() {
|
||||
t.Fatal("expected bool gating to normalize as per-head support")
|
||||
}
|
||||
if m.QKNormType != "rmsnorm" {
|
||||
t.Fatalf("QKNormType = %q, want rmsnorm default", m.QKNormType)
|
||||
}
|
||||
if !m.MoERouterUseSigmoid {
|
||||
t.Fatal("MoERouterUseSigmoid should default true")
|
||||
}
|
||||
if !m.NormTopKProb {
|
||||
t.Fatal("NormTopKProb should default true")
|
||||
}
|
||||
if diff := cmp.Diff(m.MLPOnlyLayers, []uint32{0}); diff != "" {
|
||||
t.Fatalf("MLPOnlyLayers mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if m.RopeParameters.RopeTheta != 500000 || m.RopeParameters.PartialRotaryFactor != 0.5 {
|
||||
t.Fatalf("full rope = %#v, want theta=500000 partial=0.5", m.RopeParameters)
|
||||
}
|
||||
if m.SwaRopeParameters.RopeTheta != 10000 || m.SwaRopeParameters.PartialRotaryFactor != 1 {
|
||||
t.Fatalf("swa rope = %#v, want theta=10000 partial=1", m.SwaRopeParameters)
|
||||
}
|
||||
}
|
||||
|
||||
func validLagunaTestModel() lagunaModel {
|
||||
return lagunaModel{
|
||||
ModelParameters: ModelParameters{
|
||||
VocabSize: 32,
|
||||
},
|
||||
NumHiddenLayers: 2,
|
||||
HiddenSize: 8,
|
||||
IntermediateSize: 16,
|
||||
NumAttentionHeads: 2,
|
||||
NumKeyValueHeads: 1,
|
||||
HeadDim: 4,
|
||||
RMSNormEPS: 1e-6,
|
||||
MaxPositionEmbeddings: 4096,
|
||||
SlidingWindow: 512,
|
||||
Gating: "per-head",
|
||||
QKNormType: "rmsnorm",
|
||||
LayerTypes: []string{"global_attention", "sliding_attention"},
|
||||
NumAttentionHeadsPerLayer: []uint32{2, 2},
|
||||
NumExperts: 2,
|
||||
NumExpertsPerTok: 1,
|
||||
MoEIntermediateSize: 4,
|
||||
SharedExpertIntermediateSize: 4,
|
||||
NormTopKProb: true,
|
||||
MoeRoutedScalingFactor: 2.5,
|
||||
MoERouterUseSigmoid: true,
|
||||
DecoderSparseStep: 1,
|
||||
MLPOnlyLayers: []uint32{0},
|
||||
}
|
||||
}
|
||||
|
||||
func validLagunaTestTensors(m lagunaModel) []Tensor {
|
||||
ts := []Tensor{
|
||||
newLagunaTestTensor("token_embd.weight", uint64(m.VocabSize), uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor("output_norm.weight", uint64(m.HiddenSize)),
|
||||
}
|
||||
|
||||
for layer := range m.NumHiddenLayers {
|
||||
prefix := fmt.Sprintf("blk.%d", layer)
|
||||
heads := uint64(m.numHeadsForLayer(layer))
|
||||
attnWidth := heads * uint64(m.HeadDim)
|
||||
kvWidth := uint64(m.NumKeyValueHeads * m.HeadDim)
|
||||
ts = append(ts,
|
||||
newLagunaTestTensor(prefix+".attn_norm.weight", uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor(prefix+".ffn_norm.weight", uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor(prefix+".attn_q.weight", attnWidth, uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor(prefix+".attn_k.weight", kvWidth, uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor(prefix+".attn_v.weight", kvWidth, uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor(prefix+".attn_output.weight", uint64(m.HiddenSize), attnWidth),
|
||||
newLagunaTestTensor(prefix+".attn_g.weight", heads, uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor(prefix+".attn_q_norm.weight", uint64(m.HeadDim)),
|
||||
newLagunaTestTensor(prefix+".attn_k_norm.weight", uint64(m.HeadDim)),
|
||||
)
|
||||
|
||||
if m.layerUsesMoE(layer) {
|
||||
ts = append(ts,
|
||||
newLagunaTestTensor(prefix+".ffn_gate_inp.weight", uint64(m.NumExperts), uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor(prefix+".exp_probs_b.bias", uint64(m.NumExperts)),
|
||||
newLagunaTestTensor(prefix+".ffn_gate_shexp.weight", uint64(m.SharedExpertIntermediateSize), uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor(prefix+".ffn_up_shexp.weight", uint64(m.SharedExpertIntermediateSize), uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor(prefix+".ffn_down_shexp.weight", uint64(m.HiddenSize), uint64(m.SharedExpertIntermediateSize)),
|
||||
)
|
||||
for expert := range m.NumExperts {
|
||||
ts = append(ts,
|
||||
newLagunaTestTensor(fmt.Sprintf("%s.mlp.experts.%d.gate_proj.weight", prefix, expert), uint64(m.MoEIntermediateSize), uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor(fmt.Sprintf("%s.mlp.experts.%d.up_proj.weight", prefix, expert), uint64(m.MoEIntermediateSize), uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor(fmt.Sprintf("%s.mlp.experts.%d.down_proj.weight", prefix, expert), uint64(m.HiddenSize), uint64(m.MoEIntermediateSize)),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
ts = append(ts,
|
||||
newLagunaTestTensor(prefix+".ffn_gate.weight", uint64(m.IntermediateSize), uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor(prefix+".ffn_up.weight", uint64(m.IntermediateSize), uint64(m.HiddenSize)),
|
||||
newLagunaTestTensor(prefix+".ffn_down.weight", uint64(m.HiddenSize), uint64(m.IntermediateSize)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
func TestLagunaTensorsMergeRoutedExperts(t *testing.T) {
|
||||
m := validLagunaTestModel()
|
||||
out := m.Tensors(validLagunaTestTensors(m))
|
||||
|
||||
tensors := make(map[string]*ggml.Tensor, len(out))
|
||||
for _, t := range out {
|
||||
tensors[t.Name] = t
|
||||
}
|
||||
|
||||
tests := map[string][]uint64{
|
||||
"blk.1.ffn_gate_exps.weight": {uint64(m.NumExperts), uint64(m.MoEIntermediateSize), uint64(m.HiddenSize)},
|
||||
"blk.1.ffn_up_exps.weight": {uint64(m.NumExperts), uint64(m.MoEIntermediateSize), uint64(m.HiddenSize)},
|
||||
"blk.1.ffn_down_exps.weight": {uint64(m.NumExperts), uint64(m.HiddenSize), uint64(m.MoEIntermediateSize)},
|
||||
}
|
||||
for name, wantShape := range tests {
|
||||
tensor, ok := tensors[name]
|
||||
if !ok {
|
||||
t.Fatalf("missing merged tensor %q", name)
|
||||
}
|
||||
if diff := cmp.Diff(wantShape, tensor.Shape); diff != "" {
|
||||
t.Fatalf("%s shape mismatch (-want +got):\n%s", name, diff)
|
||||
}
|
||||
}
|
||||
|
||||
for expert := range m.NumExperts {
|
||||
name := fmt.Sprintf("blk.1.mlp.experts.%d.gate_proj.weight", expert)
|
||||
if _, ok := tensors[name]; ok {
|
||||
t.Fatalf("unexpected unmerged expert tensor %q", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaKVShape(t *testing.T) {
|
||||
m := lagunaModel{
|
||||
NumHiddenLayers: 4,
|
||||
HiddenSize: 128,
|
||||
IntermediateSize: 256,
|
||||
NumAttentionHeads: 8,
|
||||
NumKeyValueHeads: 4,
|
||||
HeadDim: 16,
|
||||
RMSNormEPS: 1e-6,
|
||||
MaxPositionEmbeddings: 4096,
|
||||
SlidingWindow: 512,
|
||||
PartialRotaryFactor: 0.5,
|
||||
Gating: "per-head",
|
||||
QKNormType: "rmsnorm",
|
||||
LayerTypes: []string{"full_attention", "sliding_attention", "sliding_attention", "sliding_attention"},
|
||||
NumAttentionHeadsPerLayer: []uint32{8, 16, 16, 16},
|
||||
NumExperts: 32,
|
||||
NumExpertsPerTok: 4,
|
||||
MoEIntermediateSize: 64,
|
||||
SharedExpertIntermediateSize: 64,
|
||||
NormTopKProb: true,
|
||||
MoeRoutedScalingFactor: 2.5,
|
||||
MoERouterUseSigmoid: true,
|
||||
DecoderSparseStep: 1,
|
||||
MLPOnlyLayers: []uint32{0},
|
||||
}
|
||||
m.RopeParameters.RopeTheta = 500000
|
||||
m.RopeParameters.RopeType = "yarn"
|
||||
m.RopeParameters.Factor = 32
|
||||
m.RopeParameters.OriginalMaxPositionEmbeddings = 4096
|
||||
m.RopeParameters.BetaFast = 64
|
||||
m.RopeParameters.BetaSlow = 1
|
||||
m.SwaRopeParameters.RopeTheta = 10000
|
||||
m.SwaRopeParameters.RopeType = "linear"
|
||||
m.SwaRopeParameters.Factor = 1
|
||||
m.SwaRopeParameters.PartialRotaryFactor = 1
|
||||
|
||||
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}, Template: "{% include 'chat_template.jinja' %}"})
|
||||
required := []string{
|
||||
"general.architecture",
|
||||
"tokenizer.ggml.pre",
|
||||
"laguna.block_count",
|
||||
"laguna.context_length",
|
||||
"laguna.embedding_length",
|
||||
"laguna.feed_forward_length",
|
||||
"laguna.attention.head_count",
|
||||
"laguna.attention.head_count_kv",
|
||||
"laguna.attention.key_length",
|
||||
"laguna.attention.value_length",
|
||||
"laguna.attention.layer_norm_rms_epsilon",
|
||||
"laguna.attention.sliding_window",
|
||||
"laguna.attention.layer_types",
|
||||
"laguna.attention.sliding_window_pattern",
|
||||
"laguna.attention.gating_type",
|
||||
"laguna.attention.qk_norm",
|
||||
"laguna.expert_count",
|
||||
"laguna.expert_used_count",
|
||||
"laguna.expert_feed_forward_length",
|
||||
"laguna.expert_shared_feed_forward_length",
|
||||
"laguna.expert_shared_count",
|
||||
"laguna.expert_weights_norm",
|
||||
"laguna.expert_weights_scale",
|
||||
"laguna.expert_gating_func",
|
||||
"laguna.leading_dense_block_count",
|
||||
"laguna.dense_layers",
|
||||
"laguna.rope.freq_base",
|
||||
"laguna.rope.scaling.type",
|
||||
"laguna.rope.scaling.factor",
|
||||
"laguna.rope.partial_rotary_factor",
|
||||
"laguna.rope.swa.freq_base",
|
||||
"laguna.rope.swa.scaling.type",
|
||||
"laguna.rope.dimension_count",
|
||||
"laguna.rope.swa.dimension_count",
|
||||
}
|
||||
for _, k := range required {
|
||||
if _, ok := kv[k]; !ok {
|
||||
t.Errorf("missing required KV: %s", k)
|
||||
}
|
||||
}
|
||||
|
||||
if got := kv["general.architecture"]; got != "laguna" {
|
||||
t.Errorf("architecture = %v, want laguna", got)
|
||||
}
|
||||
if got := kv["tokenizer.ggml.add_bos_token"]; got != false {
|
||||
t.Errorf("tokenizer.ggml.add_bos_token = %v, want false", got)
|
||||
}
|
||||
if _, ok := kv["tokenizer.chat_template"]; ok {
|
||||
t.Fatal("tokenizer.chat_template should be omitted for Laguna")
|
||||
}
|
||||
if got := kv["laguna.expert_gating_func"]; got != lagunaGatingFuncSigmoid {
|
||||
t.Errorf("expert_gating_func = %v, want sigmoid(%d)", got, lagunaGatingFuncSigmoid)
|
||||
}
|
||||
if got := kv["laguna.leading_dense_block_count"]; got != uint32(1) {
|
||||
t.Errorf("leading_dense_block_count = %v, want 1", got)
|
||||
}
|
||||
if got := kv["laguna.rope.dimension_count"]; got != uint32(8) {
|
||||
t.Errorf("rope.dimension_count = %v, want 8", got)
|
||||
}
|
||||
if got := kv["laguna.rope.swa.dimension_count"]; got != uint32(16) {
|
||||
t.Errorf("rope.swa.dimension_count = %v, want 16", got)
|
||||
}
|
||||
if got, ok := kv["laguna.attention.layer_types"].([]uint32); !ok || len(got) != 4 || got[0] != 0 || got[1] != 1 || got[2] != 1 || got[3] != 1 {
|
||||
t.Fatalf("layer_types = %#v, want [0 1 1 1]", kv["laguna.attention.layer_types"])
|
||||
}
|
||||
if got, ok := kv["laguna.attention.sliding_window_pattern"].([]bool); !ok || len(got) != 4 || got[0] || !got[1] || !got[2] || !got[3] {
|
||||
t.Fatalf("sliding_window_pattern = %#v, want [false true true true]", kv["laguna.attention.sliding_window_pattern"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaKVYarnAttentionFactorFallback(t *testing.T) {
|
||||
m := validLagunaTestModel()
|
||||
m.RopeParameters.RopeType = "yarn"
|
||||
m.RopeParameters.Factor = 32
|
||||
|
||||
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||
got, ok := kv["laguna.rope.scaling.attn_factor"].(float32)
|
||||
if !ok {
|
||||
t.Fatalf("attn_factor type = %T, want float32", kv["laguna.rope.scaling.attn_factor"])
|
||||
}
|
||||
|
||||
want := float32(0.1*math.Log(32) + 1)
|
||||
if diff := math.Abs(float64(got - want)); diff > 1e-6 {
|
||||
t.Fatalf("attn_factor = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package convert
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
@@ -69,7 +70,415 @@ type nemotronHModel struct {
|
||||
ExpertGroupUsedCount uint32 `json:"topk_group"`
|
||||
}
|
||||
|
||||
type nemotronHNanoVLModel struct {
|
||||
ModelParameters
|
||||
MaxSequenceLength uint32 `json:"max_sequence_length"`
|
||||
ForceImageSize uint32 `json:"force_image_size"`
|
||||
DownsampleRatio float32 `json:"downsample_ratio"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
UseThumbnail *bool `json:"use_thumbnail"`
|
||||
ImgContextTokenID uint32 `json:"img_context_token_id"`
|
||||
ImgContextToken string `json:"img_context_token"`
|
||||
ImgStartToken string `json:"img_start_token"`
|
||||
ImgEndToken string `json:"img_end_token"`
|
||||
VitHiddenSize uint32 `json:"vit_hidden_size"`
|
||||
ProjectorHidden uint32 `json:"projector_hidden_size"`
|
||||
SoundContextTokenID uint32 `json:"sound_context_token_id"`
|
||||
SoundContextToken string `json:"sound_context_token"`
|
||||
NormMean []float32 `json:"norm_mean"`
|
||||
NormStd []float32 `json:"norm_std"`
|
||||
VisionConfig radioConfig `json:"vision_config"`
|
||||
SoundConfig soundConfig `json:"sound_config"`
|
||||
LLMConfig nemotronHModel `json:"llm_config"`
|
||||
Preprocessor struct {
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
DownsampleRatio float32 `json:"downsample_ratio"`
|
||||
MaxNumTiles uint32 `json:"max_num_tiles"`
|
||||
UseThumbnail *bool `json:"use_thumbnail"`
|
||||
NormMean []float32 `json:"norm_mean"`
|
||||
NormStd []float32 `json:"norm_std"`
|
||||
}
|
||||
}
|
||||
|
||||
type soundConfig struct {
|
||||
ModelType string `json:"model_type"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
ConvKernelSize uint32 `json:"conv_kernel_size"`
|
||||
SubsamplingConvChannels uint32 `json:"subsampling_conv_channels"`
|
||||
SubsamplingConvKernelSize uint32 `json:"subsampling_conv_kernel_size"`
|
||||
SubsamplingConvStride uint32 `json:"subsampling_conv_stride"`
|
||||
SubsamplingFactor uint32 `json:"subsampling_factor"`
|
||||
NumMelBins uint32 `json:"num_mel_bins"`
|
||||
ProjectionHiddenSize uint32 `json:"projection_hidden_size"`
|
||||
SamplingRate uint32 `json:"sampling_rate"`
|
||||
ScaleInput bool `json:"scale_input"`
|
||||
}
|
||||
|
||||
type radioConfig struct {
|
||||
Version string `json:"version"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
MaxResolution uint32 `json:"max_resolution"`
|
||||
MinNumPatches uint32 `json:"min_num_patches"`
|
||||
MaxNumPatches uint32 `json:"max_num_patches"`
|
||||
SeparateVideoEmbedder bool `json:"separate_video_embedder"`
|
||||
Args struct {
|
||||
MinNumPatches uint32 `json:"min_num_patches"`
|
||||
MaxNumPatches uint32 `json:"max_num_patches"`
|
||||
} `json:"args"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*nemotronHModel)(nil)
|
||||
var _ ModelConverter = (*nemotronHNanoVLModel)(nil)
|
||||
|
||||
func (n *nemotronHNanoVLModel) parseMore(fsys fs.FS) error {
|
||||
if n.MaxSequenceLength > 0 {
|
||||
n.LLMConfig.MaxPositionEmbeddings = n.MaxSequenceLength
|
||||
}
|
||||
|
||||
if err := n.LLMConfig.parseMore(fsys); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if bts, err := fs.ReadFile(fsys, "preprocessor_config.json"); err == nil {
|
||||
if err := json.Unmarshal(bts, &n.Preprocessor); err != nil {
|
||||
return fmt.Errorf("nemotron_h_omni: parse preprocessor_config.json: %w", err)
|
||||
}
|
||||
} else if !errors.Is(err, fs.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
|
||||
if version := strings.TrimSpace(n.VisionConfig.Version); version != "" && version != "radio_v2.5-h" {
|
||||
return fmt.Errorf("nemotron_h_omni: unsupported RADIO version %q", version)
|
||||
}
|
||||
if patchSize := n.visionPatchSize(); patchSize != 16 {
|
||||
return fmt.Errorf("nemotron_h_omni: unsupported vision patch_size=%d", patchSize)
|
||||
}
|
||||
if scale := n.visionProjectorScaleFactor(); scale != 2 {
|
||||
return fmt.Errorf("nemotron_h_omni: unsupported vision projector scale factor=%d", scale)
|
||||
}
|
||||
|
||||
if n.SoundConfig.NumHiddenLayers > 0 {
|
||||
if modelType := strings.TrimSpace(n.SoundConfig.ModelType); modelType != "" && modelType != "parakeet" {
|
||||
return fmt.Errorf("nemotron_h_omni: unsupported sound model_type %q", modelType)
|
||||
}
|
||||
if n.soundHiddenSize() == 0 {
|
||||
return fmt.Errorf("nemotron_h_omni: sound hidden_size must be set")
|
||||
}
|
||||
if n.soundAttentionHeads() == 0 {
|
||||
return fmt.Errorf("nemotron_h_omni: sound num_attention_heads must be set")
|
||||
}
|
||||
if n.soundSubsamplingFactor() != 8 {
|
||||
return fmt.Errorf("nemotron_h_omni: unsupported sound subsampling_factor=%d", n.soundSubsamplingFactor())
|
||||
}
|
||||
if n.soundMelBins() != 128 {
|
||||
return fmt.Errorf("nemotron_h_omni: unsupported sound num_mel_bins=%d", n.soundMelBins())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) KV(t *Tokenizer) KV {
|
||||
kv := n.LLMConfig.KV(t)
|
||||
kv["general.architecture"] = "nemotron_h_omni"
|
||||
|
||||
kv["vision.block_count"] = n.visionBlockCount()
|
||||
kv["vision.embedding_length"] = n.visionEmbeddingLength()
|
||||
kv["vision.feed_forward_length"] = n.visionFeedForwardLength()
|
||||
kv["vision.attention.head_count"] = n.visionAttentionHeads()
|
||||
kv["vision.attention.layer_norm_epsilon"] = float32(1e-6)
|
||||
kv["vision.patch_size"] = n.visionPatchSize()
|
||||
kv["vision.image_size"] = n.visionImageSize()
|
||||
kv["vision.max_tiles"] = n.visionMaxTiles()
|
||||
kv["vision.use_thumbnail"] = n.visionUseThumbnail()
|
||||
if minPatches := n.visionMinNumPatches(); minPatches > 0 {
|
||||
kv["vision.min_num_patches"] = minPatches
|
||||
}
|
||||
if maxPatches := n.visionMaxNumPatches(); maxPatches > 0 {
|
||||
kv["vision.max_num_patches"] = maxPatches
|
||||
}
|
||||
kv["vision.num_channels"] = uint32(3)
|
||||
kv["vision.image_mean"] = slices.Clone(defaultFloat32Slice(n.visionMean(), imageNetStandardMean))
|
||||
kv["vision.image_std"] = slices.Clone(defaultFloat32Slice(n.visionStd(), imageNetStandardSTD))
|
||||
kv["vision.projector.scale_factor"] = n.visionProjectorScaleFactor()
|
||||
|
||||
setTokenID := func(key string, explicit uint32, token string) {
|
||||
if explicit > 0 {
|
||||
kv[key] = explicit
|
||||
return
|
||||
}
|
||||
if t == nil || t.Vocabulary == nil {
|
||||
return
|
||||
}
|
||||
for i, v := range t.Vocabulary.Tokens {
|
||||
if v == token {
|
||||
kv[key] = uint32(i)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setTokenID("vision.image_token_id", n.ImgContextTokenID, cmp.Or(n.ImgContextToken, "<image>"))
|
||||
setTokenID("vision.image_start_token_id", 0, cmp.Or(n.ImgStartToken, "<img>"))
|
||||
setTokenID("vision.image_end_token_id", 0, cmp.Or(n.ImgEndToken, "</img>"))
|
||||
|
||||
if n.SoundConfig.NumHiddenLayers > 0 {
|
||||
kv["audio.block_count"] = n.SoundConfig.NumHiddenLayers
|
||||
kv["audio.embedding_length"] = n.soundHiddenSize()
|
||||
kv["audio.feed_forward_length"] = n.soundFeedForwardLength()
|
||||
kv["audio.attention.head_count"] = n.soundAttentionHeads()
|
||||
kv["audio.attention.layer_norm_epsilon"] = float32(1e-5)
|
||||
kv["audio.conv_kernel_size"] = n.soundConvKernelSize()
|
||||
kv["audio.num_mel_bins"] = n.soundMelBins()
|
||||
kv["audio.sample_rate"] = n.soundSampleRate()
|
||||
kv["audio.subsampling_factor"] = n.soundSubsamplingFactor()
|
||||
kv["audio.subsampling_conv_channels"] = n.soundSubsamplingConvChannels()
|
||||
kv["audio.subsampling_conv_kernel_size"] = n.soundSubsamplingConvKernelSize()
|
||||
kv["audio.subsampling_conv_stride"] = n.soundSubsamplingConvStride()
|
||||
kv["audio.projection_hidden_size"] = n.soundProjectionHiddenSize()
|
||||
kv["audio.scale_input"] = n.SoundConfig.ScaleInput
|
||||
setTokenID("audio.sound_token_id", n.SoundContextTokenID, cmp.Or(n.SoundContextToken, "<so_embedding>"))
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var textTensors []Tensor
|
||||
var out []*ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
switch {
|
||||
case isNemotronHNanoVLOmittedTensor(t.Name()):
|
||||
continue
|
||||
case strings.Contains(t.Name(), ".attn_qkv"):
|
||||
out = append(out, slices.Collect(splitDim(t, 0,
|
||||
split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")},
|
||||
split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")},
|
||||
split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")},
|
||||
))...)
|
||||
case t.Name() == "v.position_embd":
|
||||
shape := t.Shape()
|
||||
if len(shape) == 3 && shape[0] == 1 {
|
||||
shape = shape[1:]
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
WriterTo: t,
|
||||
})
|
||||
case strings.HasPrefix(t.Name(), "a.") || strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm."):
|
||||
name := t.Name()
|
||||
shape := slices.Clone(t.Shape())
|
||||
if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, ".conv_dw.") && strings.HasSuffix(name, ".weight") && len(shape) == 3 {
|
||||
t.SetRepacker(squeezeMiddleDim)
|
||||
shape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
if strings.HasPrefix(name, "a.blk.") && (strings.Contains(name, ".conv_pw1.") || strings.Contains(name, ".conv_pw2.")) && strings.HasSuffix(name, ".weight") && len(shape) == 3 && shape[2] == 1 {
|
||||
t.SetRepacker(squeezeLastDim)
|
||||
shape = shape[:2]
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
WriterTo: t,
|
||||
})
|
||||
default:
|
||||
textTensors = append(textTensors, t)
|
||||
}
|
||||
}
|
||||
|
||||
return append(n.LLMConfig.Tensors(textTensors), out...)
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) Replacements() []string {
|
||||
return append([]string{
|
||||
"language_model.", "",
|
||||
"vision_model.radio_model.model.patch_generator.embedder", "v.patch_embd",
|
||||
"vision_model.radio_model.model.patch_generator.pos_embed", "v.position_embd",
|
||||
"vision_model.radio_model.model.patch_generator.cls_token.token", "v.cls_embd",
|
||||
"vision_model.radio_model.model.blocks", "v.blk",
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.proj", "attn_out",
|
||||
"mlp.fc1", "ffn_up",
|
||||
"mlp.fc2", "ffn_down",
|
||||
"norm1", "ln1",
|
||||
"norm2", "ln2",
|
||||
"mlp1.0", "mm.norm",
|
||||
"mlp1.1", "mm.1",
|
||||
"mlp1.3", "mm.2",
|
||||
"sound_encoder.encoder.feature_extractor.featurizer.fb", "a.feature_extractor.fb",
|
||||
"sound_encoder.encoder.feature_extractor.featurizer.window", "a.feature_extractor.window",
|
||||
"sound_encoder.encoder.subsampling.layers.0", "a.subsampling.conv0",
|
||||
"sound_encoder.encoder.subsampling.layers.2", "a.subsampling.dw1",
|
||||
"sound_encoder.encoder.subsampling.layers.3", "a.subsampling.pw1",
|
||||
"sound_encoder.encoder.subsampling.layers.5", "a.subsampling.dw2",
|
||||
"sound_encoder.encoder.subsampling.layers.6", "a.subsampling.pw2",
|
||||
"sound_encoder.encoder.subsampling.linear", "a.subsampling.linear",
|
||||
"sound_encoder.encoder.layers", "a.blk",
|
||||
"feed_forward1.linear1", "ffn1_up",
|
||||
"feed_forward1.linear2", "ffn1_down",
|
||||
"feed_forward2.linear1", "ffn2_up",
|
||||
"feed_forward2.linear2", "ffn2_down",
|
||||
"norm_feed_forward1", "ffn1_norm",
|
||||
"norm_feed_forward2", "ffn2_norm",
|
||||
"norm_self_att", "attn_norm",
|
||||
"norm_conv", "conv_norm",
|
||||
"norm_out", "out_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_out",
|
||||
"self_attn.relative_k_proj", "attn_rel_k",
|
||||
"self_attn.bias_u", "attn_bias_u",
|
||||
"self_attn.bias_v", "attn_bias_v",
|
||||
"conv.pointwise_conv1", "conv_pw1",
|
||||
"conv.pointwise_conv2", "conv_pw2",
|
||||
"conv.depthwise_conv", "conv_dw",
|
||||
"conv.norm", "conv_bn",
|
||||
"sound_projection.norm", "mm.a.norm",
|
||||
"sound_projection.linear1", "mm.a.1",
|
||||
"sound_projection.linear2", "mm.a.2",
|
||||
}, n.LLMConfig.Replacements()...)
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) specialTokenTypes() []string {
|
||||
return n.LLMConfig.specialTokenTypes()
|
||||
}
|
||||
|
||||
func isNemotronHNanoVLOmittedTensor(name string) bool {
|
||||
return strings.HasSuffix(name, ".conv_bn.num_batches_tracked") ||
|
||||
strings.HasPrefix(name, "vision_model.radio_model.input_conditioner.") ||
|
||||
strings.HasPrefix(name, "vision_model.radio_model.model.patch_generator.video_embedder")
|
||||
}
|
||||
|
||||
func squeezeLastDim(_ string, data []float32, _ []uint64) ([]float32, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) visionImageSize() uint32 {
|
||||
return cmp.Or(n.ForceImageSize, n.Preprocessor.ImageSize, uint32(512))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) visionPatchSize() uint32 {
|
||||
return cmp.Or(n.PatchSize, n.Preprocessor.PatchSize, n.VisionConfig.PatchSize, uint32(16))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) visionProjectorScaleFactor() uint32 {
|
||||
ratio := cmp.Or(n.DownsampleRatio, n.Preprocessor.DownsampleRatio, float32(0.5))
|
||||
if ratio <= 0 {
|
||||
return 2
|
||||
}
|
||||
|
||||
return max(uint32(1), uint32(math.Round(1.0/float64(ratio))))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) visionBlockCount() uint32 {
|
||||
return 32
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) visionEmbeddingLength() uint32 {
|
||||
return cmp.Or(n.VitHiddenSize, uint32(1280))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) visionAttentionHeads() uint32 {
|
||||
return 16
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) visionFeedForwardLength() uint32 {
|
||||
return 4 * n.visionEmbeddingLength()
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) visionMaxTiles() uint32 {
|
||||
return cmp.Or(n.Preprocessor.MaxNumTiles, uint32(12))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) visionMinNumPatches() uint32 {
|
||||
return cmp.Or(n.VisionConfig.MinNumPatches, n.VisionConfig.Args.MinNumPatches)
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) visionMaxNumPatches() uint32 {
|
||||
return cmp.Or(n.VisionConfig.MaxNumPatches, n.VisionConfig.Args.MaxNumPatches)
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) visionUseThumbnail() bool {
|
||||
for _, v := range []*bool{n.UseThumbnail, n.Preprocessor.UseThumbnail} {
|
||||
if v != nil {
|
||||
return *v
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) visionMean() []float32 {
|
||||
if len(n.NormMean) > 0 {
|
||||
return n.NormMean
|
||||
}
|
||||
return n.Preprocessor.NormMean
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) visionStd() []float32 {
|
||||
if len(n.NormStd) > 0 {
|
||||
return n.NormStd
|
||||
}
|
||||
return n.Preprocessor.NormStd
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) soundHiddenSize() uint32 {
|
||||
return cmp.Or(n.SoundConfig.HiddenSize, uint32(1024))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) soundAttentionHeads() uint32 {
|
||||
return cmp.Or(n.SoundConfig.NumAttentionHeads, uint32(8))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) soundFeedForwardLength() uint32 {
|
||||
return cmp.Or(n.SoundConfig.IntermediateSize, 4*n.soundHiddenSize())
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) soundConvKernelSize() uint32 {
|
||||
return cmp.Or(n.SoundConfig.ConvKernelSize, uint32(9))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) soundMelBins() uint32 {
|
||||
return cmp.Or(n.SoundConfig.NumMelBins, uint32(128))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) soundSampleRate() uint32 {
|
||||
return cmp.Or(n.SoundConfig.SamplingRate, uint32(16000))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) soundSubsamplingFactor() uint32 {
|
||||
return cmp.Or(n.SoundConfig.SubsamplingFactor, uint32(8))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) soundSubsamplingConvChannels() uint32 {
|
||||
return cmp.Or(n.SoundConfig.SubsamplingConvChannels, uint32(256))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) soundSubsamplingConvKernelSize() uint32 {
|
||||
return cmp.Or(n.SoundConfig.SubsamplingConvKernelSize, uint32(3))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) soundSubsamplingConvStride() uint32 {
|
||||
return cmp.Or(n.SoundConfig.SubsamplingConvStride, uint32(2))
|
||||
}
|
||||
|
||||
func (n *nemotronHNanoVLModel) soundProjectionHiddenSize() uint32 {
|
||||
return cmp.Or(n.SoundConfig.ProjectionHiddenSize, uint32(4096))
|
||||
}
|
||||
|
||||
var (
|
||||
imageNetStandardMean = []float32{0.48145466, 0.4578275, 0.40821073}
|
||||
imageNetStandardSTD = []float32{0.26862954, 0.26130258, 0.27577711}
|
||||
)
|
||||
|
||||
func (n *nemotronHModel) parseMore(_ fs.FS) error {
|
||||
if n.NumHiddenLayers == 0 {
|
||||
|
||||
@@ -217,6 +217,316 @@ func TestNemotronHLoadModelMetadata(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHNanoVLLoadModelMetadata(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
config := `{
|
||||
"architectures": ["NemotronH_Nano_VL_V2"],
|
||||
"model_type": "NemotronH_Nano_VL_V2",
|
||||
"max_sequence_length": 131072,
|
||||
"force_image_size": 512,
|
||||
"downsample_ratio": 0.5,
|
||||
"patch_size": 16,
|
||||
"use_thumbnail": true,
|
||||
"img_context_token_id": 18,
|
||||
"img_context_token": "<image>",
|
||||
"img_start_token": "<img>",
|
||||
"img_end_token": "</img>",
|
||||
"sound_context_token_id": 27,
|
||||
"sound_context_token": "<so_embedding>",
|
||||
"vit_hidden_size": 1280,
|
||||
"projector_hidden_size": 20480,
|
||||
"norm_mean": [0.48145466, 0.4578275, 0.40821073],
|
||||
"norm_std": [0.26862954, 0.26130258, 0.27577711],
|
||||
"vision_config": {
|
||||
"version": "radio_v2.5-h",
|
||||
"patch_size": 16,
|
||||
"max_resolution": 2048,
|
||||
"separate_video_embedder": true
|
||||
},
|
||||
"sound_config": {
|
||||
"model_type": "parakeet",
|
||||
"hidden_size": 1024,
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 24,
|
||||
"intermediate_size": 4096,
|
||||
"conv_kernel_size": 9,
|
||||
"subsampling_conv_channels": 256,
|
||||
"subsampling_conv_kernel_size": 3,
|
||||
"subsampling_conv_stride": 2,
|
||||
"subsampling_factor": 8,
|
||||
"num_mel_bins": 128,
|
||||
"projection_hidden_size": 4096,
|
||||
"sampling_rate": 16000
|
||||
},
|
||||
"llm_config": {
|
||||
"architectures": ["NemotronHForCausalLM"],
|
||||
"model_type": "nemotron_h",
|
||||
"num_hidden_layers": 4,
|
||||
"hidden_size": 512,
|
||||
"max_position_embeddings": 262144,
|
||||
"num_attention_heads": 8,
|
||||
"num_key_value_heads": 2,
|
||||
"head_dim": 64,
|
||||
"layer_norm_epsilon": 1e-5,
|
||||
"conv_kernel": 4,
|
||||
"ssm_state_size": 128,
|
||||
"mamba_num_heads": 16,
|
||||
"mamba_head_dim": 32,
|
||||
"n_groups": 8,
|
||||
"hybrid_override_pattern": "ME*M",
|
||||
"n_routed_experts": 16,
|
||||
"num_experts_per_tok": 4,
|
||||
"moe_intermediate_size": 256
|
||||
}
|
||||
}`
|
||||
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "preprocessor_config.json"), []byte(`{
|
||||
"image_size": 512,
|
||||
"patch_size": 16,
|
||||
"downsample_ratio": 0.5,
|
||||
"max_num_tiles": 12,
|
||||
"use_thumbnail": true,
|
||||
"norm_mean": [0.48145466, 0.4578275, 0.40821073],
|
||||
"norm_std": [0.26862954, 0.26130258, 0.27577711]
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conv, tokenizer, err := LoadModelMetadata(os.DirFS(tempDir))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := conv.(*nemotronHNanoVLModel); !ok {
|
||||
t.Fatalf("unexpected converter type: %T", conv)
|
||||
}
|
||||
|
||||
kv := conv.KV(tokenizer)
|
||||
if got, want := kv["general.architecture"], "nemotron_h_omni"; got != want {
|
||||
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["context_length"], uint32(131072); got != want {
|
||||
t.Fatalf("unexpected context length: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.block_count"], uint32(32); got != want {
|
||||
t.Fatalf("unexpected vision block count: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.image_size"], uint32(512); got != want {
|
||||
t.Fatalf("unexpected vision image size: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.projector.scale_factor"], uint32(2); got != want {
|
||||
t.Fatalf("unexpected projector scale factor: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["audio.block_count"], uint32(24); got != want {
|
||||
t.Fatalf("unexpected audio block count: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["audio.sound_token_id"], uint32(27); got != want {
|
||||
t.Fatalf("unexpected audio token id: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["audio.subsampling_factor"], uint32(8); got != want {
|
||||
t.Fatalf("unexpected audio subsampling factor: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHNanoOmniReasoningV3LoadModelMetadata(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
config := `{
|
||||
"architectures": ["NemotronH_Nano_Omni_Reasoning_V3"],
|
||||
"model_type": "NemotronH_Nano_Omni_Reasoning_V3",
|
||||
"max_sequence_length": 131072,
|
||||
"force_image_size": 512,
|
||||
"downsample_ratio": 0.5,
|
||||
"patch_size": 16,
|
||||
"img_context_token_id": 18,
|
||||
"img_context_token": "<image>",
|
||||
"img_start_token": "<img>",
|
||||
"img_end_token": "</img>",
|
||||
"sound_context_token_id": 27,
|
||||
"sound_context_token": "<so_embedding>",
|
||||
"vit_hidden_size": 1280,
|
||||
"projector_hidden_size": 4096,
|
||||
"vision_config": {
|
||||
"version": "radio_v2.5-h",
|
||||
"patch_size": 16,
|
||||
"min_num_patches": 1024,
|
||||
"max_num_patches": 13312,
|
||||
"args": {
|
||||
"min_num_patches": 1024,
|
||||
"max_num_patches": 13312
|
||||
}
|
||||
},
|
||||
"sound_config": {
|
||||
"model_type": "parakeet",
|
||||
"hidden_size": 1024,
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 24,
|
||||
"intermediate_size": 4096,
|
||||
"conv_kernel_size": 9,
|
||||
"subsampling_conv_channels": 256,
|
||||
"subsampling_conv_kernel_size": 3,
|
||||
"subsampling_conv_stride": 2,
|
||||
"subsampling_factor": 8,
|
||||
"num_mel_bins": 128,
|
||||
"projection_hidden_size": 4096,
|
||||
"sampling_rate": 16000
|
||||
},
|
||||
"llm_config": {
|
||||
"architectures": ["NemotronHForCausalLM"],
|
||||
"model_type": "nemotron_h",
|
||||
"num_hidden_layers": 4,
|
||||
"hidden_size": 512,
|
||||
"max_position_embeddings": 262144,
|
||||
"num_attention_heads": 8,
|
||||
"num_key_value_heads": 2,
|
||||
"head_dim": 64,
|
||||
"layer_norm_epsilon": 1e-5,
|
||||
"conv_kernel": 4,
|
||||
"ssm_state_size": 128,
|
||||
"mamba_num_heads": 16,
|
||||
"mamba_head_dim": 32,
|
||||
"n_groups": 8,
|
||||
"hybrid_override_pattern": "ME*M",
|
||||
"n_routed_experts": 16,
|
||||
"num_experts_per_tok": 4,
|
||||
"moe_intermediate_size": 256
|
||||
}
|
||||
}`
|
||||
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conv, tokenizer, err := LoadModelMetadata(os.DirFS(tempDir))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := conv.(*nemotronHNanoVLModel); !ok {
|
||||
t.Fatalf("unexpected converter type: %T", conv)
|
||||
}
|
||||
|
||||
kv := conv.KV(tokenizer)
|
||||
if got, want := kv["general.architecture"], "nemotron_h_omni"; got != want {
|
||||
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.block_count"], uint32(32); got != want {
|
||||
t.Fatalf("unexpected vision block count: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.min_num_patches"], uint32(1024); got != want {
|
||||
t.Fatalf("unexpected vision min patches: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.max_num_patches"], uint32(13312); got != want {
|
||||
t.Fatalf("unexpected vision max patches: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["audio.block_count"], uint32(24); got != want {
|
||||
t.Fatalf("unexpected audio block count: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := kv["audio.sound_token_id"], uint32(27); got != want {
|
||||
t.Fatalf("unexpected audio token id: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHNanoVLTensorsRetainVisionAndAudio(t *testing.T) {
|
||||
m := &nemotronHNanoVLModel{
|
||||
LLMConfig: nemotronHModel{NGroups: 8},
|
||||
}
|
||||
|
||||
in := []Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_a",
|
||||
shape: []uint64{4},
|
||||
data: []float32{0, 1, 2, 3},
|
||||
},
|
||||
&fakeTensor{name: "v.blk.0.attn_qkv.weight", shape: []uint64{3840, 1280}},
|
||||
&fakeTensor{name: "v.position_embd", shape: []uint64{1, 16384, 1280}},
|
||||
&fakeTensor{name: "v.cls_embd", shape: []uint64{10, 1280}},
|
||||
&fakeTensor{name: "mm.norm.weight", shape: []uint64{5120}},
|
||||
&fakeTensor{name: "a.feature_extractor.fb", shape: []uint64{1, 128, 257}},
|
||||
&fakeTensor{name: "a.subsampling.dw1.weight", shape: []uint64{256, 1, 3, 3}},
|
||||
&fakeTensor{name: "a.blk.0.conv_dw.weight", shape: []uint64{1024, 1, 9}},
|
||||
&fakeTensor{name: "a.blk.0.conv_pw1.weight", shape: []uint64{2048, 1024, 1}},
|
||||
&fakeTensor{name: "a.blk.0.conv_bn.num_batches_tracked", shape: []uint64{1}},
|
||||
&fakeTensor{name: "mm.a.1.weight", shape: []uint64{4096, 1024}},
|
||||
}
|
||||
|
||||
out := m.Tensors(in)
|
||||
got := map[string][]uint64{}
|
||||
for _, tns := range out {
|
||||
got[tns.Name] = tns.Shape
|
||||
}
|
||||
|
||||
for _, name := range []string{
|
||||
"blk.0.ssm_a",
|
||||
"v.blk.0.attn_q.weight",
|
||||
"v.blk.0.attn_k.weight",
|
||||
"v.blk.0.attn_v.weight",
|
||||
"v.position_embd",
|
||||
"v.cls_embd",
|
||||
"mm.norm.weight",
|
||||
"a.feature_extractor.fb",
|
||||
"a.subsampling.dw1.weight",
|
||||
"a.blk.0.conv_dw.weight",
|
||||
"a.blk.0.conv_pw1.weight",
|
||||
"mm.a.1.weight",
|
||||
} {
|
||||
if _, ok := got[name]; !ok {
|
||||
t.Fatalf("expected tensor %q in output", name)
|
||||
}
|
||||
}
|
||||
|
||||
if gotShape, want := got["blk.0.ssm_a"], []uint64{4, 1}; !slices.Equal(gotShape, want) {
|
||||
t.Fatalf("unexpected ssm_a shape: got %v want %v", gotShape, want)
|
||||
}
|
||||
if gotShape, want := got["v.position_embd"], []uint64{16384, 1280}; !slices.Equal(gotShape, want) {
|
||||
t.Fatalf("unexpected position embedding shape: got %v want %v", gotShape, want)
|
||||
}
|
||||
if gotShape, want := got["a.blk.0.conv_dw.weight"], []uint64{1024, 9}; !slices.Equal(gotShape, want) {
|
||||
t.Fatalf("unexpected audio conv_dw shape: got %v want %v", gotShape, want)
|
||||
}
|
||||
if gotShape, want := got["a.blk.0.conv_pw1.weight"], []uint64{2048, 1024}; !slices.Equal(gotShape, want) {
|
||||
t.Fatalf("unexpected audio conv_pw1 shape: got %v want %v", gotShape, want)
|
||||
}
|
||||
if _, ok := got["a.blk.0.conv_bn.num_batches_tracked"]; ok {
|
||||
t.Fatal("audio batchnorm num_batches_tracked should be omitted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHNanoVLReplacements(t *testing.T) {
|
||||
m := &nemotronHNanoVLModel{}
|
||||
r := strings.NewReplacer(m.Replacements()...)
|
||||
|
||||
if got, want := r.Replace("language_model.backbone.layers.1.mixer.fc1_latent_proj.weight"), "blk.1.ffn_latent_in.weight"; got != want {
|
||||
t.Fatalf("unexpected fc1 replacement: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := r.Replace("language_model.lm_head.weight"), "output.weight"; got != want {
|
||||
t.Fatalf("unexpected lm_head replacement: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := r.Replace("vision_model.radio_model.model.blocks.0.attn.qkv.weight"), "v.blk.0.attn_qkv.weight"; got != want {
|
||||
t.Fatalf("unexpected vision replacement: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := r.Replace("mlp1.1.weight"), "mm.1.weight"; got != want {
|
||||
t.Fatalf("unexpected projector replacement: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := r.Replace("sound_encoder.encoder.layers.0.self_attn.q_proj.weight"), "a.blk.0.attn_q.weight"; got != want {
|
||||
t.Fatalf("unexpected audio q_proj replacement: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := r.Replace("sound_encoder.encoder.layers.0.conv.pointwise_conv1.weight"), "a.blk.0.conv_pw1.weight"; got != want {
|
||||
t.Fatalf("unexpected audio conv replacement: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := r.Replace("sound_projection.linear2.weight"), "mm.a.2.weight"; got != want {
|
||||
t.Fatalf("unexpected audio projector replacement: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHReplacementsLatentProjections(t *testing.T) {
|
||||
m := &nemotronHModel{}
|
||||
r := strings.NewReplacer(m.Replacements()...)
|
||||
|
||||
@@ -42,8 +42,10 @@ func (t tensorBase) Kind() uint32 {
|
||||
strings.HasSuffix(t.name, ".bias") ||
|
||||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
|
||||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
|
||||
strings.HasPrefix(t.name, "a.conv1d.") || // audio SSCP conv weights must be F32 for im2col
|
||||
strings.Contains(t.name, ".conv_dw.") || // audio depthwise conv weights must be F32
|
||||
strings.HasPrefix(t.name, "a.feature_extractor.") || // audio feature-extractor constants are read with BackendGet and must be real F32 values
|
||||
strings.HasPrefix(t.name, "a.conv1d.") || // audio SSCP conv weights are kept F32 for im2col; this likely slows audio and should be revisited
|
||||
strings.HasPrefix(t.name, "a.subsampling.") || // audio Parakeet subsampling weights are kept F32 for conv/linear stability; this likely slows audio and should be revisited
|
||||
strings.Contains(t.name, ".conv_dw.") || // audio depthwise conv weights are kept F32; this likely slows audio and should be revisited
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.position_embd.weight" ||
|
||||
|
||||
@@ -5,10 +5,12 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"maps"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@@ -23,6 +25,11 @@ type safetensorMetadata struct {
|
||||
}
|
||||
|
||||
func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]Tensor, error) {
|
||||
fp8Block, err := safetensorsFP8BlockSize(fsys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ts []Tensor
|
||||
for _, p := range ps {
|
||||
f, err := fsys.Open(p)
|
||||
@@ -50,24 +57,47 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
|
||||
|
||||
names := make(map[string]struct{}, len(keys))
|
||||
|
||||
fp8Scales, err := collectSafetensorsFP8Scales(n, headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if value := headers[key]; value.Type != "" {
|
||||
if _, ok := fp8Scales.consumed[key]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Scalar tensors (e.g. clipped linear min/max) are 0-dim in safetensors.
|
||||
// Promote them to 1-dim so they can be stored in GGUF.
|
||||
if len(value.Shape) == 0 {
|
||||
value.Shape = []uint64{1}
|
||||
}
|
||||
|
||||
var scale *safetensorScale
|
||||
if value.Type == "F8_E4M3" {
|
||||
if !fp8Block.ok {
|
||||
return nil, fmt.Errorf("missing fp8 block size metadata for tensor %q", key)
|
||||
}
|
||||
scale = fp8Scales.byWeight[key]
|
||||
if scale == nil {
|
||||
return nil, fmt.Errorf("missing fp8 scale companion for tensor %q", key)
|
||||
}
|
||||
}
|
||||
|
||||
ggufName := replacer.Replace(key)
|
||||
if _, ok := names[ggufName]; ok {
|
||||
return nil, fmt.Errorf("duplicate tensor name '%s' was found for this model", ggufName)
|
||||
}
|
||||
names[ggufName] = struct{}{}
|
||||
ts = append(ts, safetensor{
|
||||
fs: fsys,
|
||||
path: p,
|
||||
dtype: value.Type,
|
||||
offset: safetensorsPad(n, value.Offsets[0]),
|
||||
size: safetensorsPad(n, value.Offsets[1]) - safetensorsPad(n, value.Offsets[0]),
|
||||
fs: fsys,
|
||||
path: p,
|
||||
dtype: value.Type,
|
||||
offset: safetensorsPad(n, value.Offsets[0]),
|
||||
size: safetensorsPad(n, value.Offsets[1]) - safetensorsPad(n, value.Offsets[0]),
|
||||
scale: scale,
|
||||
fp8Block: fp8Block,
|
||||
tensorBase: &tensorBase{
|
||||
name: ggufName,
|
||||
shape: value.Shape,
|
||||
@@ -85,12 +115,22 @@ func safetensorsPad(n, offset int64) int64 {
|
||||
return 8 + n + offset
|
||||
}
|
||||
|
||||
type safetensor struct {
|
||||
fs fs.FS
|
||||
path string
|
||||
type safetensorScale struct {
|
||||
name string
|
||||
dtype string
|
||||
shape []uint64
|
||||
offset int64
|
||||
size int64
|
||||
}
|
||||
|
||||
type safetensor struct {
|
||||
fs fs.FS
|
||||
path string
|
||||
dtype string
|
||||
offset int64
|
||||
size int64
|
||||
scale *safetensorScale
|
||||
fp8Block safetensorFP8BlockSize
|
||||
*tensorBase
|
||||
}
|
||||
|
||||
@@ -104,17 +144,26 @@ func (st safetensor) Kind() uint32 {
|
||||
kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
if st.dtype == "F8_E4M3" && kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
return kind
|
||||
}
|
||||
|
||||
func (st safetensor) SourceDType() string {
|
||||
return st.dtype
|
||||
}
|
||||
|
||||
func (st safetensor) Clone() Tensor {
|
||||
return &safetensor{
|
||||
fs: st.fs,
|
||||
path: st.path,
|
||||
dtype: st.dtype,
|
||||
offset: st.offset,
|
||||
size: st.size,
|
||||
fs: st.fs,
|
||||
path: st.path,
|
||||
dtype: st.dtype,
|
||||
offset: st.offset,
|
||||
size: st.size,
|
||||
scale: st.scale.Clone(),
|
||||
fp8Block: st.fp8Block,
|
||||
tensorBase: &tensorBase{
|
||||
name: st.name,
|
||||
repacker: st.repacker,
|
||||
@@ -123,6 +172,19 @@ func (st safetensor) Clone() Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *safetensorScale) Clone() *safetensorScale {
|
||||
if ss == nil {
|
||||
return nil
|
||||
}
|
||||
return &safetensorScale{
|
||||
name: ss.name,
|
||||
dtype: ss.dtype,
|
||||
shape: slices.Clone(ss.shape),
|
||||
offset: ss.offset,
|
||||
size: ss.size,
|
||||
}
|
||||
}
|
||||
|
||||
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
f, err := st.fs.Open(st.path)
|
||||
if err != nil {
|
||||
@@ -180,6 +242,16 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
}
|
||||
|
||||
f32s = bfloat16.DecodeFloat32(u8s)
|
||||
case "F8_E4M3":
|
||||
u8s := make([]uint8, st.size)
|
||||
if err = binary.Read(br, binary.LittleEndian, u8s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
f32s, err = st.decodeFP8E4M3(u8s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown data type: %s", st.dtype)
|
||||
}
|
||||
@@ -208,3 +280,334 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
||||
}
|
||||
}
|
||||
|
||||
type safetensorsFP8Scales struct {
|
||||
byWeight map[string]*safetensorScale
|
||||
consumed map[string]struct{}
|
||||
}
|
||||
|
||||
func collectSafetensorsFP8Scales(n int64, headers map[string]safetensorMetadata) (safetensorsFP8Scales, error) {
|
||||
scales := safetensorsFP8Scales{
|
||||
byWeight: make(map[string]*safetensorScale),
|
||||
consumed: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
for key, value := range headers {
|
||||
if value.Type != "F8_E4M3" {
|
||||
continue
|
||||
}
|
||||
|
||||
scaleKey, scaleValue, ok, err := safetensorsFP8Scale(key, headers)
|
||||
if err != nil {
|
||||
return safetensorsFP8Scales{}, err
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := scales.consumed[scaleKey]; ok {
|
||||
return safetensorsFP8Scales{}, fmt.Errorf("fp8 scale companion %q is used by multiple tensors", scaleKey)
|
||||
}
|
||||
|
||||
scales.byWeight[key] = &safetensorScale{
|
||||
name: scaleKey,
|
||||
dtype: scaleValue.Type,
|
||||
shape: slices.Clone(scaleValue.Shape),
|
||||
offset: safetensorsPad(n, scaleValue.Offsets[0]),
|
||||
size: safetensorsPad(n, scaleValue.Offsets[1]) - safetensorsPad(n, scaleValue.Offsets[0]),
|
||||
}
|
||||
scales.consumed[scaleKey] = struct{}{}
|
||||
}
|
||||
|
||||
return scales, nil
|
||||
}
|
||||
|
||||
func safetensorsFP8Scale(key string, headers map[string]safetensorMetadata) (string, safetensorMetadata, bool, error) {
|
||||
candidates := safetensorsFP8ScaleCandidates(key)
|
||||
|
||||
var scaleKey string
|
||||
var scaleValue safetensorMetadata
|
||||
if strings.HasSuffix(key, ".weight") {
|
||||
// Keep support for compressed-tensors exports that place the scale name
|
||||
// between the module path and weight suffix.
|
||||
base := strings.TrimSuffix(key, ".weight")
|
||||
candidates = appendUnique(candidates, base+".weight_scale")
|
||||
candidates = appendUnique(candidates, base+".weight_scale_inv")
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if value, ok := headers[candidate]; ok && value.Type != "" {
|
||||
if scaleKey != "" {
|
||||
return "", safetensorMetadata{}, false, fmt.Errorf("multiple fp8 scale companions for tensor %q: %q and %q", key, scaleKey, candidate)
|
||||
}
|
||||
scaleKey = candidate
|
||||
scaleValue = value
|
||||
}
|
||||
}
|
||||
if scaleKey == "" {
|
||||
return "", safetensorMetadata{}, false, nil
|
||||
}
|
||||
|
||||
return scaleKey, scaleValue, true, nil
|
||||
}
|
||||
|
||||
func safetensorsFP8ScaleCandidates(key string) []string {
|
||||
var candidates []string
|
||||
candidates = appendUnique(candidates, key+"_scale")
|
||||
candidates = appendUnique(candidates, key+"_scale_inv")
|
||||
candidates = appendUnique(candidates, key+".scale")
|
||||
candidates = appendUnique(candidates, key+".scale_inv")
|
||||
return candidates
|
||||
}
|
||||
|
||||
func appendUnique(values []string, value string) []string {
|
||||
if !slices.Contains(values, value) {
|
||||
values = append(values, value)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
type safetensorFP8BlockSize struct {
|
||||
rows int
|
||||
cols int
|
||||
ok bool
|
||||
}
|
||||
|
||||
type safetensorsSourceQuantization struct {
|
||||
QuantMethod string `json:"quant_method"`
|
||||
Format string `json:"format"`
|
||||
WeightBlockSize []int `json:"weight_block_size"`
|
||||
ConfigGroups map[string]struct {
|
||||
Format string `json:"format"`
|
||||
Weights struct {
|
||||
BlockStructure []int `json:"block_structure"`
|
||||
NumBits int `json:"num_bits"`
|
||||
Type string `json:"type"`
|
||||
} `json:"weights"`
|
||||
} `json:"config_groups"`
|
||||
}
|
||||
|
||||
type safetensorsModelConfig struct {
|
||||
Quantization safetensorsSourceQuantization `json:"quantization"`
|
||||
QuantizationConfig safetensorsSourceQuantization `json:"quantization_config"`
|
||||
CompressionConfig safetensorsSourceQuantization `json:"compression_config"`
|
||||
TextConfig struct {
|
||||
Quantization safetensorsSourceQuantization `json:"quantization"`
|
||||
QuantizationConfig safetensorsSourceQuantization `json:"quantization_config"`
|
||||
CompressionConfig safetensorsSourceQuantization `json:"compression_config"`
|
||||
} `json:"text_config"`
|
||||
}
|
||||
|
||||
func safetensorsFP8BlockSize(fsys fs.FS) (safetensorFP8BlockSize, error) {
|
||||
bts, err := fs.ReadFile(fsys, "config.json")
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return safetensorFP8BlockSize{}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return safetensorFP8BlockSize{}, err
|
||||
}
|
||||
bts = sanitizeNonFiniteJSON(bts)
|
||||
|
||||
var cfg safetensorsModelConfig
|
||||
if err := json.Unmarshal(bts, &cfg); err != nil {
|
||||
return safetensorFP8BlockSize{}, fmt.Errorf("parse config.json fp8 metadata: %w", err)
|
||||
}
|
||||
|
||||
var blocks []safetensorFP8BlockSize
|
||||
for _, q := range []safetensorsSourceQuantization{
|
||||
cfg.Quantization,
|
||||
cfg.QuantizationConfig,
|
||||
cfg.CompressionConfig,
|
||||
cfg.TextConfig.Quantization,
|
||||
cfg.TextConfig.QuantizationConfig,
|
||||
cfg.TextConfig.CompressionConfig,
|
||||
} {
|
||||
if strings.EqualFold(q.QuantMethod, "fp8") && len(q.WeightBlockSize) == 2 {
|
||||
block, err := newSafetensorFP8BlockSize(q.WeightBlockSize[0], q.WeightBlockSize[1])
|
||||
if err != nil {
|
||||
return safetensorFP8BlockSize{}, err
|
||||
}
|
||||
blocks = append(blocks, block)
|
||||
}
|
||||
|
||||
if !strings.EqualFold(q.QuantMethod, "compressed-tensors") && !strings.EqualFold(q.Format, "float-quantized") {
|
||||
continue
|
||||
}
|
||||
for _, group := range q.ConfigGroups {
|
||||
if !strings.EqualFold(group.Format, "float-quantized") ||
|
||||
group.Weights.NumBits != 8 ||
|
||||
!strings.EqualFold(group.Weights.Type, "float") ||
|
||||
len(group.Weights.BlockStructure) != 2 {
|
||||
continue
|
||||
}
|
||||
block, err := newSafetensorFP8BlockSize(group.Weights.BlockStructure[0], group.Weights.BlockStructure[1])
|
||||
if err != nil {
|
||||
return safetensorFP8BlockSize{}, err
|
||||
}
|
||||
blocks = append(blocks, block)
|
||||
}
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
return safetensorFP8BlockSize{}, nil
|
||||
}
|
||||
|
||||
block := blocks[0]
|
||||
for _, other := range blocks[1:] {
|
||||
if other.rows != block.rows || other.cols != block.cols {
|
||||
return safetensorFP8BlockSize{}, fmt.Errorf("multiple fp8 block sizes in config.json: %dx%d and %dx%d", block.rows, block.cols, other.rows, other.cols)
|
||||
}
|
||||
}
|
||||
return block, nil
|
||||
}
|
||||
|
||||
func newSafetensorFP8BlockSize(rows, cols int) (safetensorFP8BlockSize, error) {
|
||||
if rows <= 0 || cols <= 0 {
|
||||
return safetensorFP8BlockSize{}, fmt.Errorf("invalid fp8 block size %dx%d", rows, cols)
|
||||
}
|
||||
return safetensorFP8BlockSize{rows: rows, cols: cols, ok: true}, nil
|
||||
}
|
||||
|
||||
func (st safetensor) decodeFP8E4M3(data []byte) ([]float32, error) {
|
||||
if st.scale == nil {
|
||||
return nil, fmt.Errorf("missing fp8 scale companion for tensor %q", st.name)
|
||||
}
|
||||
if !st.fp8Block.ok {
|
||||
return nil, fmt.Errorf("missing fp8 block size metadata for tensor %q", st.name)
|
||||
}
|
||||
if len(st.shape) != 2 {
|
||||
return nil, fmt.Errorf("expected 2D fp8 tensor %q, got shape %v", st.name, st.shape)
|
||||
}
|
||||
|
||||
rows, cols := int(st.shape[0]), int(st.shape[1])
|
||||
if rows < 0 || cols < 0 || rows*cols != len(data) {
|
||||
return nil, fmt.Errorf("fp8 tensor %q shape %v does not match %d bytes", st.name, st.shape, len(data))
|
||||
}
|
||||
|
||||
scale, err := st.readScale()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(st.scale.shape) != 2 {
|
||||
return nil, fmt.Errorf("expected 2D fp8 scale tensor %q, got shape %v", st.scale.name, st.scale.shape)
|
||||
}
|
||||
|
||||
blockRows := st.fp8Block.rows
|
||||
blockCols := st.fp8Block.cols
|
||||
scaleRows, scaleCols := int(st.scale.shape[0]), int(st.scale.shape[1])
|
||||
expectedRows := (rows + blockRows - 1) / blockRows
|
||||
expectedCols := (cols + blockCols - 1) / blockCols
|
||||
if scaleRows != expectedRows || scaleCols != expectedCols {
|
||||
return nil, fmt.Errorf("unexpected fp8 scale shape %v for tensor %q shape %v; want [%d %d]", st.scale.shape, st.name, st.shape, expectedRows, expectedCols)
|
||||
}
|
||||
if len(scale) != scaleRows*scaleCols {
|
||||
return nil, fmt.Errorf("fp8 scale tensor %q shape %v does not match decoded length %d", st.scale.name, st.scale.shape, len(scale))
|
||||
}
|
||||
|
||||
f32s := make([]float32, len(data))
|
||||
for r := range rows {
|
||||
scaleRow := r / blockRows
|
||||
rowOffset := r * cols
|
||||
for c := range cols {
|
||||
f32s[rowOffset+c] = decodeFloat8E4M3FN(data[rowOffset+c]) * scale[scaleRow*scaleCols+c/blockCols]
|
||||
}
|
||||
}
|
||||
|
||||
return f32s, nil
|
||||
}
|
||||
|
||||
func (st safetensor) readScale() ([]float32, error) {
|
||||
r, err := st.sectionReader(st.scale.offset, st.scale.size)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read fp8 scale tensor %q: %w", st.scale.name, err)
|
||||
}
|
||||
if closer, ok := r.(io.Closer); ok {
|
||||
defer closer.Close()
|
||||
}
|
||||
br := bufio.NewReaderSize(r, min(32<<10, int(st.scale.size)))
|
||||
|
||||
switch st.scale.dtype {
|
||||
case "F32":
|
||||
f32s := make([]float32, st.scale.size/4)
|
||||
if err := binary.Read(br, binary.LittleEndian, f32s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f32s, nil
|
||||
case "F16":
|
||||
u16s := make([]uint16, st.scale.size/2)
|
||||
if err := binary.Read(br, binary.LittleEndian, u16s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f32s := make([]float32, len(u16s))
|
||||
for i := range u16s {
|
||||
f32s[i] = float16.Frombits(u16s[i]).Float32()
|
||||
}
|
||||
return f32s, nil
|
||||
case "BF16":
|
||||
u8s := make([]uint8, st.scale.size)
|
||||
if err := binary.Read(br, binary.LittleEndian, u8s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bfloat16.DecodeFloat32(u8s), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported fp8 scale dtype %q for tensor %q", st.scale.dtype, st.scale.name)
|
||||
}
|
||||
}
|
||||
|
||||
func (st safetensor) sectionReader(offset, size int64) (io.Reader, error) {
|
||||
f, err := st.fs.Open(st.path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if readerAt, ok := f.(io.ReaderAt); ok {
|
||||
return &readCloserReader{
|
||||
Reader: io.NewSectionReader(readerAt, offset, size),
|
||||
Closer: f,
|
||||
}, nil
|
||||
}
|
||||
if seeker, ok := f.(io.Seeker); ok {
|
||||
if _, err := seeker.Seek(offset, io.SeekStart); err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &readCloserReader{
|
||||
Reader: io.LimitReader(f, size),
|
||||
Closer: f,
|
||||
}, nil
|
||||
}
|
||||
if _, err := io.CopyN(io.Discard, f, offset); err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &readCloserReader{
|
||||
Reader: io.LimitReader(f, size),
|
||||
Closer: f,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type readCloserReader struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}
|
||||
|
||||
func decodeFloat8E4M3FN(v byte) float32 {
|
||||
sign := float32(1)
|
||||
if v&0x80 != 0 {
|
||||
sign = -1
|
||||
}
|
||||
|
||||
exp := int((v >> 3) & 0x0f)
|
||||
mant := int(v & 0x07)
|
||||
if exp == 0 {
|
||||
if mant == 0 {
|
||||
return 0 * sign
|
||||
}
|
||||
return sign * float32(math.Ldexp(float64(mant)/8, -6))
|
||||
}
|
||||
if exp == 0x0f && mant == 0x07 {
|
||||
return float32(math.NaN())
|
||||
}
|
||||
|
||||
return sign * float32(math.Ldexp(1+float64(mant)/8, exp-7))
|
||||
}
|
||||
|
||||
@@ -3,8 +3,10 @@ package convert
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
@@ -231,6 +233,222 @@ func TestSafetensors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafetensorWriteToFP8E4M3(t *testing.T) {
|
||||
root, err := os.OpenRoot(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
path := filepath.Base(t.Name())
|
||||
f, err := root.Create(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// E4M3FN encodings for 1.0, 2.0, 0.5, and -1.0.
|
||||
if _, err := f.Write([]byte{0x38, 0x40, 0x30, 0xb8}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := f.Write(bfloat16.EncodeFloat32([]float32{2})); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
st := safetensor{
|
||||
fs: root.FS(),
|
||||
path: path,
|
||||
dtype: "F8_E4M3",
|
||||
offset: 0,
|
||||
size: 4,
|
||||
fp8Block: safetensorFP8BlockSize{rows: 128, cols: 128, ok: true},
|
||||
scale: &safetensorScale{
|
||||
name: "linear.weight_scale",
|
||||
dtype: "BF16",
|
||||
shape: []uint64{1, 1},
|
||||
offset: 4,
|
||||
size: 2,
|
||||
},
|
||||
tensorBase: &tensorBase{
|
||||
name: "linear.weight",
|
||||
shape: []uint64{2, 2},
|
||||
},
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := st.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
want := bfloat16.EncodeFloat32([]float32{2, 4, 1, -2})
|
||||
if diff := cmp.Diff(want, b.Bytes()); diff != "" {
|
||||
t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafetensorWriteToFP8E4M3UsesConfiguredBlockSize(t *testing.T) {
|
||||
root, err := os.OpenRoot(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
path := filepath.Base(t.Name())
|
||||
f, err := root.Create(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := f.Write(bytes.Repeat([]byte{0x38}, 12)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := f.Write(bfloat16.EncodeFloat32([]float32{1, 2, 3, 4})); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
st := safetensor{
|
||||
fs: root.FS(),
|
||||
path: path,
|
||||
dtype: "F8_E4M3",
|
||||
offset: 0,
|
||||
size: 12,
|
||||
fp8Block: safetensorFP8BlockSize{rows: 2, cols: 3, ok: true},
|
||||
scale: &safetensorScale{
|
||||
name: "linear.weight_scale",
|
||||
dtype: "BF16",
|
||||
shape: []uint64{2, 2},
|
||||
offset: 12,
|
||||
size: 8,
|
||||
},
|
||||
tensorBase: &tensorBase{
|
||||
name: "linear.weight",
|
||||
shape: []uint64{3, 4},
|
||||
},
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := st.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
want := bfloat16.EncodeFloat32([]float32{
|
||||
1, 1, 1, 2,
|
||||
1, 1, 1, 2,
|
||||
3, 3, 3, 4,
|
||||
})
|
||||
if diff := cmp.Diff(want, b.Bytes()); diff != "" {
|
||||
t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorsConsumesFP8ScaleCompanion(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
generateSafetensorTestData(t, tempDir, map[string]*tensorData{
|
||||
"linear.weight": {
|
||||
Offsets: []int{0, 4},
|
||||
Type: "F8_E4M3",
|
||||
Shape: []int{2, 2},
|
||||
},
|
||||
"linear.weight_scale": {
|
||||
Offsets: []int{4, 6},
|
||||
Type: "BF16",
|
||||
Shape: []int{1, 1},
|
||||
},
|
||||
})
|
||||
writeFP8BlockConfig(t, tempDir, 128, 128)
|
||||
|
||||
tensors, err := parseSafetensors(os.DirFS(tempDir), strings.NewReplacer(), "model-00001-of-00001.safetensors")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(tensors) != 1 {
|
||||
t.Fatalf("expected one tensor, got %d", len(tensors))
|
||||
}
|
||||
if got := tensors[0].Name(); got != "linear.weight" {
|
||||
t.Fatalf("unexpected tensor name %q", got)
|
||||
}
|
||||
if got := tensors[0].Kind(); got != tensorKindBF16 {
|
||||
t.Fatalf("unexpected fp8 converted kind %d, want %d", got, tensorKindBF16)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorsRejectsFP8WithoutBlockMetadata(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
generateSafetensorTestData(t, tempDir, map[string]*tensorData{
|
||||
"linear.weight": {
|
||||
Offsets: []int{0, 4},
|
||||
Type: "F8_E4M3",
|
||||
Shape: []int{2, 2},
|
||||
},
|
||||
"linear.weight_scale": {
|
||||
Offsets: []int{4, 6},
|
||||
Type: "BF16",
|
||||
Shape: []int{1, 1},
|
||||
},
|
||||
})
|
||||
|
||||
_, err := parseSafetensors(os.DirFS(tempDir), strings.NewReplacer(), "model-00001-of-00001.safetensors")
|
||||
if err == nil || !strings.Contains(err.Error(), "missing fp8 block size metadata") {
|
||||
t.Fatalf("expected missing fp8 block size metadata error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorsRejectsAmbiguousFP8ScaleCompanion(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
generateSafetensorTestData(t, tempDir, map[string]*tensorData{
|
||||
"linear.weight": {
|
||||
Offsets: []int{0, 4},
|
||||
Type: "F8_E4M3",
|
||||
Shape: []int{2, 2},
|
||||
},
|
||||
"linear.weight_scale": {
|
||||
Offsets: []int{4, 6},
|
||||
Type: "BF16",
|
||||
Shape: []int{1, 1},
|
||||
},
|
||||
"linear.weight.scale": {
|
||||
Offsets: []int{6, 8},
|
||||
Type: "BF16",
|
||||
Shape: []int{1, 1},
|
||||
},
|
||||
})
|
||||
writeFP8BlockConfig(t, tempDir, 128, 128)
|
||||
|
||||
_, err := parseSafetensors(os.DirFS(tempDir), strings.NewReplacer(), "model-00001-of-00001.safetensors")
|
||||
if err == nil || !strings.Contains(err.Error(), "multiple fp8 scale companions") {
|
||||
t.Fatalf("expected ambiguous fp8 scale companion error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeFP8BlockConfig(t *testing.T, dir string, rows, cols int) {
|
||||
t.Helper()
|
||||
|
||||
config := fmt.Sprintf(`{
|
||||
"architectures": ["GenericForCausalLM"],
|
||||
"compression_config": {
|
||||
"format": "float-quantized",
|
||||
"config_groups": {
|
||||
"group_0": {
|
||||
"format": "float-quantized",
|
||||
"weights": {
|
||||
"type": "float",
|
||||
"num_bits": 8,
|
||||
"block_structure": [%d, %d]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, rows, cols)
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafetensorKind(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -259,6 +477,17 @@ func TestSafetensorKind(t *testing.T) {
|
||||
},
|
||||
expected: tensorKindFP16,
|
||||
},
|
||||
{
|
||||
name: "BF16 audio feature extractor constants should return FP32",
|
||||
st: safetensor{
|
||||
tensorBase: &tensorBase{
|
||||
name: "a.feature_extractor.fb",
|
||||
shape: []uint64{1, 128, 257},
|
||||
},
|
||||
dtype: "BF16",
|
||||
},
|
||||
expected: tensorKindFP32,
|
||||
},
|
||||
{
|
||||
name: "BF16 dtype with FP32 base kind should return FP32",
|
||||
st: safetensor{
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"iter"
|
||||
"maps"
|
||||
"path"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -153,3 +154,54 @@ func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func sourceTensorKV(ts []*ggml.Tensor) KV {
|
||||
sourceFP8 := make(map[string]struct{})
|
||||
for _, t := range ts {
|
||||
if writerSourceDType(t.WriterTo) == "F8_E4M3" {
|
||||
sourceFP8[t.Name] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(sourceFP8) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return KV{
|
||||
"source_quantization": "hf_fp8",
|
||||
"source_fp8_tensors": slices.Sorted(maps.Keys(sourceFP8)),
|
||||
}
|
||||
}
|
||||
|
||||
type sourceDTypeTensor interface {
|
||||
SourceDType() string
|
||||
}
|
||||
|
||||
func writerSourceDType(w io.WriterTo) string {
|
||||
switch w := w.(type) {
|
||||
case sourceDTypeTensor:
|
||||
return w.SourceDType()
|
||||
case mergeGroup:
|
||||
if len(w) == 0 {
|
||||
return ""
|
||||
}
|
||||
dtype := sourceDType(w[0])
|
||||
if dtype == "" {
|
||||
return ""
|
||||
}
|
||||
for _, t := range w[1:] {
|
||||
if sourceDType(t) != dtype {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return dtype
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func sourceDType(t Tensor) string {
|
||||
if t, ok := t.(sourceDTypeTensor); ok {
|
||||
return t.SourceDType()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -21,7 +21,8 @@ type fakeTensor struct {
|
||||
shape []uint64
|
||||
data []float32
|
||||
|
||||
repacker Repacker
|
||||
sourceDType string
|
||||
repacker Repacker
|
||||
}
|
||||
|
||||
func (f fakeTensor) Name() string {
|
||||
@@ -36,16 +37,21 @@ func (f fakeTensor) Kind() uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (f fakeTensor) SourceDType() string {
|
||||
return f.sourceDType
|
||||
}
|
||||
|
||||
func (f *fakeTensor) SetRepacker(fn Repacker) {
|
||||
f.repacker = fn
|
||||
}
|
||||
|
||||
func (f fakeTensor) Clone() Tensor {
|
||||
return &fakeTensor{
|
||||
name: f.name,
|
||||
shape: slices.Clone(f.shape),
|
||||
data: slices.Clone(f.data),
|
||||
repacker: f.repacker,
|
||||
name: f.name,
|
||||
shape: slices.Clone(f.shape),
|
||||
data: slices.Clone(f.data),
|
||||
sourceDType: f.sourceDType,
|
||||
repacker: f.repacker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -995,3 +1001,43 @@ func TestMergeOrder(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSourceTensorKVRecordsFP8OutputTensors(t *testing.T) {
|
||||
fp8 := &fakeTensor{name: "linear.weight", shape: []uint64{2, 2}, sourceDType: "F8_E4M3"}
|
||||
bf16 := &fakeTensor{name: "other.weight", shape: []uint64{2, 2}, sourceDType: "BF16"}
|
||||
|
||||
kv := sourceTensorKV([]*ggml.Tensor{
|
||||
{Name: "blk.0.linear.weight", WriterTo: fp8},
|
||||
{Name: "blk.0.other.weight", WriterTo: bf16},
|
||||
})
|
||||
|
||||
if got := kv["source_quantization"]; got != "hf_fp8" {
|
||||
t.Fatalf("source_quantization = %v, want hf_fp8", got)
|
||||
}
|
||||
got, ok := kv["source_fp8_tensors"].([]string)
|
||||
if !ok {
|
||||
t.Fatalf("source_fp8_tensors = %#v, want []string", kv["source_fp8_tensors"])
|
||||
}
|
||||
if diff := cmp.Diff([]string{"blk.0.linear.weight"}, got); diff != "" {
|
||||
t.Fatalf("source_fp8_tensors mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSourceTensorKVRecordsMergedFP8OutputTensors(t *testing.T) {
|
||||
fp8A := &fakeTensor{name: "expert.0.weight", shape: []uint64{2, 2}, sourceDType: "F8_E4M3"}
|
||||
fp8B := &fakeTensor{name: "expert.1.weight", shape: []uint64{2, 2}, sourceDType: "F8_E4M3"}
|
||||
bf16 := &fakeTensor{name: "expert.2.weight", shape: []uint64{2, 2}, sourceDType: "BF16"}
|
||||
|
||||
kv := sourceTensorKV([]*ggml.Tensor{
|
||||
{Name: "ffn_exps.weight", WriterTo: mergeGroup{fp8A, fp8B}},
|
||||
{Name: "mixed_exps.weight", WriterTo: mergeGroup{fp8A, bf16}},
|
||||
})
|
||||
|
||||
got, ok := kv["source_fp8_tensors"].([]string)
|
||||
if !ok {
|
||||
t.Fatalf("source_fp8_tensors = %#v, want []string", kv["source_fp8_tensors"])
|
||||
}
|
||||
if diff := cmp.Diff([]string{"ffn_exps.weight"}, got); diff != "" {
|
||||
t.Fatalf("source_fp8_tensors mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,6 +103,8 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||
t.Pre = "qwen2"
|
||||
case "00431aed57e696b747435f734d1e3b9b1bfd931a121fb5cac7129e97c181e9ba":
|
||||
t.Pre = "qwen35"
|
||||
case "b92c0824a58e1d8dc3221cf3e12c433c3a86f57e46d57229993489f0798e7702":
|
||||
t.Pre = "laguna"
|
||||
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
|
||||
// noop, empty pretokenizer
|
||||
default:
|
||||
|
||||
@@ -124,7 +124,8 @@
|
||||
"/integrations/opencode",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose",
|
||||
"/integrations/pi"
|
||||
"/integrations/pi",
|
||||
"/integrations/poolside"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -15,6 +15,7 @@ Coding assistants that can read, modify, and execute code in your projects.
|
||||
- [Droid](/integrations/droid)
|
||||
- [Goose](/integrations/goose)
|
||||
- [Pi](/integrations/pi)
|
||||
- [Poolside](/integrations/poolside)
|
||||
|
||||
## Assistants
|
||||
|
||||
|
||||
54
docs/integrations/poolside.mdx
Normal file
54
docs/integrations/poolside.mdx
Normal file
@@ -0,0 +1,54 @@
|
||||
---
|
||||
title: Poolside
|
||||
---
|
||||
|
||||
Poolside is Poolside's software agent for the terminal, built for enterprise development workflows.
|
||||
|
||||
## Install
|
||||
|
||||
Install [Poolside](https://github.com/poolsideai/pool):
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```shell
|
||||
ollama launch pool
|
||||
```
|
||||
|
||||
### Run directly with a model
|
||||
|
||||
```shell
|
||||
ollama launch pool --model kimi-k2.6:cloud
|
||||
```
|
||||
|
||||
### Pass arguments through to Poolside
|
||||
|
||||
Arguments after `--` are passed directly to Poolside:
|
||||
|
||||
```shell
|
||||
ollama launch pool -- --help
|
||||
```
|
||||
|
||||
## Manual setup
|
||||
|
||||
Poolside connects to Ollama using the OpenAI-compatible API via environment variables.
|
||||
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export POOLSIDE_STANDALONE_BASE_URL=http://localhost:11434/v1
|
||||
export POOLSIDE_API_KEY=ollama
|
||||
```
|
||||
|
||||
2. Run Poolside with an Ollama model:
|
||||
|
||||
```shell
|
||||
pool -m kimi-k2.6:cloud
|
||||
```
|
||||
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
POOLSIDE_STANDALONE_BASE_URL=http://localhost:11434/v1 POOLSIDE_API_KEY=ollama pool -m kimi-k2.6:cloud
|
||||
```
|
||||
@@ -283,10 +283,11 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"gemma3n",
|
||||
"gemma4",
|
||||
"gptoss", "gpt-oss",
|
||||
"laguna",
|
||||
"llama4",
|
||||
"mistral3",
|
||||
"mllama",
|
||||
"nemotron_h", "nemotron_h_moe",
|
||||
"nemotron_h", "nemotron_h_moe", "nemotron_h_omni",
|
||||
"nomic-bert",
|
||||
"olmo3",
|
||||
"qwen25vl",
|
||||
@@ -897,7 +898,7 @@ func (f GGML) FlashAttention() bool {
|
||||
"lfm2",
|
||||
"lfm2moe",
|
||||
"mistral3",
|
||||
"nemotron_h", "nemotron_h_moe",
|
||||
"nemotron_h", "nemotron_h_moe", "nemotron_h_omni",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen35", "qwen35moe",
|
||||
|
||||
444
model/models/laguna/model.go
Normal file
444
model/models/laguna/model.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package laguna
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
const (
|
||||
cacheTypeSWA = iota
|
||||
cacheTypeCausal
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
hiddenSize int
|
||||
headDim int
|
||||
|
||||
numHeads []int
|
||||
numKVHeads int
|
||||
|
||||
eps float32
|
||||
|
||||
slidingWindow int
|
||||
slidingWindowPattern []bool
|
||||
|
||||
fullRopeDim int
|
||||
fullRopeBase, fullRopeScale float32
|
||||
fullRopeOriginalContextLength int
|
||||
fullRopeAttentionFactor float32
|
||||
fullRopeBetaFast float32
|
||||
fullRopeBetaSlow float32
|
||||
|
||||
swaRopeDim int
|
||||
swaRopeBase, swaRopeScale float32
|
||||
numExperts, numExpertsUsed int
|
||||
normTopKProb bool
|
||||
routedScalingFactor float32
|
||||
decoderSparseStep int
|
||||
denseLayers map[int]bool
|
||||
}
|
||||
|
||||
func (o *Options) numHeadsForLayer(layer int) int {
|
||||
if layer < len(o.numHeads) && o.numHeads[layer] > 0 {
|
||||
return o.numHeads[layer]
|
||||
}
|
||||
if len(o.numHeads) > 0 && o.numHeads[0] > 0 {
|
||||
return o.numHeads[0]
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (o *Options) layerIsSliding(layer int) bool {
|
||||
return layer < len(o.slidingWindowPattern) && o.slidingWindowPattern[layer]
|
||||
}
|
||||
|
||||
func (o *Options) layerUsesMoE(layer int) bool {
|
||||
if o.numExperts == 0 || o.denseLayers[layer] {
|
||||
return false
|
||||
}
|
||||
step := o.decoderSparseStep
|
||||
if step <= 0 {
|
||||
step = 1
|
||||
}
|
||||
return (layer+1)%step == 0
|
||||
}
|
||||
|
||||
func (o *Options) applyRotaryPositionEmbeddings(ctx ml.Context, layer int, states, positions ml.Tensor) ml.Tensor {
|
||||
opts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if o.layerIsSliding(layer) {
|
||||
return nn.RoPE(ctx, states, positions, o.swaRopeDim, o.swaRopeBase, 1./o.swaRopeScale, opts...)
|
||||
}
|
||||
|
||||
opts = append(opts,
|
||||
rope.WithOriginalContextLength(o.fullRopeOriginalContextLength),
|
||||
rope.WithExtrapolationFactor(1),
|
||||
rope.WithAttentionFactor(o.fullRopeAttentionFactor),
|
||||
rope.WithBetaFast(o.fullRopeBetaFast),
|
||||
rope.WithBetaSlow(o.fullRopeBetaSlow),
|
||||
)
|
||||
return nn.RoPE(ctx, states, positions, o.fullRopeDim, o.fullRopeBase, 1./o.fullRopeScale, opts...)
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Gate *nn.Linear `gguf:"attn_g"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *Attention) Forward(ctx ml.Context, layer int, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
numHeads := opts.numHeadsForLayer(layer)
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
gate := sa.Gate.Forward(ctx, hiddenStates)
|
||||
|
||||
query = query.Reshape(ctx, opts.headDim, numHeads, batchSize)
|
||||
key = key.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
|
||||
value = value.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, layer, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, layer, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), cache)
|
||||
// Laguna applies the per-head gate softplus in float32, then casts back.
|
||||
gate = gate.Cast(ctx, ml.DTypeF32).Softplus(ctx).Cast(ctx, attention.DType())
|
||||
attention = attention.Mul(ctx, gate.Reshape(ctx, 1, numHeads, batchSize))
|
||||
attention = attention.Reshape(ctx, opts.headDim*numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP interface {
|
||||
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
|
||||
}
|
||||
|
||||
type dense struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
SharedExpert *dense `gguf:",suf:_shexp"`
|
||||
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
|
||||
}
|
||||
|
||||
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
|
||||
if moe.ExpProbsBias != nil {
|
||||
scores = scores.Add(ctx, moe.ExpProbsBias)
|
||||
}
|
||||
return scores.TopK(ctx, opts.numExpertsUsed)
|
||||
}
|
||||
|
||||
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
|
||||
scores := moe.Router.Forward(ctx, hiddenStates).Cast(ctx, ml.DTypeF32).Sigmoid(ctx)
|
||||
selectedExperts := moe.topKIndices(ctx, scores, opts)
|
||||
routingWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, selectedExperts)
|
||||
if opts.normTopKProb {
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
}
|
||||
routingWeights = routingWeights.Scale(ctx, float64(opts.routedScalingFactor))
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
upStates := moe.Up.Forward(ctx, hiddenStates, selectedExperts)
|
||||
hiddenStates = moe.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, upStates)
|
||||
|
||||
experts := moe.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
return nextStates.Add(ctx, moe.SharedExpert.Forward(ctx, residual, opts))
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
*Attention
|
||||
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = l.Attention.Forward(ctx, layer, hiddenStates, positions, cache, opts)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
residual = hiddenStates
|
||||
|
||||
hiddenStates = l.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
if c.Bool("attention.sink_enabled") {
|
||||
return nil, fmt.Errorf("laguna: SWA attention sinks are not supported")
|
||||
}
|
||||
if c.Uint("attention.gating_type") != 1 {
|
||||
return nil, fmt.Errorf("laguna: unsupported attention gating type %d", c.Uint("attention.gating_type"))
|
||||
}
|
||||
if !c.Bool("attention.qk_norm") {
|
||||
return nil, fmt.Errorf("laguna: Q/K RMSNorm is required")
|
||||
}
|
||||
if gating := c.Uint("expert_gating_func"); gating != 2 {
|
||||
return nil, fmt.Errorf("laguna: unsupported expert gating function %d", gating)
|
||||
}
|
||||
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
opts := newOptions(c, numLayers)
|
||||
layers := make([]Layer, numLayers)
|
||||
for i := range layers {
|
||||
if opts.layerUsesMoE(i) {
|
||||
layers[i].MLP = &sparse{}
|
||||
} else {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
}
|
||||
|
||||
var pre []string
|
||||
switch c.String("tokenizer.ggml.pre") {
|
||||
case "laguna":
|
||||
pre = []string{
|
||||
`(?:\r?\n)+(?!\r?\n)`,
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
}
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
pre...,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: opts,
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWACache(int32(opts.slidingWindow), m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func newOptions(c fs.Config, numLayers int) *Options {
|
||||
denseLayers := make(map[int]bool)
|
||||
for _, layer := range configUints(c, "dense_layers") {
|
||||
denseLayers[int(layer)] = true
|
||||
}
|
||||
for i := range c.Uint("leading_dense_block_count") {
|
||||
denseLayers[int(i)] = true
|
||||
}
|
||||
|
||||
fullRopeScale := c.Float("rope.scaling.factor", 1)
|
||||
if fullRopeScale == 0 {
|
||||
fullRopeScale = 1
|
||||
}
|
||||
swaRopeScale := c.Float("rope.swa.scaling.factor", 1)
|
||||
if swaRopeScale == 0 {
|
||||
swaRopeScale = 1
|
||||
}
|
||||
fullRopeType := c.String("rope.scaling.type")
|
||||
fullRopeAttentionFactor := lagunaAttentionFactor(fullRopeType, fullRopeScale, c.Float("rope.scaling.attn_factor"))
|
||||
|
||||
return &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
numHeads: expandIntArray(configUints(c, "attention.head_count"), numLayers, c.Uint("attention.head_count", 1)),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-6),
|
||||
slidingWindow: int(c.Uint("attention.sliding_window", 512)),
|
||||
slidingWindowPattern: slidingWindowPattern(c, numLayers),
|
||||
fullRopeDim: int(c.Uint("rope.dimension_count", c.Uint("attention.key_length"))),
|
||||
fullRopeBase: c.Float("rope.freq_base", 500000),
|
||||
fullRopeScale: fullRopeScale,
|
||||
fullRopeOriginalContextLength: int(c.Uint("rope.scaling.original_context_length", 4096)),
|
||||
fullRopeAttentionFactor: fullRopeAttentionFactor,
|
||||
fullRopeBetaFast: c.Float("rope.scaling.beta_fast", 64),
|
||||
fullRopeBetaSlow: c.Float("rope.scaling.beta_slow", 1),
|
||||
swaRopeDim: int(c.Uint("rope.swa.dimension_count", c.Uint("attention.key_length"))),
|
||||
swaRopeBase: c.Float("rope.swa.freq_base", 10000),
|
||||
swaRopeScale: swaRopeScale,
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("expert_weights_norm", true),
|
||||
routedScalingFactor: c.Float("expert_weights_scale", 1),
|
||||
decoderSparseStep: int(c.Uint("decoder_sparse_step", 1)),
|
||||
denseLayers: denseLayers,
|
||||
}
|
||||
}
|
||||
|
||||
func lagunaAttentionFactor(ropeType string, scaleFactor, attentionFactor float32) float32 {
|
||||
if attentionFactor != 0 {
|
||||
return attentionFactor
|
||||
}
|
||||
if ropeType == "yarn" && scaleFactor > 1 {
|
||||
return float32(0.1*math.Log(float64(scaleFactor)) + 1)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func slidingWindowPattern(c fs.Config, numLayers int) []bool {
|
||||
pattern := c.Bools("attention.sliding_window_pattern")
|
||||
if len(pattern) == numLayers {
|
||||
return pattern
|
||||
}
|
||||
|
||||
layerTypes := configUints(c, "attention.layer_types")
|
||||
if len(layerTypes) == numLayers {
|
||||
pattern = make([]bool, numLayers)
|
||||
for i, layerType := range layerTypes {
|
||||
pattern[i] = layerType == 1
|
||||
}
|
||||
return pattern
|
||||
}
|
||||
|
||||
return make([]bool, numLayers)
|
||||
}
|
||||
|
||||
func configUints(c fs.Config, key string) []uint32 {
|
||||
keyExists := c.Value(c.Architecture()+"."+key) != nil || c.Value(key) != nil
|
||||
if cc, ok := c.(interface {
|
||||
Uints(string, ...[]uint32) []uint32
|
||||
}); ok {
|
||||
if values := cc.Uints(key); len(values) > 0 && (keyExists || !(len(values) == 1 && values[0] == 0)) {
|
||||
return values
|
||||
}
|
||||
}
|
||||
|
||||
ints := c.Ints(key)
|
||||
if len(ints) > 0 && (keyExists || !(len(ints) == 1 && ints[0] == 0)) {
|
||||
values := make([]uint32, len(ints))
|
||||
for i, v := range ints {
|
||||
values[i] = uint32(v)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
if scalar := c.Uint(key); scalar != 0 {
|
||||
return []uint32{scalar}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func expandIntArray(values []uint32, n int, fallback uint32) []int {
|
||||
if len(values) == 0 {
|
||||
values = []uint32{fallback}
|
||||
}
|
||||
defaultValue := values[0]
|
||||
if len(values) == 1 {
|
||||
defaultValue = values[0]
|
||||
}
|
||||
|
||||
out := make([]int, n)
|
||||
for i := range out {
|
||||
if i < len(values) {
|
||||
out[i] = int(values[i])
|
||||
} else {
|
||||
out[i] = int(defaultValue)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.Options.applyRotaryPositionEmbeddings(ctx, layer, key, shift), nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
if m.Cache != nil {
|
||||
m.Cache.SetLayer(i)
|
||||
if wrapper, ok := m.Cache.(*kvcache.WrapperCache); ok {
|
||||
cacheType := cacheTypeCausal
|
||||
if m.Options.layerIsSliding(i) {
|
||||
cacheType = cacheTypeSWA
|
||||
}
|
||||
wrapper.SetLayerType(cacheType)
|
||||
}
|
||||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, i, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("laguna", New)
|
||||
}
|
||||
|
||||
var _ model.Model = (*Model)(nil)
|
||||
237
model/models/laguna/model_test.go
Normal file
237
model/models/laguna/model_test.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package laguna
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testConfig map[string]any
|
||||
|
||||
func (c testConfig) Architecture() string { return "laguna" }
|
||||
|
||||
func (c testConfig) key(key string) string {
|
||||
switch {
|
||||
case len(key) >= len("tokenizer.") && key[:len("tokenizer.")] == "tokenizer.":
|
||||
return key
|
||||
case len(key) >= len("general.") && key[:len("general.")] == "general.":
|
||||
return key
|
||||
default:
|
||||
return "laguna." + key
|
||||
}
|
||||
}
|
||||
|
||||
func (c testConfig) String(key string, defaultValue ...string) string {
|
||||
if v, ok := c[c.key(key)].(string); ok {
|
||||
return v
|
||||
}
|
||||
if len(defaultValue) > 0 {
|
||||
return defaultValue[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c testConfig) Uint(key string, defaultValue ...uint32) uint32 {
|
||||
switch v := c[c.key(key)].(type) {
|
||||
case uint32:
|
||||
return v
|
||||
case int:
|
||||
return uint32(v)
|
||||
}
|
||||
if len(defaultValue) > 0 {
|
||||
return defaultValue[0]
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c testConfig) Float(key string, defaultValue ...float32) float32 {
|
||||
if v, ok := c[c.key(key)].(float32); ok {
|
||||
return v
|
||||
}
|
||||
if len(defaultValue) > 0 {
|
||||
return defaultValue[0]
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c testConfig) Bool(key string, defaultValue ...bool) bool {
|
||||
if v, ok := c[c.key(key)].(bool); ok {
|
||||
return v
|
||||
}
|
||||
if len(defaultValue) > 0 {
|
||||
return defaultValue[0]
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c testConfig) Strings(key string, defaultValue ...[]string) []string {
|
||||
if v, ok := c[c.key(key)].([]string); ok {
|
||||
return v
|
||||
}
|
||||
if len(defaultValue) > 0 {
|
||||
return defaultValue[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c testConfig) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||
if v, ok := c[c.key(key)].([]int32); ok {
|
||||
return v
|
||||
}
|
||||
if len(defaultValue) > 0 {
|
||||
return defaultValue[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c testConfig) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
if v, ok := c[c.key(key)].([]uint32); ok {
|
||||
return v
|
||||
}
|
||||
if len(defaultValue) > 0 {
|
||||
return defaultValue[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c testConfig) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
if v, ok := c[c.key(key)].([]float32); ok {
|
||||
return v
|
||||
}
|
||||
if len(defaultValue) > 0 {
|
||||
return defaultValue[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c testConfig) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
if v, ok := c[c.key(key)].([]bool); ok {
|
||||
return v
|
||||
}
|
||||
if len(defaultValue) > 0 {
|
||||
return defaultValue[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c testConfig) Len() int { return len(c) }
|
||||
|
||||
func (c testConfig) Keys() iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
for key := range c {
|
||||
if !yield(key) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c testConfig) Value(key string) any { return c[key] }
|
||||
|
||||
func TestNewOptionsLayerConfig(t *testing.T) {
|
||||
cfg := testConfig{
|
||||
"laguna.block_count": uint32(4),
|
||||
"laguna.embedding_length": uint32(128),
|
||||
"laguna.attention.key_length": uint32(16),
|
||||
"laguna.attention.head_count": []uint32{8, 16, 16, 16},
|
||||
"laguna.attention.head_count_kv": uint32(4),
|
||||
"laguna.attention.layer_norm_rms_epsilon": float32(1e-6),
|
||||
"laguna.attention.sliding_window": uint32(512),
|
||||
"laguna.attention.sliding_window_pattern": []bool{false, true, true, true},
|
||||
"laguna.rope.dimension_count": uint32(8),
|
||||
"laguna.rope.freq_base": float32(500000),
|
||||
"laguna.rope.scaling.factor": float32(32),
|
||||
"laguna.rope.scaling.original_context_length": uint32(4096),
|
||||
"laguna.rope.swa.dimension_count": uint32(16),
|
||||
"laguna.rope.swa.freq_base": float32(10000),
|
||||
"laguna.expert_count": uint32(32),
|
||||
"laguna.expert_used_count": uint32(4),
|
||||
"laguna.expert_weights_norm": true,
|
||||
"laguna.expert_weights_scale": float32(2.5),
|
||||
"laguna.decoder_sparse_step": uint32(1),
|
||||
"laguna.dense_layers": []uint32{0},
|
||||
}
|
||||
|
||||
opts := newOptions(cfg, 4)
|
||||
if got := opts.numHeadsForLayer(0); got != 8 {
|
||||
t.Fatalf("layer 0 heads = %d, want 8", got)
|
||||
}
|
||||
if got := opts.numHeadsForLayer(1); got != 16 {
|
||||
t.Fatalf("layer 1 heads = %d, want 16", got)
|
||||
}
|
||||
if opts.layerIsSliding(0) {
|
||||
t.Fatal("layer 0 should be full attention")
|
||||
}
|
||||
if !opts.layerIsSliding(1) {
|
||||
t.Fatal("layer 1 should be sliding attention")
|
||||
}
|
||||
if opts.layerUsesMoE(0) {
|
||||
t.Fatal("layer 0 should be dense")
|
||||
}
|
||||
if !opts.layerUsesMoE(1) {
|
||||
t.Fatal("layer 1 should use MoE")
|
||||
}
|
||||
if opts.fullRopeDim != 8 || opts.swaRopeDim != 16 {
|
||||
t.Fatalf("rope dims = full %d swa %d, want 8/16", opts.fullRopeDim, opts.swaRopeDim)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOptionsYarnAttentionFactorFallback(t *testing.T) {
|
||||
cfg := testConfig{
|
||||
"laguna.block_count": uint32(1),
|
||||
"laguna.embedding_length": uint32(128),
|
||||
"laguna.attention.key_length": uint32(16),
|
||||
"laguna.attention.head_count": uint32(8),
|
||||
"laguna.attention.head_count_kv": uint32(4),
|
||||
"laguna.rope.scaling.type": "yarn",
|
||||
"laguna.rope.scaling.factor": float32(32),
|
||||
}
|
||||
|
||||
opts := newOptions(cfg, 1)
|
||||
want := float32(0.1*math.Log(32) + 1)
|
||||
if got := opts.fullRopeAttentionFactor; math.Abs(float64(got-want)) > 1e-6 {
|
||||
t.Fatalf("fullRopeAttentionFactor = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRejectsUnsupportedLagunaVariants(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg testConfig
|
||||
}{
|
||||
{
|
||||
name: "attention sinks",
|
||||
cfg: testConfig{
|
||||
"laguna.attention.sink_enabled": true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non per-head gate",
|
||||
cfg: testConfig{
|
||||
"laguna.attention.gating_type": uint32(0),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing qk norm",
|
||||
cfg: testConfig{
|
||||
"laguna.attention.gating_type": uint32(1),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non sigmoid experts",
|
||||
cfg: testConfig{
|
||||
"laguna.attention.gating_type": uint32(1),
|
||||
"laguna.attention.qk_norm": true,
|
||||
"laguna.expert_gating_func": uint32(1),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if _, err := New(tt.cfg); err == nil {
|
||||
t.Fatal("expected unsupported variant error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/glm4moelite"
|
||||
_ "github.com/ollama/ollama/model/models/glmocr"
|
||||
_ "github.com/ollama/ollama/model/models/gptoss"
|
||||
_ "github.com/ollama/ollama/model/models/laguna"
|
||||
_ "github.com/ollama/ollama/model/models/lfm2"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
|
||||
355
model/models/nemotronh/imageproc.go
Normal file
355
model/models/nemotronh/imageproc.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"image"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize int
|
||||
patchSize int
|
||||
numChannels int
|
||||
maxTiles int
|
||||
minNumPatches int
|
||||
maxNumPatches int
|
||||
useThumbnail bool
|
||||
projectorScale int
|
||||
imageMean [3]float32
|
||||
imageStd [3]float32
|
||||
}
|
||||
|
||||
type processedVisionTile struct {
|
||||
data []float32
|
||||
size image.Point
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
mean := c.Floats("vision.image_mean")
|
||||
std := c.Floats("vision.image_std")
|
||||
|
||||
processor := ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size", 512)),
|
||||
patchSize: int(c.Uint("vision.patch_size", 16)),
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||
maxTiles: int(c.Uint("vision.max_tiles", 12)),
|
||||
minNumPatches: int(c.Uint("vision.min_num_patches")),
|
||||
maxNumPatches: int(c.Uint("vision.max_num_patches")),
|
||||
useThumbnail: c.Bool("vision.use_thumbnail", true),
|
||||
projectorScale: int(c.Uint("vision.projector.scale_factor", 2)),
|
||||
imageMean: imageproc.ClipDefaultMean,
|
||||
imageStd: imageproc.ClipDefaultSTD,
|
||||
}
|
||||
|
||||
if len(mean) >= 3 {
|
||||
processor.imageMean = [3]float32{mean[0], mean[1], mean[2]}
|
||||
}
|
||||
if len(std) >= 3 {
|
||||
processor.imageStd = [3]float32{std[0], std[1], std[2]}
|
||||
}
|
||||
if processor.imageSize <= 0 {
|
||||
processor.imageSize = 512
|
||||
}
|
||||
if processor.patchSize <= 0 {
|
||||
processor.patchSize = 16
|
||||
}
|
||||
if processor.numChannels <= 0 {
|
||||
processor.numChannels = 3
|
||||
}
|
||||
if processor.maxTiles <= 0 {
|
||||
processor.maxTiles = 12
|
||||
}
|
||||
if processor.projectorScale <= 0 {
|
||||
processor.projectorScale = 2
|
||||
}
|
||||
|
||||
return processor
|
||||
}
|
||||
|
||||
func (p ImageProcessor) ProcessImage(img image.Image) ([]processedVisionTile, error) {
|
||||
img = imageproc.Composite(img)
|
||||
if p.useDynamicResolution() {
|
||||
return p.processDynamicImage(img)
|
||||
}
|
||||
|
||||
return p.processTiledImage(img), nil
|
||||
}
|
||||
|
||||
func (p ImageProcessor) useDynamicResolution() bool {
|
||||
return p.minNumPatches > 0 || p.maxNumPatches > 0
|
||||
}
|
||||
|
||||
func (p ImageProcessor) processTiledImage(img image.Image) []processedVisionTile {
|
||||
bounds := img.Bounds()
|
||||
origWidth := bounds.Dx()
|
||||
origHeight := bounds.Dy()
|
||||
targetRatios := nemotronTargetRatios(p.maxTiles)
|
||||
gridWidth, gridHeight := findClosestAspectRatio(float64(origWidth)/float64(origHeight), targetRatios, origWidth, origHeight, p.imageSize)
|
||||
|
||||
targetWidth := p.imageSize * gridWidth
|
||||
targetHeight := p.imageSize * gridHeight
|
||||
resized := resizeImageBicubicCHW(img, targetWidth, targetHeight)
|
||||
|
||||
tiles := make([]processedVisionTile, 0, gridWidth*gridHeight+1)
|
||||
for row := range gridHeight {
|
||||
for col := range gridWidth {
|
||||
tile := cropCHWRegion(
|
||||
resized,
|
||||
targetWidth,
|
||||
targetHeight,
|
||||
p.numChannels,
|
||||
col*p.imageSize,
|
||||
row*p.imageSize,
|
||||
p.imageSize,
|
||||
p.imageSize,
|
||||
)
|
||||
tiles = append(tiles, processedVisionTile{
|
||||
data: normalizeVisionCHW(tile, p.imageMean, p.imageStd),
|
||||
size: image.Point{X: p.imageSize, Y: p.imageSize},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if p.useThumbnail && len(tiles) > 1 {
|
||||
thumbnail := resizeImageBicubicCHW(img, p.imageSize, p.imageSize)
|
||||
tiles = append(tiles, processedVisionTile{
|
||||
data: normalizeVisionCHW(thumbnail, p.imageMean, p.imageStd),
|
||||
size: image.Point{X: p.imageSize, Y: p.imageSize},
|
||||
})
|
||||
}
|
||||
|
||||
return tiles
|
||||
}
|
||||
|
||||
func (p ImageProcessor) processDynamicImage(img image.Image) ([]processedVisionTile, error) {
|
||||
bounds := img.Bounds()
|
||||
origWidth := bounds.Dx()
|
||||
origHeight := bounds.Dy()
|
||||
patchesWidth, patchesHeight := p.dynamicPatchGrid(origWidth, origHeight)
|
||||
if patchesWidth <= 0 || patchesHeight <= 0 {
|
||||
return nil, errors.New("nemotron_h_omni: invalid dynamic image patch grid")
|
||||
}
|
||||
|
||||
targetWidth := patchesWidth * p.patchSize
|
||||
targetHeight := patchesHeight * p.patchSize
|
||||
resized := resizeImageBicubicCHW(img, targetWidth, targetHeight)
|
||||
|
||||
return []processedVisionTile{{
|
||||
data: normalizeVisionCHW(resized, p.imageMean, p.imageStd),
|
||||
size: image.Point{X: targetWidth, Y: targetHeight},
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func (p ImageProcessor) dynamicPatchGrid(origWidth, origHeight int) (int, int) {
|
||||
patchesHeight := max(1, int(math.Round(float64(origHeight)/float64(p.patchSize)+0.5)))
|
||||
patchesWidth := max(1, int(math.Round(float64(origWidth)/float64(p.patchSize)+0.5)))
|
||||
|
||||
patches := patchesHeight * patchesWidth
|
||||
currentNumPatchesAvailable := p.maxNumPatches
|
||||
if currentNumPatchesAvailable <= 0 {
|
||||
currentNumPatchesAvailable = max(patches, p.minNumPatches)
|
||||
}
|
||||
|
||||
factor := math.Min(math.Sqrt(float64(currentNumPatchesAvailable)/float64(patches)), 1.0)
|
||||
targetPatchesHeight := max(1, int(math.Floor(factor*float64(patchesHeight))))
|
||||
targetPatchesWidth := max(1, int(math.Floor(factor*float64(patchesWidth))))
|
||||
|
||||
if currentNumPatchesAvailable > p.minNumPatches && targetPatchesHeight*targetPatchesWidth < p.minNumPatches {
|
||||
upFactor := math.Sqrt(float64(p.minNumPatches) / float64(targetPatchesHeight*targetPatchesWidth))
|
||||
targetPatchesHeight = int(math.Ceil(upFactor * float64(targetPatchesHeight)))
|
||||
targetPatchesWidth = int(math.Ceil(upFactor * float64(targetPatchesWidth)))
|
||||
}
|
||||
|
||||
targetPatchesHeight = roundPatchGridForPixelShuffle(targetPatchesHeight, targetPatchesWidth, currentNumPatchesAvailable, p.projectorScale)
|
||||
targetPatchesWidth = roundPatchGridForPixelShuffle(targetPatchesWidth, targetPatchesHeight, currentNumPatchesAvailable, p.projectorScale)
|
||||
|
||||
return targetPatchesWidth, targetPatchesHeight
|
||||
}
|
||||
|
||||
func roundPatchGridForPixelShuffle(v, other, maxPatches, divisor int) int {
|
||||
if divisor <= 1 {
|
||||
return v
|
||||
}
|
||||
rem := v % divisor
|
||||
if rem == 0 {
|
||||
return v
|
||||
}
|
||||
|
||||
inc := divisor - rem
|
||||
if (v+inc)*other <= maxPatches {
|
||||
return v + inc
|
||||
}
|
||||
return max(divisor, v-rem)
|
||||
}
|
||||
|
||||
type nemotronImageRatio struct {
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
func nemotronTargetRatios(maxTiles int) []nemotronImageRatio {
|
||||
targetRatios := make([]nemotronImageRatio, 0, maxTiles*maxTiles)
|
||||
for n := 1; n <= maxTiles; n++ {
|
||||
for w := 1; w <= n; w++ {
|
||||
for h := 1; h <= n; h++ {
|
||||
if w*h > maxTiles {
|
||||
continue
|
||||
}
|
||||
targetRatios = append(targetRatios, nemotronImageRatio{width: w, height: h})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unique := targetRatios[:0]
|
||||
for _, ratio := range targetRatios {
|
||||
if slices.Contains(unique, ratio) {
|
||||
continue
|
||||
}
|
||||
unique = append(unique, ratio)
|
||||
}
|
||||
|
||||
slices.SortFunc(unique, func(a, b nemotronImageRatio) int {
|
||||
return a.width*a.height - b.width*b.height
|
||||
})
|
||||
|
||||
return unique
|
||||
}
|
||||
|
||||
func findClosestAspectRatio(aspectRatio float64, targetRatios []nemotronImageRatio, width, height, imageSize int) (int, int) {
|
||||
bestRatio := nemotronImageRatio{width: 1, height: 1}
|
||||
bestRatioDiff := math.MaxFloat64
|
||||
area := width * height
|
||||
|
||||
for _, ratio := range targetRatios {
|
||||
targetAspectRatio := float64(ratio.width) / float64(ratio.height)
|
||||
ratioDiff := math.Abs(aspectRatio - targetAspectRatio)
|
||||
if ratioDiff < bestRatioDiff {
|
||||
bestRatioDiff = ratioDiff
|
||||
bestRatio = ratio
|
||||
continue
|
||||
}
|
||||
|
||||
if ratioDiff == bestRatioDiff && area > int(0.5*float64(imageSize*imageSize*ratio.width*ratio.height)) {
|
||||
bestRatio = ratio
|
||||
}
|
||||
}
|
||||
|
||||
return bestRatio.width, bestRatio.height
|
||||
}
|
||||
|
||||
func resizeImageBicubicCHW(img image.Image, outW, outH int) []float32 {
|
||||
bounds := img.Bounds()
|
||||
inW := bounds.Dx()
|
||||
inH := bounds.Dy()
|
||||
src := make([]float32, 3*inW*inH)
|
||||
|
||||
for y := range inH {
|
||||
for x := range inW {
|
||||
r, g, b, _ := img.At(bounds.Min.X+x, bounds.Min.Y+y).RGBA()
|
||||
src[y*inW+x] = float32(r>>8) / 255.0
|
||||
src[inW*inH+y*inW+x] = float32(g>>8) / 255.0
|
||||
src[2*inW*inH+y*inW+x] = float32(b>>8) / 255.0
|
||||
}
|
||||
}
|
||||
|
||||
dst := make([]float32, 3*outW*outH)
|
||||
scaleX := float64(inW) / float64(outW)
|
||||
scaleY := float64(inH) / float64(outH)
|
||||
|
||||
for oy := range outH {
|
||||
srcY := scaleY*(float64(oy)+0.5) - 0.5
|
||||
yBase := int(math.Floor(srcY))
|
||||
yFrac := clampUnit(srcY - float64(yBase))
|
||||
wy := torchBicubicWeights(yFrac)
|
||||
|
||||
for ox := range outW {
|
||||
srcX := scaleX*(float64(ox)+0.5) - 0.5
|
||||
xBase := int(math.Floor(srcX))
|
||||
xFrac := clampUnit(srcX - float64(xBase))
|
||||
wx := torchBicubicWeights(xFrac)
|
||||
|
||||
for c := range 3 {
|
||||
var sum float64
|
||||
channelOffset := c * inW * inH
|
||||
for ky := range 4 {
|
||||
iy := clampIndex(yBase-1+ky, 0, inH-1)
|
||||
for kx := range 4 {
|
||||
ix := clampIndex(xBase-1+kx, 0, inW-1)
|
||||
sum += float64(src[channelOffset+iy*inW+ix]) * wy[ky] * wx[kx]
|
||||
}
|
||||
}
|
||||
dst[c*outW*outH+oy*outW+ox] = float32(sum)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func cropCHWRegion(values []float32, width, height, channels, left, top, cropW, cropH int) []float32 {
|
||||
out := make([]float32, channels*cropW*cropH)
|
||||
channelSize := width * height
|
||||
cropSize := cropW * cropH
|
||||
for c := range channels {
|
||||
srcBase := c * channelSize
|
||||
dstBase := c * cropSize
|
||||
for y := range cropH {
|
||||
copy(out[dstBase+y*cropW:dstBase+(y+1)*cropW], values[srcBase+(top+y)*width+left:srcBase+(top+y)*width+left+cropW])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeVisionCHW(values []float32, mean, std [3]float32) []float32 {
|
||||
out := make([]float32, len(values))
|
||||
channelSize := len(values) / 3
|
||||
for c := range 3 {
|
||||
base := c * channelSize
|
||||
for i := range channelSize {
|
||||
out[base+i] = (values[base+i] - mean[c]) / std[c]
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func torchBicubicWeights(t float64) [4]float64 {
|
||||
const a = -0.75
|
||||
return [4]float64{
|
||||
bicubicConvolution2(t+1.0, a),
|
||||
bicubicConvolution1(t, a),
|
||||
bicubicConvolution1(1.0-t, a),
|
||||
bicubicConvolution2(2.0-t, a),
|
||||
}
|
||||
}
|
||||
|
||||
func bicubicConvolution1(x, a float64) float64 {
|
||||
return ((a+2)*x-(a+3))*x*x + 1
|
||||
}
|
||||
|
||||
func bicubicConvolution2(x, a float64) float64 {
|
||||
return ((a*x-5*a)*x+8*a)*x - 4*a
|
||||
}
|
||||
|
||||
func clampUnit(v float64) float64 {
|
||||
if v < 0 {
|
||||
return 0
|
||||
}
|
||||
if v > 1 {
|
||||
return 1
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func clampIndex(v, lo, hi int) int {
|
||||
if v < lo {
|
||||
return lo
|
||||
}
|
||||
if v > hi {
|
||||
return hi
|
||||
}
|
||||
return v
|
||||
}
|
||||
@@ -117,9 +117,7 @@ func Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
func (m *Model) forwardHiddenStates(ctx ml.Context, batch input.Batch, hiddenStates ml.Tensor) (ml.Tensor, error) {
|
||||
cache := m.Cache.(*HybridCache)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
@@ -137,11 +135,24 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
}
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil
|
||||
}
|
||||
|
||||
func (m *Model) forwardLogits(ctx ml.Context, batch input.Batch, hiddenStates ml.Tensor) (ml.Tensor, error) {
|
||||
hiddenStates, err := m.forwardHiddenStates(ctx, batch, hiddenStates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
return m.forwardLogits(ctx, batch, hiddenStates)
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) (*Model, error) {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
layers := make([]Layer, numLayers)
|
||||
|
||||
@@ -306,6 +317,10 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
return newTextModel(c)
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("nemotron_h", New)
|
||||
model.Register("nemotron_h_moe", New)
|
||||
|
||||
511
model/models/nemotronh/model_audio.go
Normal file
511
model/models/nemotronh/model_audio.go
Normal file
@@ -0,0 +1,511 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
type AudioOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
headDim int
|
||||
intermediateSize int
|
||||
convKernelSize int
|
||||
melBins int
|
||||
sampleRate int
|
||||
subsamplingKernel int
|
||||
subsamplingStride int
|
||||
scaleInput bool
|
||||
eps float32
|
||||
}
|
||||
|
||||
type AudioFeatureExtractor struct {
|
||||
FB ml.Tensor `gguf:"fb"`
|
||||
Window ml.Tensor `gguf:"window"`
|
||||
|
||||
mu sync.Mutex
|
||||
fb []float32
|
||||
window []float32
|
||||
fbShape [2]int
|
||||
}
|
||||
|
||||
func (f *AudioFeatureExtractor) windowAndFilters(melBins, freqBins, sampleRate int) ([]float32, []float32) {
|
||||
if f == nil {
|
||||
return defaultParakeetWindow(), buildSlaneyMelFilterBank(freqBins, melBins, sampleRate)
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if f.window == nil {
|
||||
if f.Window != nil {
|
||||
if values := f.Window.BackendGet(); len(values) == parakeetWinLength {
|
||||
f.window = values
|
||||
}
|
||||
}
|
||||
if f.window == nil {
|
||||
f.window = defaultParakeetWindow()
|
||||
}
|
||||
}
|
||||
|
||||
if f.fb == nil {
|
||||
if f.FB != nil {
|
||||
if values := f.FB.BackendGet(); len(values) == melBins*freqBins {
|
||||
f.fb = values
|
||||
f.fbShape = [2]int{melBins, freqBins}
|
||||
}
|
||||
}
|
||||
if f.fb == nil {
|
||||
f.fb = buildSlaneyMelFilterBank(freqBins, melBins, sampleRate)
|
||||
f.fbShape = [2]int{melBins, freqBins}
|
||||
}
|
||||
}
|
||||
|
||||
return f.window, f.fb
|
||||
}
|
||||
|
||||
type AudioSubsampling struct {
|
||||
Conv0 *nn.Conv2D `gguf:"conv0"`
|
||||
DW1 *AudioDepthwiseConv2D `gguf:"dw1"`
|
||||
PW1 *nn.Conv2D `gguf:"pw1"`
|
||||
DW2 *AudioDepthwiseConv2D `gguf:"dw2"`
|
||||
PW2 *nn.Conv2D `gguf:"pw2"`
|
||||
|
||||
Linear *nn.Linear `gguf:"linear"`
|
||||
}
|
||||
|
||||
type AudioDepthwiseConv2D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
type AudioFeedForward struct {
|
||||
Up *nn.Linear `gguf:"up"`
|
||||
Down *nn.Linear `gguf:"down"`
|
||||
}
|
||||
|
||||
type AudioSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
RelativeKey *nn.Linear `gguf:"attn_rel_k"`
|
||||
BiasU ml.Tensor `gguf:"attn_bias_u"`
|
||||
BiasV ml.Tensor `gguf:"attn_bias_v"`
|
||||
}
|
||||
|
||||
type AudioConvolutionModule struct {
|
||||
Pointwise1 *nn.Linear `gguf:"conv_pw1"`
|
||||
Depthwise ml.Tensor `gguf:"conv_dw.weight"`
|
||||
BatchNorm *AudioBatchNorm1D `gguf:"conv_bn"`
|
||||
Pointwise2 *nn.Linear `gguf:"conv_pw2"`
|
||||
}
|
||||
|
||||
type AudioBatchNorm1D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
RunningMean ml.Tensor `gguf:"running_mean"`
|
||||
RunningVar ml.Tensor `gguf:"running_var"`
|
||||
}
|
||||
|
||||
type AudioLayer struct {
|
||||
FFN1Norm *nn.LayerNorm `gguf:"ffn1_norm"`
|
||||
FFN1Up *nn.Linear `gguf:"ffn1_up"`
|
||||
FFN1Down *nn.Linear `gguf:"ffn1_down"`
|
||||
|
||||
AttentionNorm *nn.LayerNorm `gguf:"attn_norm"`
|
||||
Attention *AudioSelfAttention
|
||||
|
||||
ConvNorm *nn.LayerNorm `gguf:"conv_norm"`
|
||||
Conv *AudioConvolutionModule
|
||||
|
||||
FFN2Norm *nn.LayerNorm `gguf:"ffn2_norm"`
|
||||
FFN2Up *nn.Linear `gguf:"ffn2_up"`
|
||||
FFN2Down *nn.Linear `gguf:"ffn2_down"`
|
||||
|
||||
OutputNorm *nn.LayerNorm `gguf:"out_norm"`
|
||||
}
|
||||
|
||||
type AudioModel struct {
|
||||
FeatureExtractor *AudioFeatureExtractor `gguf:"feature_extractor"`
|
||||
Subsampling *AudioSubsampling `gguf:"subsampling"`
|
||||
Layers []AudioLayer `gguf:"blk"`
|
||||
|
||||
*AudioOptions
|
||||
}
|
||||
|
||||
type AudioProjector struct {
|
||||
Norm *nn.RMSNorm `gguf:"norm"`
|
||||
Linear1 *nn.Linear `gguf:"1"`
|
||||
Linear2 *nn.Linear `gguf:"2"`
|
||||
}
|
||||
|
||||
func (p *AudioProjector) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
|
||||
x = p.Norm.Forward(ctx, x, eps)
|
||||
x = audioF32(ctx, p.Linear1.Forward(ctx, x))
|
||||
x = x.RELU(ctx)
|
||||
x = x.Mul(ctx, x)
|
||||
return audioF32(ctx, p.Linear2.Forward(ctx, x))
|
||||
}
|
||||
|
||||
func (m *AudioModel) ForwardAudio(ctx ml.Context, melFeatures ml.Tensor, validFrames int, projector *AudioProjector) ml.Tensor {
|
||||
x := melFeatures.Reshape(ctx, melFeatures.Dim(0), melFeatures.Dim(1), 1, 1)
|
||||
validLen := validFrames
|
||||
|
||||
x = forwardAudioConv2D(ctx, m.Subsampling.Conv0, x, m.subsamplingStride, m.subsamplingStride, audioConvPadding(m.subsamplingKernel), audioConvPadding(m.subsamplingKernel), 1, 1)
|
||||
x = x.RELU(ctx)
|
||||
validLen = convOutputLength(validLen, m.subsamplingKernel, m.subsamplingStride, audioConvPadding(m.subsamplingKernel))
|
||||
x = applyAudioTimeMask(ctx, x, validLen)
|
||||
|
||||
x = forwardAudioDepthwiseConv2D(ctx, m.Subsampling.DW1, x, m.subsamplingStride, m.subsamplingStride, audioConvPadding(m.subsamplingKernel), audioConvPadding(m.subsamplingKernel), 1, 1)
|
||||
x = forwardAudioConv2D(ctx, m.Subsampling.PW1, x, 1, 1, 0, 0, 1, 1)
|
||||
x = x.RELU(ctx)
|
||||
validLen = convOutputLength(validLen, m.subsamplingKernel, m.subsamplingStride, audioConvPadding(m.subsamplingKernel))
|
||||
x = applyAudioTimeMask(ctx, x, validLen)
|
||||
|
||||
x = forwardAudioDepthwiseConv2D(ctx, m.Subsampling.DW2, x, m.subsamplingStride, m.subsamplingStride, audioConvPadding(m.subsamplingKernel), audioConvPadding(m.subsamplingKernel), 1, 1)
|
||||
x = forwardAudioConv2D(ctx, m.Subsampling.PW2, x, 1, 1, 0, 0, 1, 1)
|
||||
x = x.RELU(ctx)
|
||||
validLen = convOutputLength(validLen, m.subsamplingKernel, m.subsamplingStride, audioConvPadding(m.subsamplingKernel))
|
||||
x = applyAudioTimeMask(ctx, x, validLen)
|
||||
|
||||
x = flattenAudioSubsamplingOutput(ctx, x)
|
||||
x = m.Subsampling.Linear.Forward(ctx, x)
|
||||
|
||||
if m.scaleInput {
|
||||
x = x.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
|
||||
}
|
||||
if validLen > 0 && validLen < x.Dim(1) {
|
||||
x = x.Slice(ctx, 1, 0, validLen, 1).Contiguous(ctx)
|
||||
}
|
||||
|
||||
for i := range m.Layers {
|
||||
x = m.Layers[i].Forward(ctx, x, validLen, m.AudioOptions)
|
||||
}
|
||||
|
||||
if projector != nil {
|
||||
x = projector.Forward(ctx, x, m.eps)
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
func flattenAudioSubsamplingOutput(ctx ml.Context, x ml.Tensor) ml.Tensor {
|
||||
fOut, tOut, cOut := x.Dim(0), x.Dim(1), x.Dim(2)
|
||||
// PyTorch flattens the subsampling output after [B, C, T, F] ->
|
||||
// [B, T, C, F], so F must remain the fastest dimension inside each
|
||||
// channel block before the linear projection.
|
||||
x = x.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
return x.Reshape(ctx, cOut*fOut, tOut)
|
||||
}
|
||||
|
||||
func (l *AudioLayer) Forward(ctx ml.Context, x ml.Tensor, validLen int, opts *AudioOptions) ml.Tensor {
|
||||
residual := x
|
||||
x = audioFeedForward(ctx, l.FFN1Up, l.FFN1Down, l.FFN1Norm.Forward(ctx, x, opts.eps)).Scale(ctx, 0.5)
|
||||
x = residual.Add(ctx, x)
|
||||
|
||||
residual = x
|
||||
x = l.Attention.Forward(ctx, l.AttentionNorm.Forward(ctx, x, opts.eps), validLen, opts)
|
||||
x = residual.Add(ctx, x)
|
||||
|
||||
residual = x
|
||||
x = l.Conv.Forward(ctx, l.ConvNorm.Forward(ctx, x, opts.eps), opts)
|
||||
x = residual.Add(ctx, x)
|
||||
|
||||
residual = x
|
||||
x = audioFeedForward(ctx, l.FFN2Up, l.FFN2Down, l.FFN2Norm.Forward(ctx, x, opts.eps)).Scale(ctx, 0.5)
|
||||
x = residual.Add(ctx, x)
|
||||
|
||||
return l.OutputNorm.Forward(ctx, x, opts.eps)
|
||||
}
|
||||
|
||||
func audioFeedForward(ctx ml.Context, up, down *nn.Linear, x ml.Tensor) ml.Tensor {
|
||||
x = audioF32(ctx, up.Forward(ctx, x))
|
||||
x = x.SILU(ctx)
|
||||
return audioF32(ctx, down.Forward(ctx, x))
|
||||
}
|
||||
|
||||
func (a *AudioSelfAttention) Forward(ctx ml.Context, x ml.Tensor, validLen int, opts *AudioOptions) ml.Tensor {
|
||||
seqLen := x.Dim(1)
|
||||
headDim := opts.headDim
|
||||
numHeads := opts.numHeads
|
||||
|
||||
q := audioF32(ctx, a.Query.Forward(ctx, x)).Reshape(ctx, headDim, numHeads, seqLen)
|
||||
k := audioF32(ctx, a.Key.Forward(ctx, x)).Reshape(ctx, headDim, numHeads, seqLen)
|
||||
v := audioF32(ctx, a.Value.Forward(ctx, x)).Reshape(ctx, headDim, numHeads, seqLen)
|
||||
|
||||
qU := q
|
||||
if a.BiasU != nil {
|
||||
qU = qU.Add(ctx, audioF32(ctx, a.BiasU).Reshape(ctx, headDim, numHeads, 1))
|
||||
}
|
||||
qV := q
|
||||
if a.BiasV != nil {
|
||||
qV = qV.Add(ctx, audioF32(ctx, a.BiasV).Reshape(ctx, headDim, numHeads, 1))
|
||||
}
|
||||
|
||||
qP := qU.Permute(ctx, 0, 2, 1, 3)
|
||||
kP := k.Permute(ctx, 0, 2, 1, 3)
|
||||
logits := kP.MulmatFullPrec(ctx, qP)
|
||||
|
||||
positionEmbeddings := parakeetPositionEmbeddings(ctx, seqLen, opts.hiddenSize)
|
||||
relKey := audioF32(ctx, a.RelativeKey.Forward(ctx, positionEmbeddings)).Reshape(ctx, headDim, numHeads, 2*seqLen-1)
|
||||
pP := relKey.Permute(ctx, 0, 2, 1, 3)
|
||||
qVP := qV.Permute(ctx, 0, 2, 1, 3)
|
||||
relLogits := pP.MulmatFullPrec(ctx, qVP)
|
||||
relLogits = relativeShiftParakeet(ctx, relLogits, seqLen, numHeads)
|
||||
|
||||
logits = logits.Add(ctx, relLogits)
|
||||
logits = logits.Scale(ctx, math.Pow(float64(headDim), -0.5))
|
||||
if validLen > 0 && validLen < seqLen {
|
||||
logits = logits.Add(ctx, audioAttentionMask(ctx, seqLen, validLen))
|
||||
}
|
||||
logits = logits.Softmax(ctx)
|
||||
|
||||
vP := v.Permute(ctx, 0, 2, 1, 3)
|
||||
vPT := vP.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
out := vPT.Mulmat(ctx, logits)
|
||||
out = out.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
out = out.Reshape(ctx, opts.hiddenSize, seqLen)
|
||||
return audioF32(ctx, a.Output.Forward(ctx, out))
|
||||
}
|
||||
|
||||
func (c *AudioConvolutionModule) Forward(ctx ml.Context, x ml.Tensor, opts *AudioOptions) ml.Tensor {
|
||||
x = audioF32(ctx, c.Pointwise1.Forward(ctx, x))
|
||||
hidden := x.Dim(0) / 2
|
||||
value := x.Slice(ctx, 0, 0, hidden, 1).Contiguous(ctx)
|
||||
gate := x.Slice(ctx, 0, hidden, 2*hidden, 1).Contiguous(ctx).Sigmoid(ctx)
|
||||
x = value.Mul(ctx, gate)
|
||||
|
||||
x = audioDepthwiseConv1DSame(ctx, x, c.Depthwise, audioConvPadding(opts.convKernelSize))
|
||||
x = c.BatchNorm.Forward(ctx, x, opts.eps)
|
||||
x = x.SILU(ctx)
|
||||
return audioF32(ctx, c.Pointwise2.Forward(ctx, x))
|
||||
}
|
||||
|
||||
func audioF32(ctx ml.Context, x ml.Tensor) ml.Tensor {
|
||||
if x.DType() == ml.DTypeF32 {
|
||||
return x
|
||||
}
|
||||
|
||||
// Metal binary kernels used by the audio graph require F32 operands here.
|
||||
// This likely slows audio and should be revisited once the precision vs.
|
||||
// speed tradeoff is validated against BF16-native elementwise paths.
|
||||
return x.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
|
||||
func (b *AudioBatchNorm1D) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
|
||||
if b == nil || b.RunningMean == nil || b.RunningVar == nil {
|
||||
return x
|
||||
}
|
||||
|
||||
hidden := x.Dim(0)
|
||||
epsValues := make([]float32, hidden)
|
||||
for i := range epsValues {
|
||||
epsValues[i] = eps
|
||||
}
|
||||
|
||||
variance := b.RunningVar.Add(ctx, ctx.Input().FromFloats(epsValues, hidden))
|
||||
x = x.Sub(ctx, b.RunningMean)
|
||||
x = x.Div(ctx, variance.Sqrt(ctx))
|
||||
if b.Weight != nil {
|
||||
x = x.Mul(ctx, b.Weight)
|
||||
}
|
||||
if b.Bias != nil {
|
||||
x = x.Add(ctx, b.Bias)
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
func forwardAudioConv2D(ctx ml.Context, conv *nn.Conv2D, x ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
weight := conv.Weight.Contiguous(ctx)
|
||||
x = weight.Conv2D(ctx, x, s0, s1, p0, p1, d0, d1)
|
||||
if conv.Bias != nil {
|
||||
x = x.Add(ctx, conv.Bias.Reshape(ctx, 1, 1, -1))
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
func forwardAudioDepthwiseConv2D(ctx ml.Context, conv *AudioDepthwiseConv2D, x ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
x = audioDepthwiseConv2D(ctx, x, conv.Weight, s0, s1, p0, p1, d0, d1)
|
||||
if conv.Bias != nil {
|
||||
x = x.Add(ctx, conv.Bias.Reshape(ctx, 1, 1, -1))
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
func applyAudioTimeMask(ctx ml.Context, x ml.Tensor, validLen int) ml.Tensor {
|
||||
if validLen <= 0 || validLen >= x.Dim(1) {
|
||||
return x
|
||||
}
|
||||
mask := make([]float32, x.Dim(1))
|
||||
for i := range validLen {
|
||||
mask[i] = 1
|
||||
}
|
||||
return x.Mul(ctx, ctx.Input().FromFloats(mask, 1, x.Dim(1), 1, 1))
|
||||
}
|
||||
|
||||
func audioDepthwiseConv1DSame(ctx ml.Context, x, kernel ml.Tensor, padding int) ml.Tensor {
|
||||
kernelSize := kernel.Dim(0)
|
||||
seqLen := x.Dim(1)
|
||||
|
||||
kernelT := kernel.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
var out ml.Tensor
|
||||
for k := range kernelSize {
|
||||
offset := k - padding
|
||||
shifted := x
|
||||
switch {
|
||||
case offset > 0:
|
||||
shifted = x.Slice(ctx, 1, offset, seqLen, 1).Contiguous(ctx)
|
||||
shifted = shifted.PadExt(ctx, 0, 0, 0, offset, 0, 0, 0, 0)
|
||||
case offset < 0:
|
||||
shift := -offset
|
||||
shifted = x.Slice(ctx, 1, 0, seqLen-shift, 1).Contiguous(ctx)
|
||||
shifted = shifted.PadExt(ctx, 0, 0, shift, 0, 0, 0, 0, 0)
|
||||
}
|
||||
|
||||
wk := kernelT.Slice(ctx, 1, k, k+1, 1).Contiguous(ctx)
|
||||
term := shifted.Mul(ctx, wk)
|
||||
if out == nil {
|
||||
out = term
|
||||
} else {
|
||||
out = out.Add(ctx, term)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func audioDepthwiseConv2D(ctx ml.Context, x, kernel ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
if d0 != 1 || d1 != 1 {
|
||||
panic("audio depthwise conv2d only supports dilation 1")
|
||||
}
|
||||
|
||||
kernel = kernel.Contiguous(ctx)
|
||||
kernelW, kernelH := kernel.Dim(0), kernel.Dim(1)
|
||||
outW := convOutputLength(x.Dim(0), kernelW, s0, p0)
|
||||
outH := convOutputLength(x.Dim(1), kernelH, s1, p1)
|
||||
padded := x.PadExt(ctx, p0, p0, p1, p1, 0, 0, 0, 0)
|
||||
|
||||
var out ml.Tensor
|
||||
for ky := range kernelH {
|
||||
for kx := range kernelW {
|
||||
patch := padded.Slice(ctx, 0, kx, kx+s0*(outW-1)+1, s0).Contiguous(ctx)
|
||||
patch = patch.Slice(ctx, 1, ky, ky+s1*(outH-1)+1, s1).Contiguous(ctx)
|
||||
|
||||
wk := kernel.Slice(ctx, 0, kx, kx+1, 1).Slice(ctx, 1, ky, ky+1, 1).Contiguous(ctx)
|
||||
if wk.Dim(2) == 1 {
|
||||
wk = wk.Permute(ctx, 0, 1, 3, 2).Contiguous(ctx)
|
||||
} else {
|
||||
wk = wk.Reshape(ctx, 1, 1, wk.Dim(2), wk.Dim(3))
|
||||
}
|
||||
|
||||
term := patch.Mul(ctx, wk)
|
||||
if out == nil {
|
||||
out = term
|
||||
} else {
|
||||
out = out.Add(ctx, term)
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func convOutputLength(inputLength, kernel, stride, padding int) int {
|
||||
if inputLength <= 0 {
|
||||
return 0
|
||||
}
|
||||
return (inputLength+2*padding-kernel)/stride + 1
|
||||
}
|
||||
|
||||
func audioConvPadding(kernel int) int {
|
||||
return (kernel - 1) / 2
|
||||
}
|
||||
|
||||
func parakeetPositionEmbeddings(ctx ml.Context, seqLen, hiddenSize int) ml.Tensor {
|
||||
half := hiddenSize / 2
|
||||
values := make([]float32, hiddenSize*(2*seqLen-1))
|
||||
for posIdx, pos := 0, seqLen-1; posIdx < 2*seqLen-1; posIdx, pos = posIdx+1, pos-1 {
|
||||
for i := range half {
|
||||
invFreq := math.Pow(10000, -float64(2*i)/float64(hiddenSize))
|
||||
angle := float64(pos) * invFreq
|
||||
values[posIdx*hiddenSize+2*i] = float32(math.Sin(angle))
|
||||
values[posIdx*hiddenSize+2*i+1] = float32(math.Cos(angle))
|
||||
}
|
||||
}
|
||||
return ctx.Input().FromFloats(values, hiddenSize, 2*seqLen-1)
|
||||
}
|
||||
|
||||
func relativeShiftParakeet(ctx ml.Context, x ml.Tensor, seqLen, numHeads int) ml.Tensor {
|
||||
positionLen := 2*seqLen - 1
|
||||
x = x.PadExt(ctx, 1, 0, 0, 0, 0, 0, 0, 0)
|
||||
x = x.Reshape(ctx, seqLen, positionLen+1, numHeads)
|
||||
x = x.Slice(ctx, 1, 1, positionLen+1, 1).Contiguous(ctx)
|
||||
x = x.Reshape(ctx, positionLen, seqLen, numHeads)
|
||||
return x.Slice(ctx, 0, 0, seqLen, 1).Contiguous(ctx)
|
||||
}
|
||||
|
||||
func audioAttentionMask(ctx ml.Context, seqLen, validLen int) ml.Tensor {
|
||||
values := make([]float32, seqLen*seqLen)
|
||||
for q := range seqLen {
|
||||
for k := range seqLen {
|
||||
if q >= validLen || k >= validLen {
|
||||
values[q*seqLen+k] = -1e9
|
||||
}
|
||||
}
|
||||
}
|
||||
return ctx.Input().FromFloats(values, seqLen, seqLen, 1)
|
||||
}
|
||||
|
||||
func newAudioModel(c fs.Config) *AudioModel {
|
||||
numLayers := int(c.Uint("audio.block_count", 0))
|
||||
if numLayers == 0 {
|
||||
return nil
|
||||
}
|
||||
return &AudioModel{
|
||||
Layers: make([]AudioLayer, numLayers),
|
||||
AudioOptions: newAudioOptions(c),
|
||||
}
|
||||
}
|
||||
|
||||
func newAudioProjector(c fs.Config) *AudioProjector {
|
||||
if c.Uint("audio.block_count", 0) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &AudioProjector{}
|
||||
}
|
||||
|
||||
func newAudioOptions(c fs.Config) *AudioOptions {
|
||||
hiddenSize := int(c.Uint("audio.embedding_length", 1024))
|
||||
numHeads := int(c.Uint("audio.attention.head_count", 8))
|
||||
headDim := hiddenSize / max(1, numHeads)
|
||||
return &AudioOptions{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
headDim: headDim,
|
||||
intermediateSize: int(c.Uint("audio.feed_forward_length", uint32(hiddenSize*4))),
|
||||
convKernelSize: int(c.Uint("audio.conv_kernel_size", 9)),
|
||||
melBins: int(c.Uint("audio.num_mel_bins", 128)),
|
||||
sampleRate: int(c.Uint("audio.sample_rate", 16000)),
|
||||
subsamplingKernel: int(c.Uint("audio.subsampling_conv_kernel_size", 3)),
|
||||
subsamplingStride: int(c.Uint("audio.subsampling_conv_stride", 2)),
|
||||
scaleInput: c.Bool("audio.scale_input", false),
|
||||
eps: c.Float("audio.attention.layer_norm_epsilon", 1e-5),
|
||||
}
|
||||
}
|
||||
|
||||
func defaultAudioOptions() *AudioOptions {
|
||||
return &AudioOptions{
|
||||
hiddenSize: 1024,
|
||||
numHeads: 8,
|
||||
headDim: 128,
|
||||
intermediateSize: 4096,
|
||||
convKernelSize: 9,
|
||||
melBins: 128,
|
||||
sampleRate: 16000,
|
||||
subsamplingKernel: 3,
|
||||
subsamplingStride: 2,
|
||||
eps: 1e-5,
|
||||
}
|
||||
}
|
||||
239
model/models/nemotronh/model_omni.go
Normal file
239
model/models/nemotronh/model_omni.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"image"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type OmniModel struct {
|
||||
*Model
|
||||
*VisionModel `gguf:"v"`
|
||||
*AudioModel `gguf:"a"`
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
*AudioProjector `gguf:"mm.a"`
|
||||
|
||||
ImageProcessor
|
||||
|
||||
imageTokenID int32
|
||||
imageStartToken int32
|
||||
imageEndToken int32
|
||||
audioTokenID int32
|
||||
}
|
||||
|
||||
var _ model.MultimodalProcessor = (*OmniModel)(nil)
|
||||
|
||||
func NewOmni(c fs.Config) (model.Model, error) {
|
||||
textModel, err := newTextModel(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
imageTokenID := int32(c.Uint("vision.image_token_id", 18))
|
||||
imageStartToken := int32(c.Uint("vision.image_start_token_id", 19))
|
||||
imageEndToken := int32(c.Uint("vision.image_end_token_id", 20))
|
||||
audioTokenID := int32(c.Uint("audio.sound_token_id", 27))
|
||||
|
||||
return &OmniModel{
|
||||
Model: textModel,
|
||||
VisionModel: newVisionModel(c),
|
||||
AudioModel: newAudioModel(c),
|
||||
MultiModalProjector: newMultiModalProjector(c),
|
||||
AudioProjector: newAudioProjector(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
imageTokenID: imageTokenID,
|
||||
imageStartToken: imageStartToken,
|
||||
imageEndToken: imageEndToken,
|
||||
audioTokenID: audioTokenID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *OmniModel) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if isAudioData(multimodalData) {
|
||||
return m.encodeAudioMultimodal(ctx, multimodalData)
|
||||
}
|
||||
|
||||
if m.VisionModel == nil || m.MultiModalProjector == nil || len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tiles, err := m.ImageProcessor.ProcessImage(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mm := make([]input.Multimodal, 0, len(tiles))
|
||||
for _, tile := range tiles {
|
||||
patches := visionPatchGrid{
|
||||
Width: tile.size.X / m.ImageProcessor.patchSize,
|
||||
Height: tile.size.Y / m.ImageProcessor.patchSize,
|
||||
}
|
||||
if patches.Width == 0 || patches.Height == 0 {
|
||||
return nil, errors.New("nemotron_h_omni: invalid resized image dimensions")
|
||||
}
|
||||
|
||||
patchInput := packVisionPatchesCHW(tile.data, tile.size.X, tile.size.Y, m.ImageProcessor.numChannels, m.ImageProcessor.patchSize)
|
||||
visionOutputs := m.VisionModel.ForwardPacked(ctx, patchInput, patches)
|
||||
projected := m.MultiModalProjector.Forward(ctx, visionOutputs, patches)
|
||||
mm = append(mm, input.Multimodal{Tensor: projected})
|
||||
}
|
||||
|
||||
return mm, nil
|
||||
}
|
||||
|
||||
type audioTag struct{}
|
||||
|
||||
func (m *OmniModel) encodeAudioMultimodal(ctx ml.Context, data []byte) ([]input.Multimodal, error) {
|
||||
if m.AudioModel == nil || m.AudioProjector == nil || len(m.AudioModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
samples, err := decodeWAV(data, m.AudioModel.sampleRate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
melData, frames, validFrames, err := computeParakeetMelSpectrogram(samples, m.AudioModel.FeatureExtractor, m.AudioModel.AudioOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
melTensor := ctx.Input().FromFloats(melData, m.AudioModel.melBins, frames)
|
||||
audioOutputs := m.AudioModel.ForwardAudio(ctx, melTensor, validFrames, m.AudioProjector)
|
||||
return []input.Multimodal{{Tensor: audioOutputs, Data: audioTag{}}}, nil
|
||||
}
|
||||
|
||||
func (m *OmniModel) PostLoad() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *OmniModel) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
imageToken := m.imageTokenID
|
||||
if imageToken == 0 {
|
||||
imageToken = 18
|
||||
}
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
continue
|
||||
}
|
||||
|
||||
totalTokens := 0
|
||||
for _, mm := range inp.Multimodal {
|
||||
if mm.Tensor == nil {
|
||||
continue
|
||||
}
|
||||
totalTokens += mm.Tensor.Dim(1)
|
||||
}
|
||||
if totalTokens <= 0 {
|
||||
return nil, errors.New("nemotron_h_omni: multimodal input has no tokens")
|
||||
}
|
||||
|
||||
if _, ok := inp.Multimodal[0].Data.(audioTag); ok {
|
||||
audioToken := m.audioTokenID
|
||||
if audioToken == 0 {
|
||||
audioToken = 27
|
||||
}
|
||||
|
||||
for i, mm := range inp.Multimodal {
|
||||
tokenCount := 0
|
||||
if mm.Tensor != nil {
|
||||
tokenCount = mm.Tensor.Dim(1)
|
||||
}
|
||||
if tokenCount <= 0 {
|
||||
return nil, errors.New("nemotron_h_omni: multimodal input has no tokens")
|
||||
}
|
||||
|
||||
first := &input.Input{Token: audioToken, SameBatch: tokenCount - 1}
|
||||
if i == 0 {
|
||||
first.MultimodalHash = inp.MultimodalHash
|
||||
}
|
||||
first.Multimodal = []input.Multimodal{mm}
|
||||
result = append(result, first)
|
||||
if tokenCount > 1 {
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: audioToken}}, tokenCount-1)...)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if m.imageStartToken > 0 {
|
||||
result = append(result, &input.Input{
|
||||
Token: m.imageStartToken,
|
||||
SameBatch: totalTokens + btoi(m.imageEndToken > 0),
|
||||
})
|
||||
}
|
||||
|
||||
for _, mm := range inp.Multimodal {
|
||||
tokenCount := 0
|
||||
if mm.Tensor != nil {
|
||||
tokenCount = mm.Tensor.Dim(1)
|
||||
}
|
||||
if tokenCount <= 0 {
|
||||
return nil, errors.New("nemotron_h_omni: multimodal input has no tokens")
|
||||
}
|
||||
|
||||
result = append(result, &input.Input{
|
||||
Token: imageToken,
|
||||
Multimodal: []input.Multimodal{mm},
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
})
|
||||
if tokenCount > 1 {
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, tokenCount-1)...)
|
||||
}
|
||||
}
|
||||
|
||||
if m.imageEndToken > 0 {
|
||||
result = append(result, &input.Input{Token: m.imageEndToken})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func btoi(v bool) int {
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *OmniModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
if len(batch.Multimodal) > 0 {
|
||||
hiddenStates = hiddenStates.Duplicate(ctx)
|
||||
}
|
||||
|
||||
for _, mm := range batch.Multimodal {
|
||||
offset := mm.Index
|
||||
for _, multimodal := range mm.Multimodal {
|
||||
if multimodal.Tensor == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
tensor := multimodal.Tensor
|
||||
ctx.Forward(tensor.Copy(ctx, hiddenStates.View(ctx, offset*hiddenStates.Stride(1), tensor.Dim(0)*tensor.Dim(1))))
|
||||
offset += tensor.Dim(1)
|
||||
}
|
||||
}
|
||||
|
||||
return m.forwardLogits(ctx, batch, hiddenStates)
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("nemotron_h_omni", NewOmni)
|
||||
}
|
||||
606
model/models/nemotronh/model_omni_test.go
Normal file
606
model/models/nemotronh/model_omni_test.go
Normal file
@@ -0,0 +1,606 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"image"
|
||||
"image/color"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
backendggml "github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type fakeTensor struct {
|
||||
*backendggml.Tensor
|
||||
dims []int
|
||||
}
|
||||
|
||||
func (t *fakeTensor) Dim(i int) int {
|
||||
return t.dims[i]
|
||||
}
|
||||
|
||||
func setupTestContext(t *testing.T) ml.Context {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), "*.gguf")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := fsggml.WriteGGUF(f, fsggml.KV{"general.architecture": "test"}, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := b.NewContext().Input()
|
||||
t.Cleanup(func() {
|
||||
ctx.Close()
|
||||
b.Close()
|
||||
})
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func TestPostTokenizeImageSpans(t *testing.T) {
|
||||
m := &OmniModel{
|
||||
imageTokenID: 18,
|
||||
imageStartToken: 19,
|
||||
imageEndToken: 20,
|
||||
}
|
||||
|
||||
makeChunk := func() input.Multimodal {
|
||||
return input.Multimodal{Tensor: &fakeTensor{dims: []int{2688, 256, 1, 1}}}
|
||||
}
|
||||
|
||||
in := []*input.Input{
|
||||
{Token: 7},
|
||||
{
|
||||
Multimodal: []input.Multimodal{makeChunk(), makeChunk()},
|
||||
MultimodalHash: 99,
|
||||
},
|
||||
{Token: 8},
|
||||
}
|
||||
|
||||
out, err := m.PostTokenize(in)
|
||||
if err != nil {
|
||||
t.Fatalf("PostTokenize() error = %v", err)
|
||||
}
|
||||
|
||||
if len(out) != 516 {
|
||||
t.Fatalf("len(out) = %d, want 516", len(out))
|
||||
}
|
||||
|
||||
if out[0].Token != 7 {
|
||||
t.Fatalf("out[0].Token = %d, want 7", out[0].Token)
|
||||
}
|
||||
if out[1].Token != 19 {
|
||||
t.Fatalf("out[1].Token = %d, want 19", out[1].Token)
|
||||
}
|
||||
if out[1].SameBatch != 513 {
|
||||
t.Fatalf("out[1].SameBatch = %d, want 513", out[1].SameBatch)
|
||||
}
|
||||
if out[2].Token != 18 || len(out[2].Multimodal) != 1 || out[2].MultimodalHash != 99 || out[2].SameBatch != 0 {
|
||||
t.Fatalf("unexpected first image token: %+v", *out[2])
|
||||
}
|
||||
if out[258].Token != 18 || len(out[258].Multimodal) != 1 || out[258].MultimodalHash != 99 || out[258].SameBatch != 0 {
|
||||
t.Fatalf("unexpected second image token: %+v", *out[258])
|
||||
}
|
||||
if out[514].Token != 20 {
|
||||
t.Fatalf("out[514].Token = %d, want 20", out[514].Token)
|
||||
}
|
||||
if out[515].Token != 8 {
|
||||
t.Fatalf("out[515].Token = %d, want 8", out[515].Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectorPixelShuffleMatchesReferenceV2Order(t *testing.T) {
|
||||
ctx := setupTestContext(t)
|
||||
|
||||
hidden := 2
|
||||
width := 4
|
||||
height := 2
|
||||
values := make([]float32, 0, hidden*width*height)
|
||||
for y := range height {
|
||||
for x := range width {
|
||||
for c := range hidden {
|
||||
values = append(values, float32(100*y+10*x+c))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
got := pixelShuffleVisionOutputs(ctx, ctx.FromFloats(values, hidden, width*height), visionPatchGrid{
|
||||
Width: width,
|
||||
Height: height,
|
||||
}, 2)
|
||||
ctx.Forward(got).Compute(got)
|
||||
|
||||
want := []float32{
|
||||
0, 1, 10, 11, 100, 101, 110, 111,
|
||||
20, 21, 30, 31, 120, 121, 130, 131,
|
||||
}
|
||||
if got.Shape()[0] != 8 || got.Shape()[1] != 2 {
|
||||
t.Fatalf("shape = %v, want [8 2 1]", got.Shape())
|
||||
}
|
||||
gotValues := got.BackendGet()
|
||||
if len(gotValues) != len(want) {
|
||||
t.Fatalf("len(got) = %d, want %d", len(gotValues), len(want))
|
||||
}
|
||||
for i := range want {
|
||||
if gotValues[i] != want[i] {
|
||||
t.Fatalf("got[%d] = %v, want %v", i, gotValues[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostTokenizeAudioSpans(t *testing.T) {
|
||||
m := &OmniModel{
|
||||
audioTokenID: 27,
|
||||
}
|
||||
|
||||
in := []*input.Input{
|
||||
{Token: 7},
|
||||
{
|
||||
Multimodal: []input.Multimodal{{
|
||||
Tensor: &fakeTensor{dims: []int{2688, 13, 1, 1}},
|
||||
Data: audioTag{},
|
||||
}},
|
||||
MultimodalHash: 99,
|
||||
},
|
||||
{Token: 8},
|
||||
}
|
||||
|
||||
out, err := m.PostTokenize(in)
|
||||
if err != nil {
|
||||
t.Fatalf("PostTokenize() error = %v", err)
|
||||
}
|
||||
|
||||
if len(out) != 15 {
|
||||
t.Fatalf("len(out) = %d, want 15", len(out))
|
||||
}
|
||||
if out[0].Token != 7 || out[14].Token != 8 {
|
||||
t.Fatalf("unexpected surrounding tokens: first=%d last=%d", out[0].Token, out[14].Token)
|
||||
}
|
||||
for i := 1; i <= 13; i++ {
|
||||
if out[i].Token != 27 {
|
||||
t.Fatalf("out[%d].Token = %d, want 27", i, out[i].Token)
|
||||
}
|
||||
}
|
||||
if len(out[1].Multimodal) != 1 || out[1].MultimodalHash != 99 {
|
||||
t.Fatalf("first audio token did not carry multimodal payload: %+v", *out[1])
|
||||
}
|
||||
if out[1].SameBatch != 12 {
|
||||
t.Fatalf("first audio token SameBatch = %d, want 12", out[1].SameBatch)
|
||||
}
|
||||
if len(out[2].Multimodal) != 0 {
|
||||
t.Fatalf("only the first audio token should carry multimodal payload: %+v", *out[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParakeetAudioPreprocessShapes(t *testing.T) {
|
||||
data := sineWAV(t, 16000, 440, 1.0)
|
||||
samples, err := decodeWAV(data, 16000)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := len(samples), 16000; got != want {
|
||||
t.Fatalf("sample count = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
mel, frames, validFrames, err := computeParakeetMelSpectrogram(samples, nil, defaultAudioOptions())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if frames != 101 {
|
||||
t.Fatalf("frames = %d, want 101", frames)
|
||||
}
|
||||
if validFrames != 100 {
|
||||
t.Fatalf("validFrames = %d, want 100", validFrames)
|
||||
}
|
||||
if len(mel) != 101*128 {
|
||||
t.Fatalf("len(mel) = %d, want %d", len(mel), 101*128)
|
||||
}
|
||||
lastFrame := mel[100*128 : 101*128]
|
||||
if !slices.Equal(lastFrame, make([]float32, 128)) {
|
||||
t.Fatal("expected masked final frame to be zero")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParakeetAudioPreprocessMatchesIntegrationWAVReference(t *testing.T) {
|
||||
data := integrationAudioWAV(t)
|
||||
samples, err := decodeWAV(data, 16000)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := len(samples), 42083; got != want {
|
||||
t.Fatalf("sample count = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
mel, frames, validFrames, err := computeParakeetMelSpectrogram(samples, nil, defaultAudioOptions())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if frames != 264 {
|
||||
t.Fatalf("frames = %d, want 264", frames)
|
||||
}
|
||||
if validFrames != 263 {
|
||||
t.Fatalf("validFrames = %d, want 263", validFrames)
|
||||
}
|
||||
if len(mel) != 264*128 {
|
||||
t.Fatalf("len(mel) = %d, want %d", len(mel), 264*128)
|
||||
}
|
||||
lastFrame := mel[263*128 : 264*128]
|
||||
if !slices.Equal(lastFrame, make([]float32, 128)) {
|
||||
t.Fatal("expected masked final frame to be zero")
|
||||
}
|
||||
|
||||
// Reference values come from the ParakeetExtractor path used by vLLM:
|
||||
// pre-emphasis, torch.stft(center=True, pad_mode="constant"), Slaney mel
|
||||
// filters, log guard 2^-24, and per-mel normalization over valid frames.
|
||||
checks := map[[2]int]float32{
|
||||
{0, 0}: -1.0855197,
|
||||
{0, 50}: -0.93212974,
|
||||
{1, 10}: -0.9735168,
|
||||
{2, 100}: -0.6533053,
|
||||
{50, 0}: 2.2483668,
|
||||
{50, 127}: -0.3828735,
|
||||
{100, 50}: 2.9742377,
|
||||
{262, 0}: -0.9521758,
|
||||
{262, 127}: -0.4602786,
|
||||
{263, 50}: 0,
|
||||
}
|
||||
for pos, want := range checks {
|
||||
got := mel[pos[0]*128+pos[1]]
|
||||
if math.Abs(float64(got-want)) > 1e-4 {
|
||||
t.Errorf("mel[%d,%d] = %v, want %v", pos[0], pos[1], got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func integrationAudioWAV(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
|
||||
path := filepath.Join("..", "..", "..", "integration", "audio_test_data_test.go")
|
||||
b, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const marker = "const audioEncodingPrompt = `"
|
||||
s := string(b)
|
||||
start := strings.Index(s, marker)
|
||||
if start < 0 {
|
||||
t.Fatal("audioEncodingPrompt marker not found")
|
||||
}
|
||||
start += len(marker)
|
||||
end := strings.Index(s[start:], "`")
|
||||
if end < 0 {
|
||||
t.Fatal("audioEncodingPrompt terminator not found")
|
||||
}
|
||||
|
||||
data, err := base64.StdEncoding.DecodeString(strings.TrimSpace(s[start : start+end]))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestRelativeShiftParakeetMatchesReference(t *testing.T) {
|
||||
ctx := setupTestContext(t)
|
||||
|
||||
seqLen := 3
|
||||
positionLen := 2*seqLen - 1
|
||||
values := make([]float32, seqLen*positionLen)
|
||||
for q := range seqLen {
|
||||
for p := range positionLen {
|
||||
values[q*positionLen+p] = float32(q*10 + p)
|
||||
}
|
||||
}
|
||||
|
||||
x := ctx.FromFloats(values, positionLen, seqLen, 1)
|
||||
got := relativeShiftParakeet(ctx, x, seqLen, 1)
|
||||
ctx.Forward(got).Compute(got)
|
||||
|
||||
want := []float32{
|
||||
2, 3, 4,
|
||||
11, 12, 13,
|
||||
20, 21, 22,
|
||||
}
|
||||
if !slices.Equal(got.BackendGet(), want) {
|
||||
t.Fatalf("relative shift mismatch:\n got %v\nwant %v", got.BackendGet(), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAudioDepthwiseConv2DMatchesReference(t *testing.T) {
|
||||
ctx := setupTestContext(t)
|
||||
|
||||
freq, frames, channels := 4, 5, 2
|
||||
xValues := make([]float32, freq*frames*channels)
|
||||
for i := range xValues {
|
||||
xValues[i] = float32(i)/10 - 1
|
||||
}
|
||||
kernelValues := make([]float32, 3*3*channels)
|
||||
for i := range kernelValues {
|
||||
kernelValues[i] = float32(i)/7 - 1
|
||||
}
|
||||
|
||||
x := ctx.FromFloats(xValues, freq, frames, channels, 1)
|
||||
kernel := ctx.FromFloats(kernelValues, 3, 3, 1, channels)
|
||||
bias := ctx.FromFloats([]float32{0.25, -0.5}, channels)
|
||||
got := audioDepthwiseConv2D(ctx, x, kernel, 2, 2, 1, 1, 1, 1).Add(ctx, bias.Reshape(ctx, 1, 1, -1))
|
||||
ctx.Forward(got).Compute(got)
|
||||
|
||||
want := []float32{
|
||||
0.86428565, 1.3357141,
|
||||
1.2785715, 1.3642857,
|
||||
-0.5928571, -1.7499999,
|
||||
5.4000001, 8.8142853,
|
||||
10.514286, 16.042856,
|
||||
6.6857138, 9.8428574,
|
||||
}
|
||||
assertCloseSlice(t, got.BackendGet(), want, 1e-5)
|
||||
}
|
||||
|
||||
func TestFlattenAudioSubsamplingOutputMatchesReference(t *testing.T) {
|
||||
ctx := setupTestContext(t)
|
||||
|
||||
const (
|
||||
freq = 2
|
||||
frames = 3
|
||||
channels = 2
|
||||
)
|
||||
values := make([]float32, freq*frames*channels)
|
||||
for c := range channels {
|
||||
for t := range frames {
|
||||
for f := range freq {
|
||||
values[f+freq*(t+frames*c)] = float32(100*c + 10*t + f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
got := flattenAudioSubsamplingOutput(ctx, ctx.FromFloats(values, freq, frames, channels, 1))
|
||||
ctx.Forward(got).Compute(got)
|
||||
|
||||
want := []float32{
|
||||
0, 1, 100, 101,
|
||||
10, 11, 110, 111,
|
||||
20, 21, 120, 121,
|
||||
}
|
||||
assertCloseSlice(t, got.BackendGet(), want, 0)
|
||||
}
|
||||
|
||||
func TestAudioDepthwiseConv1DMatchesReference(t *testing.T) {
|
||||
ctx := setupTestContext(t)
|
||||
|
||||
xValues := make([]float32, 2*5)
|
||||
for i := range xValues {
|
||||
xValues[i] = float32(i)/5 - 0.7
|
||||
}
|
||||
kernelValues := make([]float32, 3*2)
|
||||
for i := range kernelValues {
|
||||
kernelValues[i] = float32(i)/3 - 0.5
|
||||
}
|
||||
|
||||
x := ctx.FromFloats(xValues, 2, 5)
|
||||
kernel := ctx.FromFloats(kernelValues, 3, 2)
|
||||
got := audioDepthwiseConv1DSame(ctx, x, kernel, 1)
|
||||
ctx.Forward(got).Compute(got)
|
||||
|
||||
want := []float32{
|
||||
0.066666655, -0.5333333,
|
||||
0.41666666, 0.016666688,
|
||||
0.21666668, 1.0166667,
|
||||
0.01666667, 2.0166664,
|
||||
-0.40000004, 1.2666667,
|
||||
}
|
||||
assertCloseSlice(t, got.BackendGet(), want, 1e-5)
|
||||
}
|
||||
|
||||
func TestAudioSelfAttentionMatchesReference(t *testing.T) {
|
||||
ctx := setupTestContext(t)
|
||||
|
||||
const (
|
||||
hiddenSize = 4
|
||||
numHeads = 2
|
||||
headDim = 2
|
||||
seqLen = 3
|
||||
)
|
||||
xValues := make([]float32, hiddenSize*seqLen)
|
||||
for i := range xValues {
|
||||
xValues[i] = float32(i)/10 - 0.5
|
||||
}
|
||||
|
||||
identity := make([]float32, hiddenSize*hiddenSize)
|
||||
for i := range hiddenSize {
|
||||
identity[i*hiddenSize+i] = 1
|
||||
}
|
||||
linear := func() *nn.Linear {
|
||||
return &nn.Linear{Weight: ctx.FromFloats(identity, hiddenSize, hiddenSize)}
|
||||
}
|
||||
|
||||
attn := &AudioSelfAttention{
|
||||
Query: linear(),
|
||||
Key: linear(),
|
||||
Value: linear(),
|
||||
Output: linear(),
|
||||
RelativeKey: linear(),
|
||||
BiasU: ctx.FromFloats([]float32{0.1, -0.2, 0.3, -0.4}, headDim, numHeads),
|
||||
BiasV: ctx.FromFloats([]float32{-0.05, 0.07, 0.11, -0.13}, headDim, numHeads),
|
||||
}
|
||||
got := attn.Forward(ctx, ctx.FromFloats(xValues, hiddenSize, seqLen), seqLen, &AudioOptions{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
headDim: headDim,
|
||||
})
|
||||
ctx.Forward(got).Compute(got)
|
||||
|
||||
want := []float32{
|
||||
-0.08471569, 0.015284289, 0.05532019, 0.1553202,
|
||||
-0.09135241, 0.008647568, 0.11468154, 0.21468155,
|
||||
-0.019152153, 0.08084783, 0.1733382, 0.2733382,
|
||||
}
|
||||
assertCloseSlice(t, got.BackendGet(), want, 1e-5)
|
||||
}
|
||||
|
||||
func assertCloseSlice(t *testing.T, got, want []float32, tolerance float64) {
|
||||
t.Helper()
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("len(got) = %d, want %d", len(got), len(want))
|
||||
}
|
||||
for i := range want {
|
||||
if math.Abs(float64(got[i]-want[i])) > tolerance {
|
||||
t.Fatalf("got[%d] = %v, want %v\nall got: %v", i, got[i], want[i], got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackPatchesCHW(t *testing.T) {
|
||||
values := []float32{
|
||||
0, 1, 2, 3,
|
||||
4, 5, 6, 7,
|
||||
8, 9, 10, 11,
|
||||
12, 13, 14, 15,
|
||||
100, 101, 102, 103,
|
||||
104, 105, 106, 107,
|
||||
108, 109, 110, 111,
|
||||
112, 113, 114, 115,
|
||||
}
|
||||
|
||||
got := packVisionPatchesCHW(values, 4, 4, 2, 2)
|
||||
want := []float32{
|
||||
0, 1, 4, 5, 100, 101, 104, 105,
|
||||
2, 3, 6, 7, 102, 103, 106, 107,
|
||||
8, 9, 12, 13, 108, 109, 112, 113,
|
||||
10, 11, 14, 15, 110, 111, 114, 115,
|
||||
}
|
||||
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("len(got) = %d, want %d", len(got), len(want))
|
||||
}
|
||||
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("got[%d] = %v, want %v", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResizePositionEmbeddingMatchesReferenceInterpolation(t *testing.T) {
|
||||
values := []float32{
|
||||
0, 10,
|
||||
20, 30,
|
||||
}
|
||||
got := resizePositionEmbedding(values, 1, 2, 2, 3, 3)
|
||||
want := []float32{
|
||||
0, 5, 10,
|
||||
10, 15, 20,
|
||||
20, 25, 30,
|
||||
}
|
||||
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("len(got) = %d, want %d", len(got), len(want))
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("got[%d] = %v, want %v", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicImageProcessorMatchesReferencePatchBudget(t *testing.T) {
|
||||
p := ImageProcessor{
|
||||
imageSize: 512,
|
||||
patchSize: 16,
|
||||
numChannels: 3,
|
||||
minNumPatches: 1024,
|
||||
maxNumPatches: 13312,
|
||||
projectorScale: 2,
|
||||
imageMean: [3]float32{0.48145466, 0.4578275, 0.40821073},
|
||||
imageStd: [3]float32{0.26862954, 0.26130258, 0.27577711},
|
||||
}
|
||||
|
||||
img := image.NewRGBA(image.Rect(0, 0, 400, 250))
|
||||
bounds := img.Bounds()
|
||||
width, height := bounds.Dx(), bounds.Dy()
|
||||
for y := range height {
|
||||
for x := range width {
|
||||
img.SetRGBA(x, y, color.RGBA{R: uint8(x), G: uint8(y), B: 128, A: 255})
|
||||
}
|
||||
}
|
||||
|
||||
tiles, err := p.ProcessImage(img)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessImage() error = %v", err)
|
||||
}
|
||||
if got, want := len(tiles), 1; got != want {
|
||||
t.Fatalf("len(tiles) = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := tiles[0].size, (image.Point{X: 672, Y: 416}); got != want {
|
||||
t.Fatalf("tile size = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := len(tiles[0].data), 3*672*416; got != want {
|
||||
t.Fatalf("tile data len = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func sineWAV(t *testing.T, sampleRate int, frequency float64, seconds float64) []byte {
|
||||
t.Helper()
|
||||
|
||||
samples := int(float64(sampleRate) * seconds)
|
||||
var pcm bytes.Buffer
|
||||
for i := range samples {
|
||||
v := int16(math.Sin(2*math.Pi*frequency*float64(i)/float64(sampleRate)) * 32767)
|
||||
if err := binary.Write(&pcm, binary.LittleEndian, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
out.WriteString("RIFF")
|
||||
if err := binary.Write(&out, binary.LittleEndian, uint32(36+pcm.Len())); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
out.WriteString("WAVE")
|
||||
out.WriteString("fmt ")
|
||||
if err := binary.Write(&out, binary.LittleEndian, uint32(16)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := binary.Write(&out, binary.LittleEndian, uint16(1)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := binary.Write(&out, binary.LittleEndian, uint16(1)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := binary.Write(&out, binary.LittleEndian, uint32(sampleRate)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := binary.Write(&out, binary.LittleEndian, uint32(sampleRate*2)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := binary.Write(&out, binary.LittleEndian, uint16(2)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := binary.Write(&out, binary.LittleEndian, uint16(16)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
out.WriteString("data")
|
||||
if err := binary.Write(&out, binary.LittleEndian, uint32(pcm.Len())); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
out.Write(pcm.Bytes())
|
||||
return out.Bytes()
|
||||
}
|
||||
348
model/models/nemotronh/model_vision.go
Normal file
348
model/models/nemotronh/model_vision.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
const nemotronVisionBatchSize = 1
|
||||
|
||||
type visionPatchGrid struct {
|
||||
Width int
|
||||
Height int
|
||||
}
|
||||
|
||||
type VisionPatchEmbedding struct {
|
||||
*nn.Linear
|
||||
}
|
||||
|
||||
func packVisionPatchesCHW(values []float32, width, height, channels, patchSize int) []float32 {
|
||||
patchesX, patchesY := width/patchSize, height/patchSize
|
||||
patchDim := channels * patchSize * patchSize
|
||||
plane := width * height
|
||||
|
||||
patches := make([]float32, patchDim*patchesX*patchesY)
|
||||
offset := 0
|
||||
for py := range patchesY {
|
||||
for px := range patchesX {
|
||||
for c := range channels {
|
||||
channelBase := c * plane
|
||||
for yy := range patchSize {
|
||||
rowBase := (py*patchSize + yy) * width
|
||||
for xx := range patchSize {
|
||||
patches[offset] = values[channelBase+rowBase+px*patchSize+xx]
|
||||
offset++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return patches
|
||||
}
|
||||
|
||||
func (p *VisionPatchEmbedding) ForwardPacked(ctx ml.Context, patches []float32, patchDim, numPatches int) ml.Tensor {
|
||||
hiddenState := ctx.Input().FromFloats(patches, patchDim, numPatches)
|
||||
hiddenState = hiddenState.Duplicate(ctx)
|
||||
return p.Linear.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
func (p *VisionPatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSize int) ml.Tensor {
|
||||
// Match the RADIO patch generator's exact flattening order: patches are laid
|
||||
// out token-major with each token packed as channel, then patch-row, then
|
||||
// patch-col. This is more explicit than the prior IM2Col path and likely
|
||||
// slower, but it avoids backend-specific packing differences that caused the
|
||||
// converted patch embedder to diverge badly from the reference model.
|
||||
width, height, channels := pixelValues.Dim(0), pixelValues.Dim(1), pixelValues.Dim(2)
|
||||
patchesX, patchesY := width/patchSize, height/patchSize
|
||||
patchDim := channels * patchSize * patchSize
|
||||
|
||||
values := pixelValues.BackendGet()
|
||||
return p.ForwardPacked(ctx, packVisionPatchesCHW(values, width, height, channels, patchSize), patchDim, patchesX*patchesY)
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionOptions) ml.Tensor {
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), nemotronVisionBatchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), nemotronVisionBatchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), nemotronVisionBatchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), nemotronVisionBatchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
return mlp.Down.Forward(ctx, mlp.Up.Forward(ctx, hiddenState).GELU(ctx))
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
LayerNorm1 *nn.LayerNorm `gguf:"ln1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
|
||||
LayerNorm2 *nn.LayerNorm `gguf:"ln2"`
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (l *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
|
||||
residual = hiddenState
|
||||
hiddenState = l.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type VisionOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
imageSize int
|
||||
patchSize int
|
||||
eps float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *VisionPatchEmbedding `gguf:"patch_embd"`
|
||||
PositionEmbedding ml.Tensor `gguf:"position_embd"`
|
||||
ClassEmbedding ml.Tensor `gguf:"cls_embd"`
|
||||
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
|
||||
*VisionOptions
|
||||
|
||||
resizedPositionEmbeddingsMu sync.Mutex
|
||||
resizedPositionEmbeddings map[visionPatchGrid][]float32
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, patches visionPatchGrid) ml.Tensor {
|
||||
numPatches := patches.Width * patches.Height
|
||||
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize)
|
||||
return m.forwardPatchEmbeddings(ctx, hiddenState, patches, numPatches)
|
||||
}
|
||||
|
||||
func (m *VisionModel) ForwardPacked(ctx ml.Context, patchValues []float32, patches visionPatchGrid) ml.Tensor {
|
||||
numPatches := patches.Width * patches.Height
|
||||
patchDim := 0
|
||||
if numPatches > 0 {
|
||||
patchDim = len(patchValues) / numPatches
|
||||
}
|
||||
hiddenState := m.PatchEmbedding.ForwardPacked(ctx, patchValues, patchDim, numPatches)
|
||||
return m.forwardPatchEmbeddings(ctx, hiddenState, patches, numPatches)
|
||||
}
|
||||
|
||||
func (m *VisionModel) forwardPatchEmbeddings(ctx ml.Context, hiddenState ml.Tensor, patches visionPatchGrid, numPatches int) ml.Tensor {
|
||||
if m.PositionEmbedding != nil {
|
||||
positionEmbeddings := m.positionEmbeddings(ctx, hiddenState, patches, numPatches)
|
||||
hiddenState = hiddenState.Add(ctx, positionEmbeddings)
|
||||
}
|
||||
|
||||
if m.ClassEmbedding != nil {
|
||||
numPrefixTokens := m.ClassEmbedding.Dim(1)
|
||||
classEmbeddings := m.ClassEmbedding.Cast(ctx, hiddenState.DType())
|
||||
classEmbeddings = classEmbeddings.Reshape(ctx, classEmbeddings.Dim(0), numPrefixTokens, 1)
|
||||
hiddenState = classEmbeddings.Concat(ctx, hiddenState, 1)
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, m.VisionOptions)
|
||||
}
|
||||
|
||||
if m.ClassEmbedding != nil {
|
||||
hiddenState = hiddenState.Slice(ctx, 1, m.ClassEmbedding.Dim(1), hiddenState.Dim(1), 1)
|
||||
}
|
||||
|
||||
return hiddenState.Reshape(ctx, hiddenState.Dim(0), hiddenState.Dim(1))
|
||||
}
|
||||
|
||||
func (m *VisionModel) positionEmbeddings(ctx ml.Context, hiddenState ml.Tensor, patches visionPatchGrid, numPatches int) ml.Tensor {
|
||||
posTokens := m.PositionEmbedding.Dim(1)
|
||||
source := int(math.Sqrt(float64(posTokens)))
|
||||
|
||||
positionEmbeddings := m.PositionEmbedding.Cast(ctx, hiddenState.DType())
|
||||
if !(source > 0 && source*source == posTokens && (source != patches.Width || source != patches.Height)) {
|
||||
if positionEmbeddings.Dim(1) > numPatches {
|
||||
positionEmbeddings = positionEmbeddings.Slice(ctx, 1, 0, numPatches, 1)
|
||||
}
|
||||
return positionEmbeddings
|
||||
}
|
||||
|
||||
if cached, ok := m.cachePositionEmbeddings(ctx, hiddenState.Dim(0), patches); ok {
|
||||
return ctx.Input().FromFloats(cached, hiddenState.Dim(0), numPatches)
|
||||
}
|
||||
|
||||
// Runner fit/reserve builds worst-case multimodal graphs before weights are
|
||||
// loaded, so the align-corners CPU cache path cannot materialize source
|
||||
// values there. Fall back to a graph-only bilinear resize for reservation;
|
||||
// the loaded inference path above still uses the cached align-corners data.
|
||||
positionEmbeddings = positionEmbeddings.Reshape(ctx, -1, source, source)
|
||||
positionEmbeddings = positionEmbeddings.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
positionEmbeddings = positionEmbeddings.Interpolate(ctx, [4]int{
|
||||
patches.Width,
|
||||
patches.Height,
|
||||
hiddenState.Dim(0),
|
||||
1,
|
||||
}, ml.SamplingModeBilinear)
|
||||
positionEmbeddings = positionEmbeddings.Permute(ctx, 1, 2, 0, 3)
|
||||
return positionEmbeddings.Contiguous(ctx, -1, patches.Width*patches.Height)
|
||||
}
|
||||
|
||||
func (m *VisionModel) cachePositionEmbeddings(ctx ml.Context, hidden int, patches visionPatchGrid) ([]float32, bool) {
|
||||
m.resizedPositionEmbeddingsMu.Lock()
|
||||
cached := m.resizedPositionEmbeddings[patches]
|
||||
m.resizedPositionEmbeddingsMu.Unlock()
|
||||
if cached != nil {
|
||||
return cached, true
|
||||
}
|
||||
|
||||
if len(m.PositionEmbedding.Bytes()) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
posTokens := m.PositionEmbedding.Dim(1)
|
||||
source := int(math.Sqrt(float64(posTokens)))
|
||||
positionEmbeddingsF32 := m.PositionEmbedding.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(positionEmbeddingsF32).Compute(positionEmbeddingsF32)
|
||||
|
||||
// RADIO eval-time CPE uses bilinear interpolation with align_corners=false.
|
||||
// Cache a CPU-resized token-major embedding here for correctness first. This
|
||||
// is likely slower than a native graph path and should be revisited if this
|
||||
// precision vs speed tradeoff is not worthwhile.
|
||||
cached = resizePositionEmbedding(positionEmbeddingsF32.Floats(), hidden, source, source, patches.Width, patches.Height)
|
||||
|
||||
m.resizedPositionEmbeddingsMu.Lock()
|
||||
if m.resizedPositionEmbeddings == nil {
|
||||
m.resizedPositionEmbeddings = make(map[visionPatchGrid][]float32)
|
||||
}
|
||||
if existing := m.resizedPositionEmbeddings[patches]; existing != nil {
|
||||
cached = existing
|
||||
} else {
|
||||
m.resizedPositionEmbeddings[patches] = cached
|
||||
}
|
||||
m.resizedPositionEmbeddingsMu.Unlock()
|
||||
|
||||
return cached, true
|
||||
}
|
||||
|
||||
func resizePositionEmbedding(values []float32, hidden, sourceWidth, sourceHeight, targetWidth, targetHeight int) []float32 {
|
||||
out := make([]float32, hidden*targetWidth*targetHeight)
|
||||
|
||||
scaleX := float64(sourceWidth) / float64(targetWidth)
|
||||
scaleY := float64(sourceHeight) / float64(targetHeight)
|
||||
|
||||
for oy := range targetHeight {
|
||||
srcY := scaleY*(float64(oy)+0.5) - 0.5
|
||||
y0 := int(math.Floor(srcY))
|
||||
y1 := min(y0+1, sourceHeight-1)
|
||||
wy := float32(srcY - float64(y0))
|
||||
y0 = max(y0, 0)
|
||||
|
||||
for ox := range targetWidth {
|
||||
srcX := scaleX*(float64(ox)+0.5) - 0.5
|
||||
x0 := int(math.Floor(srcX))
|
||||
x1 := min(x0+1, sourceWidth-1)
|
||||
wx := float32(srcX - float64(x0))
|
||||
x0 = max(x0, 0)
|
||||
|
||||
t00 := (y0*sourceWidth + x0) * hidden
|
||||
t01 := (y0*sourceWidth + x1) * hidden
|
||||
t10 := (y1*sourceWidth + x0) * hidden
|
||||
t11 := (y1*sourceWidth + x1) * hidden
|
||||
dst := (oy*targetWidth + ox) * hidden
|
||||
for h := range hidden {
|
||||
v00 := values[t00+h]
|
||||
v01 := values[t01+h]
|
||||
v10 := values[t10+h]
|
||||
v11 := values[t11+h]
|
||||
top := v00 + (v01-v00)*wx
|
||||
bot := v10 + (v11-v10)*wx
|
||||
out[dst+h] = top + (bot-top)*wy
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
|
||||
VisionOptions: &VisionOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length", 1280)),
|
||||
numHeads: int(c.Uint("vision.attention.head_count", 16)),
|
||||
imageSize: int(c.Uint("vision.image_size", 512)),
|
||||
patchSize: int(c.Uint("vision.patch_size", 16)),
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type MultiModalProjector struct {
|
||||
Norm *nn.RMSNorm `gguf:"norm"`
|
||||
Linear1 *nn.Linear `gguf:"1"`
|
||||
Linear2 *nn.Linear `gguf:"2"`
|
||||
|
||||
scaleFactor int
|
||||
}
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, patches visionPatchGrid) ml.Tensor {
|
||||
scaleFactor := max(p.scaleFactor, 1)
|
||||
|
||||
// The reference projector first pixel-shuffles the vision grid with
|
||||
// downsample_ratio=0.5 before applying the RMSNorm/MLP. Preserve that exact
|
||||
// v2 packing order here rather than flattening 2x2 neighborhoods via IM2Col.
|
||||
merged := pixelShuffleVisionOutputs(ctx, visionOutputs, patches, scaleFactor)
|
||||
|
||||
merged = p.Norm.Forward(ctx, merged, 1e-5)
|
||||
merged = p.Linear1.Forward(ctx, merged)
|
||||
merged = merged.RELU(ctx)
|
||||
merged = merged.Mul(ctx, merged)
|
||||
return p.Linear2.Forward(ctx, merged)
|
||||
}
|
||||
|
||||
func pixelShuffleVisionOutputs(ctx ml.Context, visionOutputs ml.Tensor, patches visionPatchGrid, scaleFactor int) ml.Tensor {
|
||||
hiddenSize := visionOutputs.Dim(0)
|
||||
scaleFactor = max(scaleFactor, 1)
|
||||
|
||||
merged := visionOutputs.Reshape(ctx, hiddenSize, patches.Width, patches.Height, 1)
|
||||
|
||||
width := patches.Width / scaleFactor
|
||||
height := patches.Height / scaleFactor
|
||||
channels := hiddenSize * scaleFactor
|
||||
|
||||
merged = merged.Reshape(ctx, channels, width, patches.Height, 1)
|
||||
merged = merged.Reshape(ctx, channels, width, scaleFactor, height)
|
||||
merged = merged.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
return merged.Reshape(ctx, channels*scaleFactor, width*height, 1)
|
||||
}
|
||||
|
||||
func newMultiModalProjector(c fs.Config) *MultiModalProjector {
|
||||
return &MultiModalProjector{
|
||||
scaleFactor: int(c.Uint("vision.projector.scale_factor", 2)),
|
||||
}
|
||||
}
|
||||
328
model/models/nemotronh/process_audio.go
Normal file
328
model/models/nemotronh/process_audio.go
Normal file
@@ -0,0 +1,328 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/cmplx"
|
||||
)
|
||||
|
||||
const (
|
||||
parakeetHopLength = 160
|
||||
parakeetNFFT = 512
|
||||
parakeetWinLength = 400
|
||||
parakeetPreemphasis = 0.97
|
||||
parakeetLogZeroGuardValue = 1.0 / (1 << 24)
|
||||
parakeetNormalizeEps = 1e-5
|
||||
)
|
||||
|
||||
func isAudioData(data []byte) bool {
|
||||
return len(data) >= 12 && string(data[:4]) == "RIFF" && string(data[8:12]) == "WAVE"
|
||||
}
|
||||
|
||||
func decodeWAV(data []byte, targetSampleRate int) ([]float32, error) {
|
||||
if len(data) < 12 {
|
||||
return nil, fmt.Errorf("WAV file too short")
|
||||
}
|
||||
if !isAudioData(data) {
|
||||
return nil, fmt.Errorf("not a WAV file")
|
||||
}
|
||||
|
||||
var audioFormat uint16
|
||||
var numChannels, sampleRate, bitsPerSample int
|
||||
var audioData []byte
|
||||
foundFmt := false
|
||||
|
||||
offset := 12
|
||||
for offset+8 <= len(data) {
|
||||
chunkID := string(data[offset : offset+4])
|
||||
chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
|
||||
chunkEnd := min(offset+8+chunkSize, len(data))
|
||||
chunkData := data[offset+8 : chunkEnd]
|
||||
|
||||
switch chunkID {
|
||||
case "fmt ":
|
||||
if len(chunkData) < 16 {
|
||||
return nil, fmt.Errorf("fmt chunk too short")
|
||||
}
|
||||
audioFormat = binary.LittleEndian.Uint16(chunkData[0:2])
|
||||
numChannels = int(binary.LittleEndian.Uint16(chunkData[2:4]))
|
||||
sampleRate = int(binary.LittleEndian.Uint32(chunkData[4:8]))
|
||||
bitsPerSample = int(binary.LittleEndian.Uint16(chunkData[14:16]))
|
||||
if audioFormat == 0xfffe && len(chunkData) >= 26 {
|
||||
audioFormat = binary.LittleEndian.Uint16(chunkData[24:26])
|
||||
}
|
||||
foundFmt = true
|
||||
case "data":
|
||||
audioData = chunkData
|
||||
}
|
||||
|
||||
offset += 8 + chunkSize
|
||||
if chunkSize%2 != 0 {
|
||||
offset++
|
||||
}
|
||||
}
|
||||
|
||||
if !foundFmt {
|
||||
return nil, fmt.Errorf("no fmt chunk found in WAV file")
|
||||
}
|
||||
if audioFormat != 1 && audioFormat != 3 {
|
||||
return nil, fmt.Errorf("unsupported WAV format: %d (need PCM=1 or float=3)", audioFormat)
|
||||
}
|
||||
if audioData == nil {
|
||||
return nil, fmt.Errorf("no data chunk found in WAV file")
|
||||
}
|
||||
if numChannels <= 0 {
|
||||
return nil, fmt.Errorf("invalid WAV channel count: %d", numChannels)
|
||||
}
|
||||
|
||||
samples := decodeWAVSamples(audioData, audioFormat, bitsPerSample, numChannels)
|
||||
if sampleRate != targetSampleRate {
|
||||
samples = resampleLinear(samples, sampleRate, targetSampleRate)
|
||||
}
|
||||
return samples, nil
|
||||
}
|
||||
|
||||
func decodeWAVSamples(data []byte, format uint16, bits, channels int) []float32 {
|
||||
bytesPerSample := bits / 8
|
||||
if bytesPerSample <= 0 || channels <= 0 {
|
||||
return nil
|
||||
}
|
||||
totalSamples := len(data) / (bytesPerSample * channels)
|
||||
mono := make([]float32, totalSamples)
|
||||
|
||||
for i := range totalSamples {
|
||||
var sum float64
|
||||
for ch := range channels {
|
||||
off := (i*channels + ch) * bytesPerSample
|
||||
if off+bytesPerSample > len(data) {
|
||||
break
|
||||
}
|
||||
switch {
|
||||
case format == 1 && bits == 16:
|
||||
v := int16(binary.LittleEndian.Uint16(data[off : off+2]))
|
||||
sum += float64(v) / 32768.0
|
||||
case format == 1 && bits == 32:
|
||||
v := int32(binary.LittleEndian.Uint32(data[off : off+4]))
|
||||
sum += float64(v) / 2147483648.0
|
||||
case format == 1 && bits == 24:
|
||||
v := int32(data[off]) | int32(data[off+1])<<8 | int32(data[off+2])<<16
|
||||
if v&0x800000 != 0 {
|
||||
v |= ^0xffffff
|
||||
}
|
||||
sum += float64(v) / 8388608.0
|
||||
case format == 3 && bits == 32:
|
||||
sum += float64(math.Float32frombits(binary.LittleEndian.Uint32(data[off : off+4])))
|
||||
case format == 1 && bits == 8:
|
||||
sum += (float64(data[off]) - 128.0) / 128.0
|
||||
}
|
||||
}
|
||||
mono[i] = float32(sum / float64(channels))
|
||||
}
|
||||
return mono
|
||||
}
|
||||
|
||||
func resampleLinear(samples []float32, fromRate, toRate int) []float32 {
|
||||
if fromRate <= 0 || toRate <= 0 || len(samples) == 0 {
|
||||
return samples
|
||||
}
|
||||
n := int(float64(len(samples)) / float64(fromRate) * float64(toRate))
|
||||
if n <= 1 {
|
||||
return slicesCloneOne(samples)
|
||||
}
|
||||
out := make([]float32, n)
|
||||
for i := range n {
|
||||
pos := float64(i) * float64(len(samples)-1) / float64(n-1)
|
||||
idx := int(pos)
|
||||
frac := float32(pos - float64(idx))
|
||||
if idx+1 < len(samples) {
|
||||
out[i] = samples[idx]*(1-frac) + samples[idx+1]*frac
|
||||
} else {
|
||||
out[i] = samples[idx]
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func slicesCloneOne(samples []float32) []float32 {
|
||||
if len(samples) == 0 {
|
||||
return nil
|
||||
}
|
||||
return []float32{samples[0]}
|
||||
}
|
||||
|
||||
func computeParakeetMelSpectrogram(samples []float32, extractor *AudioFeatureExtractor, opts *AudioOptions) ([]float32, int, int, error) {
|
||||
if len(samples) == 0 {
|
||||
return nil, 0, 0, fmt.Errorf("audio too short to encode")
|
||||
}
|
||||
if opts == nil {
|
||||
opts = defaultAudioOptions()
|
||||
}
|
||||
|
||||
melBins := opts.melBins
|
||||
freqBins := parakeetNFFT/2 + 1
|
||||
window, melFilters := extractor.windowAndFilters(melBins, freqBins, opts.sampleRate)
|
||||
if len(window) != parakeetWinLength {
|
||||
return nil, 0, 0, fmt.Errorf("invalid Parakeet window length: %d", len(window))
|
||||
}
|
||||
if len(melFilters) != melBins*freqBins {
|
||||
return nil, 0, 0, fmt.Errorf("invalid Parakeet mel filter shape: %d", len(melFilters))
|
||||
}
|
||||
|
||||
emphasized := make([]float32, len(samples))
|
||||
emphasized[0] = samples[0]
|
||||
for i := 1; i < len(samples); i++ {
|
||||
emphasized[i] = samples[i] - parakeetPreemphasis*samples[i-1]
|
||||
}
|
||||
|
||||
frames := len(samples)/parakeetHopLength + 1
|
||||
validFrames := max(1, len(samples)/parakeetHopLength)
|
||||
if validFrames > frames {
|
||||
validFrames = frames
|
||||
}
|
||||
|
||||
result := make([]float32, frames*melBins)
|
||||
fftInput := make([]complex128, parakeetNFFT)
|
||||
winOffset := (parakeetNFFT - parakeetWinLength) / 2
|
||||
centerPad := parakeetNFFT / 2
|
||||
|
||||
for frame := range frames {
|
||||
for i := range parakeetNFFT {
|
||||
fftInput[i] = 0
|
||||
}
|
||||
for i := range parakeetWinLength {
|
||||
src := frame*parakeetHopLength + i + winOffset - centerPad
|
||||
if src >= 0 && src < len(emphasized) {
|
||||
fftInput[i+winOffset] = complex(float64(emphasized[src])*float64(window[i]), 0)
|
||||
}
|
||||
}
|
||||
|
||||
fft(fftInput)
|
||||
|
||||
for mel := range melBins {
|
||||
var v float64
|
||||
filterOffset := mel * freqBins
|
||||
for freq := range freqBins {
|
||||
mag := cmplx.Abs(fftInput[freq])
|
||||
v += float64(melFilters[filterOffset+freq]) * mag * mag
|
||||
}
|
||||
result[frame*melBins+mel] = float32(math.Log(v + parakeetLogZeroGuardValue))
|
||||
}
|
||||
}
|
||||
|
||||
for mel := range melBins {
|
||||
var sum float64
|
||||
for frame := range validFrames {
|
||||
sum += float64(result[frame*melBins+mel])
|
||||
}
|
||||
mean := sum / float64(validFrames)
|
||||
|
||||
var variance float64
|
||||
for frame := range validFrames {
|
||||
d := float64(result[frame*melBins+mel]) - mean
|
||||
variance += d * d
|
||||
}
|
||||
denom := max(1, validFrames-1)
|
||||
std := math.Sqrt(variance / float64(denom))
|
||||
|
||||
for frame := range frames {
|
||||
idx := frame*melBins + mel
|
||||
if frame >= validFrames {
|
||||
result[idx] = 0
|
||||
continue
|
||||
}
|
||||
result[idx] = float32((float64(result[idx]) - mean) / (std + parakeetNormalizeEps))
|
||||
}
|
||||
}
|
||||
|
||||
return result, frames, validFrames, nil
|
||||
}
|
||||
|
||||
func defaultParakeetWindow() []float32 {
|
||||
window := make([]float32, parakeetWinLength)
|
||||
for i := range window {
|
||||
window[i] = float32(0.5 - 0.5*math.Cos(2*math.Pi*float64(i)/float64(parakeetWinLength-1)))
|
||||
}
|
||||
return window
|
||||
}
|
||||
|
||||
func buildSlaneyMelFilterBank(numFreqBins, numMels int, sampleRate int) []float32 {
|
||||
hzToMel := func(f float64) float64 {
|
||||
if f < 1000 {
|
||||
return 3 * f / 200
|
||||
}
|
||||
return 15 + math.Log(f/1000)*27/math.Log(6.4)
|
||||
}
|
||||
melToHz := func(m float64) float64 {
|
||||
if m < 15 {
|
||||
return 200 * m / 3
|
||||
}
|
||||
return 1000 * math.Exp(math.Log(6.4)*(m-15)/27)
|
||||
}
|
||||
|
||||
minMel := hzToMel(0)
|
||||
maxMel := hzToMel(float64(sampleRate) / 2)
|
||||
mels := make([]float64, numMels+2)
|
||||
freqs := make([]float64, numMels+2)
|
||||
for i := range mels {
|
||||
mels[i] = minMel + (maxMel-minMel)*float64(i)/float64(numMels+1)
|
||||
freqs[i] = melToHz(mels[i])
|
||||
}
|
||||
|
||||
fftFreqs := make([]float64, numFreqBins)
|
||||
for i := range fftFreqs {
|
||||
fftFreqs[i] = float64(i) * float64(sampleRate) / float64(parakeetNFFT)
|
||||
}
|
||||
|
||||
filters := make([]float32, numMels*numFreqBins)
|
||||
for mel := range numMels {
|
||||
left, center, right := freqs[mel], freqs[mel+1], freqs[mel+2]
|
||||
enorm := 2.0 / (right - left)
|
||||
for freq, fftFreq := range fftFreqs {
|
||||
var lower, upper float64
|
||||
if center > left {
|
||||
lower = (fftFreq - left) / (center - left)
|
||||
}
|
||||
if right > center {
|
||||
upper = (right - fftFreq) / (right - center)
|
||||
}
|
||||
v := math.Max(0, math.Min(lower, upper))
|
||||
filters[mel*numFreqBins+freq] = float32(v * enorm)
|
||||
}
|
||||
}
|
||||
return filters
|
||||
}
|
||||
|
||||
func fft(x []complex128) {
|
||||
n := len(x)
|
||||
if n <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
j := 0
|
||||
for i := 1; i < n; i++ {
|
||||
bit := n >> 1
|
||||
for j&bit != 0 {
|
||||
j ^= bit
|
||||
bit >>= 1
|
||||
}
|
||||
j ^= bit
|
||||
if i < j {
|
||||
x[i], x[j] = x[j], x[i]
|
||||
}
|
||||
}
|
||||
|
||||
for size := 2; size <= n; size <<= 1 {
|
||||
halfSize := size / 2
|
||||
w := complex(math.Cos(2*math.Pi/float64(size)), -math.Sin(2*math.Pi/float64(size)))
|
||||
for start := 0; start < n; start += size {
|
||||
wn := complex(1, 0)
|
||||
for k := range halfSize {
|
||||
t := wn * x[start+k+halfSize]
|
||||
x[start+k+halfSize] = x[start+k] - t
|
||||
x[start+k] = x[start+k] + t
|
||||
wn *= w
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
498
model/parsers/laguna.go
Normal file
498
model/parsers/laguna.go
Normal file
@@ -0,0 +1,498 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
lagunaThinkingOpenTag = "<think>"
|
||||
lagunaThinkingCloseTag = "</think>"
|
||||
lagunaToolCallOpenTag = "<tool_call>"
|
||||
lagunaToolCallCloseTag = "</tool_call>"
|
||||
lagunaUserOpenTag = "<user>"
|
||||
lagunaUserCloseTag = "</user>"
|
||||
)
|
||||
|
||||
type lagunaParserState int
|
||||
|
||||
const (
|
||||
lagunaParserStateThinking lagunaParserState = iota
|
||||
lagunaParserStateContent
|
||||
lagunaParserStateTool
|
||||
)
|
||||
|
||||
type LagunaParser struct {
|
||||
state lagunaParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
callIndex int
|
||||
thinkingEnabled bool
|
||||
thinkingSuppressed bool
|
||||
allowLeadingThinkOpen bool
|
||||
}
|
||||
|
||||
func (p *LagunaParser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *LagunaParser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *LagunaParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.callIndex = 0
|
||||
p.buffer.Reset()
|
||||
p.thinkingEnabled = thinkValue == nil || thinkValue.Bool()
|
||||
p.thinkingSuppressed = thinkValue != nil && !thinkValue.Bool()
|
||||
p.state = lagunaParserStateContent
|
||||
p.allowLeadingThinkOpen = false
|
||||
return tools
|
||||
}
|
||||
|
||||
func (p *LagunaParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
var contentSB, thinkingSB strings.Builder
|
||||
|
||||
for {
|
||||
progress := false
|
||||
switch p.state {
|
||||
case lagunaParserStateThinking:
|
||||
progress, thinking = p.consumeThinking(done)
|
||||
if p.thinkingEnabled {
|
||||
thinkingSB.WriteString(thinking)
|
||||
}
|
||||
case lagunaParserStateContent:
|
||||
var parsedCalls []api.ToolCall
|
||||
progress, content, parsedCalls, err = p.consumeContent(done)
|
||||
if err != nil {
|
||||
return "", "", nil, err
|
||||
}
|
||||
contentSB.WriteString(content)
|
||||
calls = append(calls, parsedCalls...)
|
||||
case lagunaParserStateTool:
|
||||
var call api.ToolCall
|
||||
progress, call, err = p.consumeTool(done)
|
||||
if err != nil {
|
||||
return "", "", nil, err
|
||||
}
|
||||
if progress {
|
||||
calls = append(calls, call)
|
||||
}
|
||||
}
|
||||
if !progress {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return contentSB.String(), thinkingSB.String(), calls, nil
|
||||
}
|
||||
|
||||
func (p *LagunaParser) consumeThinking(done bool) (bool, string) {
|
||||
acc := p.buffer.String()
|
||||
if p.allowLeadingThinkOpen {
|
||||
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, lagunaThinkingOpenTag) {
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(strings.TrimLeftFunc(strings.TrimPrefix(trimmed, lagunaThinkingOpenTag), unicode.IsSpace))
|
||||
p.allowLeadingThinkOpen = false
|
||||
return true, ""
|
||||
}
|
||||
if strings.HasPrefix(lagunaThinkingOpenTag, trimmed) && !done {
|
||||
return false, ""
|
||||
}
|
||||
p.allowLeadingThinkOpen = false
|
||||
}
|
||||
|
||||
if idx := strings.Index(acc, lagunaThinkingCloseTag); idx != -1 {
|
||||
thinking := acc[:idx]
|
||||
after := strings.TrimLeftFunc(acc[idx+len(lagunaThinkingCloseTag):], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = lagunaParserStateContent
|
||||
return true, thinking
|
||||
}
|
||||
if idx := strings.Index(acc, lagunaToolCallOpenTag); idx != -1 {
|
||||
thinking := strings.TrimRightFunc(acc[:idx], unicode.IsSpace)
|
||||
after := acc[idx+len(lagunaToolCallOpenTag):]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = lagunaParserStateTool
|
||||
return true, thinking
|
||||
}
|
||||
if done {
|
||||
p.buffer.Reset()
|
||||
p.state = lagunaParserStateContent
|
||||
return acc != "", acc
|
||||
}
|
||||
|
||||
overlapLen := max(overlap(acc, lagunaThinkingCloseTag), overlap(acc, lagunaToolCallOpenTag))
|
||||
trailingLen := trailingWhitespaceLen(acc)
|
||||
keep := max(overlapLen, trailingLen)
|
||||
if keep > 0 && keep < len(acc) {
|
||||
emit := acc[:len(acc)-keep]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(acc[len(acc)-keep:])
|
||||
return emit != "", emit
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
func (p *LagunaParser) consumeContent(done bool) (bool, string, []api.ToolCall, error) {
|
||||
acc := p.buffer.String()
|
||||
if p.thinkingEnabled || p.thinkingSuppressed {
|
||||
if idx := strings.Index(acc, lagunaThinkingOpenTag); idx != -1 {
|
||||
content := acc[:idx]
|
||||
after := strings.TrimLeftFunc(acc[idx+len(lagunaThinkingOpenTag):], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = lagunaParserStateThinking
|
||||
p.allowLeadingThinkOpen = false
|
||||
return true, content, nil, nil
|
||||
}
|
||||
if !done {
|
||||
overlapLen := overlap(acc, lagunaThinkingOpenTag)
|
||||
if overlapLen > 0 && overlapLen < len(acc) {
|
||||
content := acc[:len(acc)-overlapLen]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(acc[len(acc)-overlapLen:])
|
||||
return content != "", content, nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if p.thinkingEnabled {
|
||||
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, lagunaThinkingCloseTag) {
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(strings.TrimLeftFunc(strings.TrimPrefix(trimmed, lagunaThinkingCloseTag), unicode.IsSpace))
|
||||
return true, "", nil, nil
|
||||
}
|
||||
if strings.HasPrefix(lagunaThinkingCloseTag, trimmed) && !done {
|
||||
return false, "", nil, nil
|
||||
}
|
||||
}
|
||||
if p.thinkingSuppressed {
|
||||
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, lagunaThinkingCloseTag) {
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(strings.TrimLeftFunc(strings.TrimPrefix(trimmed, lagunaThinkingCloseTag), unicode.IsSpace))
|
||||
return true, "", nil, nil
|
||||
}
|
||||
if strings.HasPrefix(lagunaThinkingCloseTag, trimmed) && !done {
|
||||
return false, "", nil, nil
|
||||
}
|
||||
}
|
||||
if idx := strings.Index(acc, lagunaToolCallOpenTag); idx != -1 {
|
||||
content := strings.TrimRightFunc(acc[:idx], unicode.IsSpace)
|
||||
after := acc[idx+len(lagunaToolCallOpenTag):]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = lagunaParserStateTool
|
||||
return true, content, nil, nil
|
||||
}
|
||||
if idx := strings.Index(acc, lagunaUserOpenTag); idx != -1 && len(p.tools) > 0 {
|
||||
before := strings.TrimRightFunc(acc[:idx], unicode.IsSpace)
|
||||
afterOpen := acc[idx+len(lagunaUserOpenTag):]
|
||||
if closeIdx := strings.Index(afterOpen, lagunaUserCloseTag); closeIdx != -1 {
|
||||
raw := afterOpen[:closeIdx]
|
||||
if call, ok := p.parseToolAlias(raw); ok {
|
||||
after := strings.TrimLeftFunc(afterOpen[closeIdx+len(lagunaUserCloseTag):], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return true, before, []api.ToolCall{call}, nil
|
||||
}
|
||||
} else if !done {
|
||||
if idx > 0 {
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(acc[idx:])
|
||||
return true, before, nil, nil
|
||||
}
|
||||
return false, "", nil, nil
|
||||
}
|
||||
}
|
||||
if len(p.tools) > 0 {
|
||||
if progress, content, call, ok, err := p.consumeStandaloneJSONTool(done); ok || err != nil {
|
||||
if err != nil {
|
||||
return false, "", nil, err
|
||||
}
|
||||
if progress {
|
||||
return true, content, []api.ToolCall{call}, nil
|
||||
}
|
||||
return false, "", nil, nil
|
||||
}
|
||||
}
|
||||
if done {
|
||||
p.buffer.Reset()
|
||||
return acc != "", acc, nil, nil
|
||||
}
|
||||
overlapLen := max(overlap(acc, lagunaToolCallOpenTag), overlap(acc, lagunaUserOpenTag))
|
||||
if p.thinkingEnabled || p.thinkingSuppressed {
|
||||
overlapLen = max(overlapLen, overlap(acc, lagunaThinkingOpenTag))
|
||||
}
|
||||
if p.thinkingSuppressed {
|
||||
overlapLen = max(overlapLen, overlap(acc, lagunaThinkingCloseTag))
|
||||
}
|
||||
trailingLen := trailingWhitespaceLen(acc)
|
||||
keep := max(overlapLen, trailingLen)
|
||||
if keep > 0 && keep < len(acc) {
|
||||
emit := acc[:len(acc)-keep]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(acc[len(acc)-keep:])
|
||||
return emit != "", emit, nil, nil
|
||||
}
|
||||
if keep == 0 && acc != "" {
|
||||
p.buffer.Reset()
|
||||
return true, acc, nil, nil
|
||||
}
|
||||
return false, "", nil, nil
|
||||
}
|
||||
|
||||
func (p *LagunaParser) consumeStandaloneJSONTool(done bool) (progress bool, content string, call api.ToolCall, ok bool, err error) {
|
||||
acc := p.buffer.String()
|
||||
jsonIdx := strings.Index(acc, "{")
|
||||
if jsonIdx == -1 {
|
||||
return false, "", api.ToolCall{}, false, nil
|
||||
}
|
||||
|
||||
before := strings.TrimRightFunc(acc[:jsonIdx], unicode.IsSpace)
|
||||
raw := strings.TrimLeftFunc(acc[jsonIdx:], unicode.IsSpace)
|
||||
if !lagunaLooksLikeJSONToolCall(raw, done) {
|
||||
return false, "", api.ToolCall{}, false, nil
|
||||
}
|
||||
|
||||
if !done && !json.Valid([]byte(strings.TrimSpace(raw))) {
|
||||
if before != "" {
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(acc[jsonIdx:])
|
||||
return true, before, api.ToolCall{}, true, nil
|
||||
}
|
||||
return false, "", api.ToolCall{}, true, nil
|
||||
}
|
||||
|
||||
call, err = parseLagunaToolCall(raw, p.tools)
|
||||
if err != nil {
|
||||
return false, "", api.ToolCall{}, true, err
|
||||
}
|
||||
call.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
p.buffer.Reset()
|
||||
p.state = lagunaParserStateContent
|
||||
return true, before, call, true, nil
|
||||
}
|
||||
|
||||
func lagunaLooksLikeJSONToolCall(raw string, done bool) bool {
|
||||
trimmed := strings.TrimLeftFunc(raw, unicode.IsSpace)
|
||||
if !strings.HasPrefix(trimmed, "{") {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(trimmed, `"name"`) || strings.Contains(trimmed, `"arguments"`) {
|
||||
return true
|
||||
}
|
||||
if done {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(trimmed, `{"`) || strings.HasPrefix(trimmed, "{\n") || strings.HasPrefix(trimmed, "{\r\n")
|
||||
}
|
||||
|
||||
func (p *LagunaParser) parseToolAlias(raw string) (api.ToolCall, bool) {
|
||||
raw = cleanLagunaToolCallRaw(raw)
|
||||
name, ok := lagunaToolCallName(raw)
|
||||
if !ok {
|
||||
return api.ToolCall{}, false
|
||||
}
|
||||
if _, ok := lagunaResolveToolName(name, p.tools); !ok {
|
||||
return api.ToolCall{}, false
|
||||
}
|
||||
call, err := parseLagunaToolCall(raw, p.tools)
|
||||
if err != nil {
|
||||
return api.ToolCall{}, false
|
||||
}
|
||||
call.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
return call, true
|
||||
}
|
||||
|
||||
func lagunaResolveToolName(name string, tools []api.Tool) (string, bool) {
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == name {
|
||||
return name, true
|
||||
}
|
||||
}
|
||||
|
||||
aliases := map[string]string{
|
||||
"read_file": "read",
|
||||
"write_file": "write",
|
||||
"edit_file": "edit",
|
||||
"web_fetch": "webfetch",
|
||||
}
|
||||
if alias, ok := aliases[name]; ok {
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == alias {
|
||||
return alias, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return name, false
|
||||
}
|
||||
|
||||
func cleanLagunaToolCallRaw(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
for strings.HasPrefix(raw, lagunaToolCallOpenTag) {
|
||||
raw = strings.TrimSpace(strings.TrimPrefix(raw, lagunaToolCallOpenTag))
|
||||
}
|
||||
if idx := strings.Index(raw, lagunaToolCallCloseTag); idx != -1 {
|
||||
raw = strings.TrimSpace(raw[:idx])
|
||||
}
|
||||
if idx := strings.Index(raw, lagunaToolCallOpenTag); idx != -1 {
|
||||
before := strings.TrimSpace(raw[:idx])
|
||||
if before != "" {
|
||||
return before
|
||||
}
|
||||
raw = strings.TrimSpace(raw[idx+len(lagunaToolCallOpenTag):])
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
func lagunaToolCallName(raw string) (string, bool) {
|
||||
raw = cleanLagunaToolCallRaw(raw)
|
||||
if strings.HasPrefix(raw, "{") {
|
||||
var parsed struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
|
||||
return "", false
|
||||
}
|
||||
name := strings.TrimSpace(parsed.Name)
|
||||
return name, name != ""
|
||||
}
|
||||
|
||||
nameEnd := strings.Index(raw, "<arg_key>")
|
||||
if nameEnd < 0 {
|
||||
nameEnd = strings.Index(raw, "{")
|
||||
}
|
||||
if nameEnd < 0 {
|
||||
nameEnd = strings.IndexAny(raw, "\r\n")
|
||||
}
|
||||
if nameEnd < 0 {
|
||||
nameEnd = len(raw)
|
||||
}
|
||||
name := strings.TrimSpace(raw[:nameEnd])
|
||||
return name, name != ""
|
||||
}
|
||||
|
||||
func (p *LagunaParser) consumeTool(done bool) (bool, api.ToolCall, error) {
|
||||
acc := p.buffer.String()
|
||||
if idx := strings.Index(acc, lagunaToolCallCloseTag); idx != -1 {
|
||||
raw := acc[:idx]
|
||||
after := strings.TrimLeftFunc(acc[idx+len(lagunaToolCallCloseTag):], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = lagunaParserStateContent
|
||||
call, err := parseLagunaToolCall(raw, p.tools)
|
||||
if err != nil {
|
||||
return false, api.ToolCall{}, err
|
||||
}
|
||||
call.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
return true, call, nil
|
||||
}
|
||||
if done && strings.TrimSpace(acc) != "" {
|
||||
p.buffer.Reset()
|
||||
p.state = lagunaParserStateContent
|
||||
call, err := parseLagunaToolCall(acc, p.tools)
|
||||
if err != nil {
|
||||
return false, api.ToolCall{}, err
|
||||
}
|
||||
call.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
return true, call, nil
|
||||
}
|
||||
return false, api.ToolCall{}, nil
|
||||
}
|
||||
|
||||
var lagunaArgRE = regexp.MustCompile(`(?s)<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>`)
|
||||
|
||||
func parseLagunaToolCall(raw string, tools []api.Tool) (api.ToolCall, error) {
|
||||
raw = cleanLagunaToolCallRaw(raw)
|
||||
if strings.HasPrefix(raw, "{") {
|
||||
var parsed struct {
|
||||
Name string `json:"name"`
|
||||
Arguments api.ToolCallFunctionArguments `json:"arguments"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse Laguna JSON tool call: %w", err)
|
||||
}
|
||||
if parsed.Name == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty Laguna tool call name")
|
||||
}
|
||||
if name, ok := lagunaResolveToolName(parsed.Name, tools); ok {
|
||||
parsed.Name = name
|
||||
}
|
||||
return api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: parsed.Name,
|
||||
Arguments: parsed.Arguments,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
nameEnd := strings.Index(raw, "<arg_key>")
|
||||
name := raw
|
||||
argsText := ""
|
||||
if nameEnd >= 0 {
|
||||
name = raw[:nameEnd]
|
||||
argsText = raw[nameEnd:]
|
||||
} else if jsonStart := strings.Index(raw, "{"); jsonStart >= 0 {
|
||||
name = raw[:jsonStart]
|
||||
argsText = raw[jsonStart:]
|
||||
}
|
||||
name = strings.TrimSpace(name)
|
||||
if resolved, ok := lagunaResolveToolName(name, tools); ok {
|
||||
name = resolved
|
||||
}
|
||||
|
||||
var matchedTool *api.Tool
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == name {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
call := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
if strings.HasPrefix(strings.TrimSpace(argsText), "{") {
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(argsText)), &call.Function.Arguments); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse Laguna JSON tool call arguments: %w", err)
|
||||
}
|
||||
return call, nil
|
||||
}
|
||||
for _, match := range lagunaArgRE.FindAllStringSubmatch(argsText, -1) {
|
||||
key := strings.TrimSpace(match[1])
|
||||
value := match[2]
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
|
||||
if len(prop.AnyOf) > 0 {
|
||||
for _, anyOfProp := range prop.AnyOf {
|
||||
paramType = append(paramType, anyOfProp.Type...)
|
||||
}
|
||||
} else {
|
||||
paramType = prop.Type
|
||||
}
|
||||
}
|
||||
}
|
||||
call.Function.Arguments.Set(key, parseValue(value, paramType))
|
||||
}
|
||||
return call, nil
|
||||
}
|
||||
484
model/parsers/laguna_test.go
Normal file
484
model/parsers/laguna_test.go
Normal file
@@ -0,0 +1,484 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func lagunaTestTools() []api.Tool {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
props.Set("days", api.ToolProperty{Type: api.PropertyType{"integer"}})
|
||||
return []api.Tool{{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
func TestLagunaParserToolCall(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
if parser == nil {
|
||||
t.Fatal("expected laguna parser")
|
||||
}
|
||||
if !parser.HasToolSupport() || !parser.HasThinkingSupport() {
|
||||
t.Fatal("laguna parser should advertise tools and thinking")
|
||||
}
|
||||
|
||||
parser.Init(lagunaTestTools(), nil, nil)
|
||||
content, thinking, calls, err := parser.Add("<tool_call>get_weather\n<arg_key>location</arg_key>\n<arg_value>Paris</arg_value>\n<arg_key>days</arg_key>\n<arg_value>3</arg_value>\n</tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "" || thinking != "" {
|
||||
t.Fatalf("content=%q thinking=%q, want empty", content, thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls=%d, want 1", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
|
||||
}
|
||||
if got, _ := calls[0].Function.Arguments.Get("location"); got != "Paris" {
|
||||
t.Fatalf("location=%v, want Paris", got)
|
||||
}
|
||||
if got, _ := calls[0].Function.Arguments.Get("days"); got != 3 {
|
||||
t.Fatalf("days=%v, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserJSONToolCall(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(lagunaTestTools(), nil, nil)
|
||||
|
||||
_, _, calls, err := parser.Add("<tool_call>\n{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\",\"days\":3}}\n</tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls=%d, want 1", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
|
||||
}
|
||||
if got, _ := calls[0].Function.Arguments.Get("location"); got != "Paris" {
|
||||
t.Fatalf("location=%v, want Paris", got)
|
||||
}
|
||||
if got, _ := calls[0].Function.Arguments.Get("days"); got != float64(3) {
|
||||
t.Fatalf("days=%v, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserStandaloneJSONToolCall(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(lagunaTestTools(), nil, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add("{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\",\"days\":3}}", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "" || thinking != "" {
|
||||
t.Fatalf("content=%q thinking=%q", content, thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls=%d, want 1", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserStandaloneJSONToolCallAfterLeadingContent(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(lagunaTestTools(), nil, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add("Let me call the weather tool.\n{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "Let me call the weather tool." || thinking != "" {
|
||||
t.Fatalf("content=%q thinking=%q", content, thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls=%d, want 1", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserStreamingStandaloneJSONToolCall(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(lagunaTestTools(), nil, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add("{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco,", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "" || thinking != "" || len(calls) != 0 {
|
||||
t.Fatalf("first chunk content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add(" CA\"}}", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "" || thinking != "" || len(calls) != 1 {
|
||||
t.Fatalf("second chunk content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
|
||||
}
|
||||
if got, _ := calls[0].Function.Arguments.Get("location"); got != "San Francisco, CA" {
|
||||
t.Fatalf("location=%v, want San Francisco, CA", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserNameLineJSONToolCall(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(lagunaTestTools(), nil, nil)
|
||||
|
||||
_, _, calls, err := parser.Add("<tool_call>get_weather\n{\"location\":\"San Francisco\"}</tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls=%d, want 1", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
|
||||
}
|
||||
if got, _ := calls[0].Function.Arguments.Get("location"); got != "San Francisco" {
|
||||
t.Fatalf("location=%v, want San Francisco", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserNormalizesCommonToolAliases(t *testing.T) {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("path", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{{
|
||||
Function: api.ToolFunction{
|
||||
Name: "read",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}}
|
||||
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(tools, nil, nil)
|
||||
|
||||
_, _, calls, err := parser.Add("<tool_call>\n{\"name\":\"read_file\",\"arguments\":{\"path\":\"./go.mod\"}}\n</tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls=%d, want 1", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "read" {
|
||||
t.Fatalf("name=%q, want read", calls[0].Function.Name)
|
||||
}
|
||||
if got, _ := calls[0].Function.Arguments.Get("path"); got != "./go.mod" {
|
||||
t.Fatalf("path=%v, want ./go.mod", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserIgnoresDuplicatedNestedToolCall(t *testing.T) {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("name", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{{
|
||||
Function: api.ToolFunction{
|
||||
Name: "skill",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}}
|
||||
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(tools, nil, nil)
|
||||
|
||||
_, _, calls, err := parser.Add("<tool_call>skill\n{\"name\":\"git-diff-review\"}\n<tool_call>skill\n{\"name\":\"git-diff-review\"}</tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls=%d, want 1", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "skill" {
|
||||
t.Fatalf("name=%q, want skill", calls[0].Function.Name)
|
||||
}
|
||||
if got, _ := calls[0].Function.Arguments.Get("name"); got != "git-diff-review" {
|
||||
t.Fatalf("name arg=%v, want git-diff-review", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserThinkingThenTool(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(lagunaTestTools(), nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("<think>Need current weather.</think>\n<tool_call>get_weather\n<arg_key>location</arg_key>\n<arg_value>SF</arg_value>\n</tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("content=%q, want empty", content)
|
||||
}
|
||||
if thinking != "Need current weather." {
|
||||
t.Fatalf("thinking=%q, want reasoning", thinking)
|
||||
}
|
||||
if len(calls) != 1 || calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("unexpected calls: %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserUserTaggedToolAlias(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(lagunaTestTools(), nil, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add("<user>get_weather\n<arg_key>location</arg_key>\n<arg_value>San Francisco, CA</arg_value>\n</user>", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "" || thinking != "" {
|
||||
t.Fatalf("content=%q thinking=%q, want empty", content, thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls=%d, want 1", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("name=%q, want get_weather", calls[0].Function.Name)
|
||||
}
|
||||
if got, _ := calls[0].Function.Arguments.Get("location"); got != "San Francisco, CA" {
|
||||
t.Fatalf("location=%v, want San Francisco, CA", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserUserTaggedToolAliasWithLeadingContent(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("path", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{{
|
||||
Function: api.ToolFunction{
|
||||
Name: "read",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}}
|
||||
parser.Init(tools, nil, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add("I'll read the file for you.\n<user>read\n<arg_key>path</arg_key>\n<arg_value>/Users/test/code/myproject/go.mod</arg_value>\n</user>", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "I'll read the file for you." || thinking != "" {
|
||||
t.Fatalf("content=%q thinking=%q", content, thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls=%d, want 1", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "read" {
|
||||
t.Fatalf("name=%q, want read", calls[0].Function.Name)
|
||||
}
|
||||
if got, _ := calls[0].Function.Arguments.Get("path"); got != "/Users/test/code/myproject/go.mod" {
|
||||
t.Fatalf("path=%v, want /Users/test/code/myproject/go.mod", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserUserTaggedJSONToolCallWithLeadingContent(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("command", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{{
|
||||
Function: api.ToolFunction{
|
||||
Name: "bash",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}}
|
||||
parser.Init(tools, nil, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add("I'll run git diff for you.<user>\n{\"name\":\"bash\",\"arguments\":{\"command\":\"git diff main\"}}\n</user>", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "I'll run git diff for you." || thinking != "" {
|
||||
t.Fatalf("content=%q thinking=%q", content, thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls=%d, want 1", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "bash" {
|
||||
t.Fatalf("name=%q, want bash", calls[0].Function.Name)
|
||||
}
|
||||
if got, _ := calls[0].Function.Arguments.Get("command"); got != "git diff main" {
|
||||
t.Fatalf("command=%v, want git diff main", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserStreamingUserTaggedToolAliasAfterContent(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("path", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{{
|
||||
Function: api.ToolFunction{
|
||||
Name: "read",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}}
|
||||
parser.Init(tools, nil, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add("I'll read the file for you.<us", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "I'll read the file for you." || thinking != "" || len(calls) != 0 {
|
||||
t.Fatalf("first chunk content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("er>read\n<arg_key>path</arg_key>\n<arg_value>/Users/test/code/myproject/go.mod</arg_value>\n</user>", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "" || thinking != "" {
|
||||
t.Fatalf("second chunk content=%q thinking=%q", content, thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls=%d, want 1", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "read" {
|
||||
t.Fatalf("name=%q, want read", calls[0].Function.Name)
|
||||
}
|
||||
if got, _ := calls[0].Function.Arguments.Get("path"); got != "/Users/test/code/myproject/go.mod" {
|
||||
t.Fatalf("path=%v, want /Users/test/code/myproject/go.mod", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserUserTaggedNonToolContent(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(lagunaTestTools(), nil, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add("<user>hello</user>", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "<user>hello</user>" || thinking != "" || len(calls) != 0 {
|
||||
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserThinkingDefaultsOn(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(nil, nil, nil)
|
||||
content, thinking, calls, err := parser.Add("<think>Need to reason.</think>\nDirect answer.", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "Direct answer." || thinking != "Need to reason." || len(calls) != 0 {
|
||||
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserThinkingDefaultsOnWhenToolsPresent(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(lagunaTestTools(), nil, nil)
|
||||
content, thinking, calls, err := parser.Add("<think>Need to reason.</think>\n<tool_call>get_weather\n<arg_key>location</arg_key>\n<arg_value>Paris</arg_value>\n</tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if thinking != "Need to reason." || len(calls) != 1 {
|
||||
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("content=%q, want thinking block suppressed from content when default thinking is enabled", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserThinkingExplicitlyDisabled(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
content, thinking, calls, err := parser.Add("<think>Hidden?</think>\nDirect answer.", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "Direct answer." || thinking != "" || len(calls) != 0 {
|
||||
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserThinkingExplicitlyDisabledDropsLeadingCloseTag(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
content, thinking, calls, err := parser.Add("</think>\nTokyo\n", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "Tokyo\n" || thinking != "" || len(calls) != 0 {
|
||||
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserThinkingEnabledDropsLeadingCloseTag(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
content, thinking, calls, err := parser.Add("</think>\nTokyo\n", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "Tokyo\n" || thinking != "" || len(calls) != 0 {
|
||||
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserThinkingDefaultOnDropsLeadingCloseTag(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(nil, nil, nil)
|
||||
content, thinking, calls, err := parser.Add("</think>\nTokyo\n", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "Tokyo\n" || thinking != "" || len(calls) != 0 {
|
||||
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserThinkingEnabledUntaggedAnswerIsContent(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
content, thinking, calls, err := parser.Add("Direct answer.", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "Direct answer." || thinking != "" || len(calls) != 0 {
|
||||
t.Fatalf("content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaParserSplitToolTag(t *testing.T) {
|
||||
parser := ParserForName("laguna")
|
||||
parser.Init(lagunaTestTools(), nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("<think>Need lookup<tool_c", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "" || thinking != "Need lookup" || len(calls) != 0 {
|
||||
t.Fatalf("first chunk content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("all>get_weather\n<arg_key>location</arg_key>\n<arg_value>SF</arg_value>\n</tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if content != "" || thinking != "" || len(calls) != 1 {
|
||||
t.Fatalf("second chunk content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
}
|
||||
@@ -16,14 +16,17 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
nemotronThinkOpen = "<think>"
|
||||
nemotronThinkClose = "</think>"
|
||||
nemotronToolCallOpen = "<tool_call>"
|
||||
)
|
||||
|
||||
type Nemotron3NanoParser struct {
|
||||
state Nemotron3NanoParserState
|
||||
buffer strings.Builder
|
||||
toolParser *Qwen3CoderParser
|
||||
state Nemotron3NanoParserState
|
||||
buffer strings.Builder
|
||||
toolParser *Qwen3CoderParser
|
||||
maybeThinkingOpenAtBOL bool
|
||||
skipThinkingLeadingWS bool
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) HasToolSupport() bool { return true }
|
||||
@@ -32,14 +35,18 @@ func (p *Nemotron3NanoParser) HasThinkingSupport() bool { return true }
|
||||
func (p *Nemotron3NanoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.toolParser = &Qwen3CoderParser{}
|
||||
p.toolParser.Init(tools, nil, nil)
|
||||
p.buffer.Reset()
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
p.skipThinkingLeadingWS = false
|
||||
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
thinkingEnabled := thinkValue == nil || thinkValue.Bool()
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
|
||||
if !thinkingEnabled || (prefill && lastMessage.Content != "") {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
} else {
|
||||
p.state = Nemotron3NanoCollectingThinking
|
||||
p.maybeThinkingOpenAtBOL = true
|
||||
}
|
||||
|
||||
return tools
|
||||
@@ -61,6 +68,29 @@ func (p *Nemotron3NanoParser) Add(s string, done bool) (content string, thinking
|
||||
|
||||
// Nemotron3NanoCollectingThinking - buffer and look for end markers
|
||||
p.buffer.WriteString(s)
|
||||
if p.skipThinkingLeadingWS {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(trimmed)
|
||||
if trimmed == "" {
|
||||
return "", "", nil, nil
|
||||
}
|
||||
p.skipThinkingLeadingWS = false
|
||||
}
|
||||
if p.stripOpeningThinkTag() {
|
||||
return p.Add("", done)
|
||||
}
|
||||
if p.maybeThinkingOpenAtBOL {
|
||||
bufStr := p.buffer.String()
|
||||
trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace)
|
||||
if trimmed == "" || overlap(trimmed, nemotronThinkOpen) == len(trimmed) {
|
||||
if len(trimmed) != len(bufStr) {
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(trimmed)
|
||||
}
|
||||
return "", "", nil, nil
|
||||
}
|
||||
}
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
// Look for end of thinking: </think> or <tool_call> (model may skip </think>)
|
||||
@@ -124,3 +154,35 @@ func (p *Nemotron3NanoParser) emitThinking(bufStr string) string {
|
||||
p.buffer.Reset()
|
||||
return bufStr
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) stripOpeningThinkTag() bool {
|
||||
if !p.maybeThinkingOpenAtBOL {
|
||||
return false
|
||||
}
|
||||
|
||||
bufStr := p.buffer.String()
|
||||
trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace)
|
||||
if trimmed == "" {
|
||||
p.buffer.Reset()
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(trimmed, nemotronThinkOpen) {
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(strings.TrimLeftFunc(trimmed[len(nemotronThinkOpen):], unicode.IsSpace))
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
p.skipThinkingLeadingWS = true
|
||||
return true
|
||||
}
|
||||
|
||||
if overlap(trimmed, nemotronThinkOpen) == len(trimmed) {
|
||||
if len(trimmed) != len(bufStr) {
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(trimmed)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -82,6 +82,20 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
expectedThinking: "My thoughts...",
|
||||
expectedContent: "Content here.",
|
||||
},
|
||||
{
|
||||
name: "leading open think tag is ignored",
|
||||
input: "<think>\nLet me think about this...</think>\nHere is my answer.",
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expectedThinking: "Let me think about this...",
|
||||
expectedContent: "Here is my answer.",
|
||||
},
|
||||
{
|
||||
name: "empty explicit think block is ignored",
|
||||
input: "<think></think>\nHere is my answer.",
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expectedThinking: "",
|
||||
expectedContent: "Here is my answer.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -191,6 +205,13 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "leading open think tag split across chunks",
|
||||
chunks: []string{"<th", "ink>", "\nThink first", "</think>", "\nDone."},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expectedThinking: "Think first",
|
||||
expectedContent: "Done.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -265,11 +286,11 @@ func TestNemotron3NanoParser_Init(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("starts in content state when nil thinkValue", func(t *testing.T) {
|
||||
t.Run("starts in thinking state when nil thinkValue", func(t *testing.T) {
|
||||
p := &Nemotron3NanoParser{}
|
||||
p.Init(nil, nil, nil)
|
||||
if p.state != Nemotron3NanoCollectingContent {
|
||||
t.Errorf("expected state Nemotron3NanoCollectingContent, got %v", p.state)
|
||||
if p.state != Nemotron3NanoCollectingThinking {
|
||||
t.Errorf("expected state Nemotron3NanoCollectingThinking, got %v", p.state)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -281,6 +302,29 @@ func TestNemotron3NanoParser_Init(t *testing.T) {
|
||||
t.Errorf("expected state Nemotron3NanoCollectingContent, got %v", p.state)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reinit clears buffered state", func(t *testing.T) {
|
||||
p := &Nemotron3NanoParser{}
|
||||
p.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
if _, _, _, err := p.Add("thinking in progress", false); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
p.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
content, thinking, calls, err := p.Add("content only", true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if content != "content only" {
|
||||
t.Fatalf("expected content after reinit, got %q", content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking after reinit, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls after reinit, got %v", calls)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
||||
|
||||
@@ -87,6 +87,8 @@ func ParserForName(name string) Parser {
|
||||
return &LFM2Parser{hasThinkingSupport: false}
|
||||
case "lfm2-thinking":
|
||||
return &LFM2Parser{hasThinkingSupport: true}
|
||||
case "laguna":
|
||||
return &LagunaParser{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
111
model/renderers/laguna.go
Normal file
111
model/renderers/laguna.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
lagunaBOS = "〈|EOS|〉"
|
||||
lagunaThoughtOpen = "<think>"
|
||||
lagunaThoughtClose = "</think>"
|
||||
)
|
||||
|
||||
type LagunaRenderer struct{}
|
||||
|
||||
func (r *LagunaRenderer) Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(lagunaBOS)
|
||||
|
||||
thinkingEnabled := think == nil || think.Bool()
|
||||
systemMessage := ""
|
||||
firstMessageIsSystem := len(messages) > 0 && messages[0].Role == "system"
|
||||
if firstMessageIsSystem {
|
||||
systemMessage = strings.TrimRight(messages[0].Content, "\n")
|
||||
}
|
||||
|
||||
sb.WriteString("<system>\n")
|
||||
if thinkingEnabled {
|
||||
sb.WriteString("You should use chain-of-thought reasoning. Put your reasoning inside <think> </think> tags before your response.")
|
||||
} else {
|
||||
sb.WriteString("You should respond directly without using chain-of-thought reasoning tags.")
|
||||
}
|
||||
if strings.TrimSpace(systemMessage) != "" {
|
||||
sb.WriteByte('\n')
|
||||
sb.WriteString(systemMessage)
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("\n\n### Tools\n\n")
|
||||
sb.WriteString("You may call functions to assist with the user query.\n")
|
||||
sb.WriteString("All available function signatures are listed below:\n")
|
||||
sb.WriteString("<available_tools>\n")
|
||||
for _, tool := range tools {
|
||||
if b, err := marshalWithSpaces(tool); err == nil {
|
||||
sb.Write(b)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
sb.WriteString("</available_tools>\n\n")
|
||||
sb.WriteString("For each function call, return a json object with function name and arguments within '<tool_call>' and '</tool_call>' tags:\n")
|
||||
sb.WriteString("<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>")
|
||||
}
|
||||
sb.WriteString("\n</system>\n")
|
||||
|
||||
for i, message := range messages {
|
||||
if i == 0 && firstMessageIsSystem {
|
||||
continue
|
||||
}
|
||||
content := message.Content
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<user>\n")
|
||||
sb.WriteString(content)
|
||||
sb.WriteString("\n</user>\n")
|
||||
case "assistant":
|
||||
lastMessage := i == len(messages)-1
|
||||
prefill := lastMessage && (content != "" || message.Thinking != "" || len(message.ToolCalls) > 0)
|
||||
sb.WriteString("<assistant>\n")
|
||||
if thinkingEnabled && message.Thinking != "" {
|
||||
sb.WriteString(lagunaThoughtOpen)
|
||||
sb.WriteString(message.Thinking)
|
||||
sb.WriteString(lagunaThoughtClose)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if strings.Trim(content, "\n") != "" {
|
||||
sb.WriteString(strings.Trim(content, "\n"))
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<tool_call>")
|
||||
sb.WriteString(toolCall.Function.Name)
|
||||
sb.WriteByte('\n')
|
||||
for name, value := range toolCall.Function.Arguments.All() {
|
||||
sb.WriteString("<arg_key>")
|
||||
sb.WriteString(name)
|
||||
sb.WriteString("</arg_key>\n")
|
||||
sb.WriteString("<arg_value>")
|
||||
sb.WriteString(formatToolCallArgument(value))
|
||||
sb.WriteString("</arg_value>\n")
|
||||
}
|
||||
sb.WriteString("</tool_call>\n")
|
||||
}
|
||||
if !prefill {
|
||||
sb.WriteString("</assistant>\n")
|
||||
}
|
||||
case "tool":
|
||||
sb.WriteString("<tool_response>\n")
|
||||
sb.WriteString(content)
|
||||
sb.WriteString("\n</tool_response>\n")
|
||||
case "system":
|
||||
sb.WriteString("<system>\n")
|
||||
sb.WriteString(content)
|
||||
sb.WriteString("\n</system>\n")
|
||||
}
|
||||
}
|
||||
|
||||
if len(messages) == 0 || messages[len(messages)-1].Role != "assistant" {
|
||||
sb.WriteString("<assistant>\n")
|
||||
}
|
||||
return sb.String(), nil
|
||||
}
|
||||
339
model/renderers/laguna_test.go
Normal file
339
model/renderers/laguna_test.go
Normal file
@@ -0,0 +1,339 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
lagunaDirectDirective = "You should respond directly without using chain-of-thought reasoning tags."
|
||||
lagunaThinkDirective = "You should use chain-of-thought reasoning. Put your reasoning inside <think> </think> tags before your response."
|
||||
)
|
||||
|
||||
func TestLagunaRendererReferenceFlowCoverage(t *testing.T) {
|
||||
weather := lagunaWeatherTool()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
think *api.ThinkValue
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "user_only_thinking_default_on",
|
||||
messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
want: "" +
|
||||
"〈|EOS|〉<system>\n" +
|
||||
lagunaThinkDirective +
|
||||
"\n</system>\n" +
|
||||
"<user>\nHello\n</user>\n" +
|
||||
"<assistant>\n",
|
||||
},
|
||||
{
|
||||
name: "user_only_thinking_enabled",
|
||||
messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
think: &api.ThinkValue{Value: true},
|
||||
want: "" +
|
||||
"〈|EOS|〉<system>\n" +
|
||||
lagunaThinkDirective +
|
||||
"\n</system>\n" +
|
||||
"<user>\nHello\n</user>\n" +
|
||||
"<assistant>\n",
|
||||
},
|
||||
{
|
||||
name: "user_only_thinking_disabled",
|
||||
messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
think: &api.ThinkValue{Value: false},
|
||||
want: "" +
|
||||
"〈|EOS|〉<system>\n" +
|
||||
lagunaDirectDirective +
|
||||
"\n</system>\n" +
|
||||
"<user>\nHello\n</user>\n" +
|
||||
"<assistant>\n",
|
||||
},
|
||||
{
|
||||
name: "first_system_is_header",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "Stay concise.\n\n"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
want: "" +
|
||||
"〈|EOS|〉<system>\n" +
|
||||
lagunaThinkDirective +
|
||||
"\nStay concise." +
|
||||
"\n</system>\n" +
|
||||
"<user>\nHi\n</user>\n" +
|
||||
"<assistant>\n",
|
||||
},
|
||||
{
|
||||
name: "additional_system_message_renders_in_loop",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "Primary."},
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "system", Content: "Secondary."},
|
||||
},
|
||||
want: "" +
|
||||
"〈|EOS|〉<system>\n" +
|
||||
lagunaThinkDirective +
|
||||
"\nPrimary." +
|
||||
"\n</system>\n" +
|
||||
"<user>\nHi\n</user>\n" +
|
||||
"<system>\nSecondary.\n</system>\n" +
|
||||
"<assistant>\n",
|
||||
},
|
||||
{
|
||||
name: "tools_in_header",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "Stay concise."},
|
||||
{Role: "user", Content: "Weather?"},
|
||||
},
|
||||
tools: weather,
|
||||
think: &api.ThinkValue{Value: true},
|
||||
want: "" +
|
||||
"〈|EOS|〉<system>\n" +
|
||||
lagunaThinkDirective +
|
||||
"\nStay concise." +
|
||||
"\n\n### Tools\n\n" +
|
||||
"You may call functions to assist with the user query.\n" +
|
||||
"All available function signatures are listed below:\n" +
|
||||
"<available_tools>\n" +
|
||||
`{"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "City"}}}}}` + "\n" +
|
||||
"</available_tools>\n\n" +
|
||||
"For each function call, return a json object with function name and arguments within '<tool_call>' and '</tool_call>' tags:\n" +
|
||||
"<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" +
|
||||
"\n</system>\n" +
|
||||
"<user>\nWeather?\n</user>\n" +
|
||||
"<assistant>\n",
|
||||
},
|
||||
{
|
||||
name: "tools_default_thinking_on_when_unspecified",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
},
|
||||
tools: weather,
|
||||
want: "" +
|
||||
"〈|EOS|〉<system>\n" +
|
||||
lagunaThinkDirective +
|
||||
"\n\n### Tools\n\n" +
|
||||
"You may call functions to assist with the user query.\n" +
|
||||
"All available function signatures are listed below:\n" +
|
||||
"<available_tools>\n" +
|
||||
`{"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "City"}}}}}` + "\n" +
|
||||
"</available_tools>\n\n" +
|
||||
"For each function call, return a json object with function name and arguments within '<tool_call>' and '</tool_call>' tags:\n" +
|
||||
"<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" +
|
||||
"\n</system>\n" +
|
||||
"<user>\nWeather?\n</user>\n" +
|
||||
"<assistant>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant_history_with_thinking_content_tool_and_response",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Add these."},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "\nCalling the tool.\n",
|
||||
Thinking: "Need addition.",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "add",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{Key: "a", Value: 2},
|
||||
{Key: "b", Value: 3},
|
||||
}),
|
||||
},
|
||||
}},
|
||||
},
|
||||
{Role: "tool", Content: "5"},
|
||||
{Role: "user", Content: "Thanks"},
|
||||
},
|
||||
think: &api.ThinkValue{Value: true},
|
||||
want: "" +
|
||||
"〈|EOS|〉<system>\n" +
|
||||
lagunaThinkDirective +
|
||||
"\n</system>\n" +
|
||||
"<user>\nAdd these.\n</user>\n" +
|
||||
"<assistant>\n" +
|
||||
"<think>Need addition.</think>\n" +
|
||||
"Calling the tool.\n" +
|
||||
"<tool_call>add\n" +
|
||||
"<arg_key>a</arg_key>\n<arg_value>2</arg_value>\n" +
|
||||
"<arg_key>b</arg_key>\n<arg_value>3</arg_value>\n" +
|
||||
"</tool_call>\n" +
|
||||
"</assistant>\n" +
|
||||
"<tool_response>\n5\n</tool_response>\n" +
|
||||
"<user>\nThanks\n</user>\n" +
|
||||
"<assistant>\n",
|
||||
},
|
||||
{
|
||||
name: "final_assistant_prefill_is_continued",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Complete this"},
|
||||
{Role: "assistant", Content: "Partial"},
|
||||
},
|
||||
want: "" +
|
||||
"〈|EOS|〉<system>\n" +
|
||||
lagunaThinkDirective +
|
||||
"\n</system>\n" +
|
||||
"<user>\nComplete this\n</user>\n" +
|
||||
"<assistant>\nPartial\n",
|
||||
},
|
||||
}
|
||||
|
||||
renderer := &LagunaRenderer{}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Fatalf("renderer output mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLagunaRendererMatchesLocalJinjaControlFlow(t *testing.T) {
|
||||
if os.Getenv("VERIFY_LAGUNA_JINJA2") == "" {
|
||||
t.Skip("set VERIFY_LAGUNA_JINJA2=1 to compare against the local Laguna chat_template.jinja")
|
||||
}
|
||||
python := "/Users/daniel/.codex/worktrees/7038/ollama/.venv/bin/python3"
|
||||
if _, err := os.Stat(python); err != nil {
|
||||
t.Fatalf("VERIFY_LAGUNA_JINJA2 requires %s with jinja2 installed", python)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
think *api.ThinkValue
|
||||
}{
|
||||
{
|
||||
name: "user_only",
|
||||
messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
},
|
||||
{
|
||||
name: "system_user",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "Stay concise.\n"},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "additional_system_and_tool_response",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "Primary."},
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{Role: "assistant", Content: "Calling."},
|
||||
{Role: "tool", Content: "Sunny"},
|
||||
{Role: "system", Content: "Secondary."},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking_enabled",
|
||||
messages: []api.Message{{Role: "user", Content: "Think briefly."}},
|
||||
think: &api.ThinkValue{Value: true},
|
||||
},
|
||||
{
|
||||
name: "thinking_disabled",
|
||||
messages: []api.Message{{Role: "user", Content: "Answer directly."}},
|
||||
think: &api.ThinkValue{Value: false},
|
||||
},
|
||||
}
|
||||
|
||||
renderer := &LagunaRenderer{}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := renderer.Render(tt.messages, nil, tt.think)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, modelDir := range []string{
|
||||
"/Users/daniel/Models/poolside/laguna-xs-23-04-2026",
|
||||
} {
|
||||
want := renderLagunaChatTemplate(t, python, modelDir, tt.messages, tt.think)
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Fatalf("%s mismatch (-chat_template +renderer):\n%s", modelDir, diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func renderLagunaChatTemplate(t *testing.T, python, modelDir string, messages []api.Message, think *api.ThinkValue) string {
|
||||
t.Helper()
|
||||
|
||||
type templateMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
templateMessages := make([]templateMessage, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
templateMessages = append(templateMessages, templateMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
messagesJSON, err := json.Marshal(templateMessages)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal messages: %v", err)
|
||||
}
|
||||
|
||||
enableThinking := "True"
|
||||
if think != nil && !think.Bool() {
|
||||
enableThinking = "False"
|
||||
}
|
||||
|
||||
script := `
|
||||
import json
|
||||
import sys
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_dir = sys.argv[1]
|
||||
messages = json.loads(sys.argv[2])
|
||||
enable_thinking = sys.argv[3] == "True"
|
||||
tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
print(tok.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=enable_thinking,
|
||||
), end="")
|
||||
`
|
||||
cmd := exec.Command(python, "-c", script, modelDir, string(messagesJSON), enableThinking)
|
||||
var stdout, stderr strings.Builder
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.Fatalf("chat_template render failed: %v\nstderr: %s", err, stderr.String())
|
||||
}
|
||||
return stdout.String()
|
||||
}
|
||||
|
||||
func lagunaWeatherTool() []api.Tool {
|
||||
return []api.Tool{{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: testPropsOrdered([]orderedProp{{
|
||||
Key: "location",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City",
|
||||
},
|
||||
}}),
|
||||
},
|
||||
},
|
||||
}}
|
||||
}
|
||||
@@ -3,6 +3,9 @@ package renderers
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -12,15 +15,15 @@ type Nemotron3NanoRenderer struct{}
|
||||
|
||||
func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
imageOffset := 0
|
||||
|
||||
// thinking is enabled if user requests it
|
||||
enableThinking := thinkValue != nil && thinkValue.Bool()
|
||||
enableThinking := r.resolveThinking(messages, thinkValue)
|
||||
|
||||
// Extract system message if present
|
||||
var systemMessage string
|
||||
var loopMessages []api.Message
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
systemMessage = messages[0].Content
|
||||
systemMessage = r.sanitizeSystemMessage(messages[0].Content)
|
||||
loopMessages = messages[1:]
|
||||
} else {
|
||||
loopMessages = messages
|
||||
@@ -34,6 +37,7 @@ func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool,
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n\n")
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
if systemMessage != "" {
|
||||
sb.WriteString(systemMessage)
|
||||
@@ -45,28 +49,30 @@ func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool,
|
||||
}
|
||||
sb.WriteString(r.renderTools(tools))
|
||||
}
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
sb.WriteString("<|im_end|>\n\n")
|
||||
|
||||
for i, message := range loopMessages {
|
||||
switch message.Role {
|
||||
case "assistant":
|
||||
// Build content with thinking tags
|
||||
content := r.buildContent(message)
|
||||
shouldTruncate := i < lastUserIdx
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
sb.WriteString(r.formatContent(content, shouldTruncate, true))
|
||||
sb.WriteString(r.formatToolCallContent(content, shouldTruncate))
|
||||
r.writeToolCalls(&sb, message.ToolCalls)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
} else {
|
||||
formatted := r.formatContent(content, shouldTruncate, false)
|
||||
sb.WriteString("<|im_start|>assistant\n" + formatted + "<|im_end|>\n")
|
||||
formatted := r.formatAssistantContent(content, shouldTruncate)
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
sb.WriteString(formatted)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
case "user", "system":
|
||||
sb.WriteString("<|im_start|>" + message.Role + "\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString(r.renderMessageContent(message, imageOffset))
|
||||
imageOffset += len(message.Images)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
|
||||
case "tool":
|
||||
@@ -90,6 +96,8 @@ func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool,
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Add generation prompt
|
||||
if enableThinking {
|
||||
sb.WriteString("<|im_start|>assistant\n<think>\n")
|
||||
@@ -119,7 +127,7 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string {
|
||||
sb.WriteString("\n<name>" + paramName + "</name>")
|
||||
|
||||
if len(paramFields.Type) > 0 {
|
||||
sb.WriteString("\n<type>" + strings.Join(paramFields.Type, ", ") + "</type>")
|
||||
sb.WriteString("\n<type>" + r.formatPropertyType(paramFields.Type) + "</type>")
|
||||
}
|
||||
|
||||
if paramFields.Description != "" {
|
||||
@@ -127,17 +135,17 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string {
|
||||
}
|
||||
|
||||
if len(paramFields.Enum) > 0 {
|
||||
enumJSON, _ := json.Marshal(paramFields.Enum)
|
||||
sb.WriteString("\n<enum>" + string(enumJSON) + "</enum>")
|
||||
sb.WriteString("\n<enum>" + r.pythonJSON(paramFields.Enum) + "</enum>")
|
||||
}
|
||||
|
||||
r.renderToolPropertyExtraKeys(&sb, paramFields)
|
||||
sb.WriteString("\n</parameter>")
|
||||
}
|
||||
}
|
||||
|
||||
r.renderToolParameterExtraKeys(&sb, fn.Parameters)
|
||||
if len(fn.Parameters.Required) > 0 {
|
||||
reqJSON, _ := json.Marshal(fn.Parameters.Required)
|
||||
sb.WriteString("\n<required>" + string(reqJSON) + "</required>")
|
||||
sb.WriteString("\n<required>" + r.pythonJSON(fn.Parameters.Required) + "</required>")
|
||||
}
|
||||
|
||||
sb.WriteString("\n</parameters>")
|
||||
@@ -159,27 +167,38 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string {
|
||||
}
|
||||
|
||||
func (r *Nemotron3NanoRenderer) buildContent(message api.Message) string {
|
||||
// The parser always extracts thinking into the Thinking field,
|
||||
// so Content will never have <think> tags embedded
|
||||
content := nemotron3NanoRenderContent(message.Content)
|
||||
if message.Thinking != "" {
|
||||
return "<think>\n" + message.Thinking + "\n</think>\n" + message.Content
|
||||
return "<think>\n" + message.Thinking + "\n</think>\n" + content
|
||||
}
|
||||
return "<think></think>" + message.Content
|
||||
if !strings.Contains(content, "<think>") && !strings.Contains(content, "</think>") {
|
||||
return "<think></think>" + content
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, addNewline bool) string {
|
||||
if content == "" {
|
||||
func (r *Nemotron3NanoRenderer) formatAssistantContent(content string, truncate bool) string {
|
||||
if !truncate {
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
c := content
|
||||
if strings.Contains(c, "<think>") && strings.Contains(c, "</think>") {
|
||||
parts := strings.Split(c, "</think>")
|
||||
c = "<think></think>" + parts[len(parts)-1]
|
||||
}
|
||||
return strings.TrimSpace(c)
|
||||
}
|
||||
|
||||
func (r *Nemotron3NanoRenderer) formatToolCallContent(content string, truncate bool) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return "<think></think>"
|
||||
}
|
||||
|
||||
if !truncate {
|
||||
if addNewline {
|
||||
return strings.TrimSpace(content) + "\n"
|
||||
}
|
||||
return strings.TrimSpace(content)
|
||||
return strings.TrimSpace(content) + "\n"
|
||||
}
|
||||
|
||||
// Truncate thinking - keep only content after </think>
|
||||
c := content
|
||||
if strings.Contains(c, "</think>") {
|
||||
parts := strings.Split(c, "</think>")
|
||||
@@ -190,13 +209,7 @@ func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, add
|
||||
}
|
||||
c = "<think></think>" + strings.TrimSpace(c)
|
||||
|
||||
if addNewline && len(c) > len("<think></think>") {
|
||||
return c + "\n"
|
||||
}
|
||||
if c == "<think></think>" {
|
||||
return c
|
||||
}
|
||||
return strings.TrimSpace(c)
|
||||
return strings.TrimSpace(c) + "\n"
|
||||
}
|
||||
|
||||
func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) {
|
||||
@@ -212,9 +225,225 @@ func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []
|
||||
func (r *Nemotron3NanoRenderer) formatArgValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case map[string]any, []any:
|
||||
jsonBytes, _ := json.Marshal(v)
|
||||
return string(jsonBytes)
|
||||
return r.pythonJSON(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Nemotron3NanoRenderer) renderMessageContent(message api.Message, imageOffset int) string {
|
||||
content := nemotron3NanoRenderContent(message.Content)
|
||||
if len(message.Images) == 0 {
|
||||
return content
|
||||
}
|
||||
|
||||
if strings.Contains(content, "[img-") {
|
||||
return content
|
||||
}
|
||||
|
||||
if strings.Contains(content, "[img]") {
|
||||
for i := range message.Images {
|
||||
content = strings.Replace(content, "[img]", fmt.Sprintf("[img-%d]", imageOffset+i), 1)
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for i := range message.Images {
|
||||
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset+i))
|
||||
}
|
||||
sb.WriteString(content)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func nemotron3NanoRenderContent(content any) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []any:
|
||||
var sb strings.Builder
|
||||
for _, item := range v {
|
||||
obj, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
bts, _ := json.Marshal(item)
|
||||
sb.Write(bts)
|
||||
continue
|
||||
}
|
||||
|
||||
switch obj["type"] {
|
||||
case "image":
|
||||
sb.WriteString("<image>")
|
||||
case "text":
|
||||
if text, ok := obj["text"].(string); ok {
|
||||
sb.WriteString(text)
|
||||
}
|
||||
default:
|
||||
bts, _ := json.Marshal(item)
|
||||
sb.Write(bts)
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
default:
|
||||
bts, _ := json.Marshal(v)
|
||||
return string(bts)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Nemotron3NanoRenderer) resolveThinking(messages []api.Message, thinkValue *api.ThinkValue) bool {
|
||||
enableThinking := thinkValue == nil || thinkValue.Bool()
|
||||
for _, message := range messages {
|
||||
if message.Role != "user" && message.Role != "system" {
|
||||
continue
|
||||
}
|
||||
content := message.Content
|
||||
if strings.Contains(strings.ReplaceAll(content, "</think>", ""), "/think") {
|
||||
enableThinking = true
|
||||
} else if strings.Contains(content, "/no_think") {
|
||||
enableThinking = false
|
||||
}
|
||||
}
|
||||
return enableThinking
|
||||
}
|
||||
|
||||
func (r *Nemotron3NanoRenderer) sanitizeSystemMessage(content string) string {
|
||||
system := nemotron3NanoRenderContent(content)
|
||||
system = strings.ReplaceAll(system, "</think>", "<_end_think>")
|
||||
system = strings.ReplaceAll(system, "/think", "")
|
||||
system = strings.ReplaceAll(system, "/no_think", "")
|
||||
system = strings.ReplaceAll(system, "<_end_think>", "</think>")
|
||||
return system
|
||||
}
|
||||
|
||||
func (r *Nemotron3NanoRenderer) formatPropertyType(propertyType api.PropertyType) string {
|
||||
if len(propertyType) == 1 {
|
||||
return propertyType[0]
|
||||
}
|
||||
quoted := make([]string, 0, len(propertyType))
|
||||
for _, v := range propertyType {
|
||||
quoted = append(quoted, "'"+v+"'")
|
||||
}
|
||||
return "[" + strings.Join(quoted, ", ") + "]"
|
||||
}
|
||||
|
||||
func (r *Nemotron3NanoRenderer) renderToolPropertyExtraKeys(sb *strings.Builder, prop api.ToolProperty) {
|
||||
if len(prop.AnyOf) > 0 {
|
||||
sb.WriteString("\n<anyOf>" + r.pythonJSON(prop.AnyOf) + "</anyOf>")
|
||||
}
|
||||
if prop.Items != nil {
|
||||
sb.WriteString("\n<items>" + r.pythonJSON(prop.Items) + "</items>")
|
||||
}
|
||||
if prop.Properties != nil {
|
||||
sb.WriteString("\n<properties>" + r.pythonJSON(prop.Properties) + "</properties>")
|
||||
}
|
||||
if len(prop.Required) > 0 {
|
||||
sb.WriteString("\n<required>" + r.pythonJSON(prop.Required) + "</required>")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Nemotron3NanoRenderer) renderToolParameterExtraKeys(sb *strings.Builder, params api.ToolFunctionParameters) {
|
||||
if params.Defs != nil {
|
||||
sb.WriteString("\n<$defs>" + r.pythonJSON(params.Defs) + "</$defs>")
|
||||
}
|
||||
if params.Items != nil {
|
||||
sb.WriteString("\n<items>" + r.pythonJSON(params.Items) + "</items>")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Nemotron3NanoRenderer) pythonJSON(v any) string {
|
||||
switch value := v.(type) {
|
||||
case nil:
|
||||
return "null"
|
||||
case string:
|
||||
return strconv.Quote(value)
|
||||
case bool:
|
||||
if value {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
case int, int8, int16, int32, int64:
|
||||
return fmt.Sprintf("%d", reflect.ValueOf(value).Int())
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", reflect.ValueOf(value).Uint())
|
||||
case float32, float64:
|
||||
b, _ := json.Marshal(value)
|
||||
return string(b)
|
||||
case api.PropertyType:
|
||||
return r.pythonJSON([]string(value))
|
||||
case []string:
|
||||
parts := make([]string, 0, len(value))
|
||||
for _, item := range value {
|
||||
parts = append(parts, r.pythonJSON(item))
|
||||
}
|
||||
return "[" + strings.Join(parts, ", ") + "]"
|
||||
case []any:
|
||||
parts := make([]string, 0, len(value))
|
||||
for _, item := range value {
|
||||
parts = append(parts, r.pythonJSON(item))
|
||||
}
|
||||
return "[" + strings.Join(parts, ", ") + "]"
|
||||
case []api.ToolProperty:
|
||||
parts := make([]string, 0, len(value))
|
||||
for _, item := range value {
|
||||
parts = append(parts, r.pythonJSON(item))
|
||||
}
|
||||
return "[" + strings.Join(parts, ", ") + "]"
|
||||
case map[string]any:
|
||||
keys := make([]string, 0, len(value))
|
||||
for key := range value {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
parts := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
parts = append(parts, strconv.Quote(key)+": "+r.pythonJSON(value[key]))
|
||||
}
|
||||
return "{" + strings.Join(parts, ", ") + "}"
|
||||
case *api.ToolPropertiesMap:
|
||||
if value == nil {
|
||||
return "null"
|
||||
}
|
||||
parts := make([]string, 0, value.Len())
|
||||
for key, prop := range value.All() {
|
||||
parts = append(parts, strconv.Quote(key)+": "+r.pythonJSON(prop))
|
||||
}
|
||||
return "{" + strings.Join(parts, ", ") + "}"
|
||||
case api.ToolProperty:
|
||||
parts := make([]string, 0, 6)
|
||||
if len(value.AnyOf) > 0 {
|
||||
parts = append(parts, `"anyOf": `+r.pythonJSON(value.AnyOf))
|
||||
}
|
||||
if len(value.Type) > 0 {
|
||||
if len(value.Type) == 1 {
|
||||
parts = append(parts, `"type": `+r.pythonJSON(value.Type[0]))
|
||||
} else {
|
||||
parts = append(parts, `"type": `+r.pythonJSON([]string(value.Type)))
|
||||
}
|
||||
}
|
||||
if value.Items != nil {
|
||||
parts = append(parts, `"items": `+r.pythonJSON(value.Items))
|
||||
}
|
||||
if value.Description != "" {
|
||||
parts = append(parts, `"description": `+r.pythonJSON(value.Description))
|
||||
}
|
||||
if len(value.Enum) > 0 {
|
||||
parts = append(parts, `"enum": `+r.pythonJSON(value.Enum))
|
||||
}
|
||||
if value.Properties != nil {
|
||||
parts = append(parts, `"properties": `+r.pythonJSON(value.Properties))
|
||||
}
|
||||
if len(value.Required) > 0 {
|
||||
parts = append(parts, `"required": `+r.pythonJSON(value.Required))
|
||||
}
|
||||
return "{" + strings.Join(parts, ", ") + "}"
|
||||
default:
|
||||
b, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "null"
|
||||
}
|
||||
var generic any
|
||||
if err := json.Unmarshal(b, &generic); err != nil {
|
||||
return string(b)
|
||||
}
|
||||
return r.pythonJSON(generic)
|
||||
}
|
||||
}
|
||||
|
||||
614
model/renderers/nemotron3nano_reference_test.go
Normal file
614
model/renderers/nemotron3nano_reference_test.go
Normal file
@@ -0,0 +1,614 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const nemotron3NanoTemplate = "testdata/nemotron3nano_chat_template.jinja2"
|
||||
|
||||
func TestNemotron3NanoRendererMatchesReference(t *testing.T) {
|
||||
toolText := `<|im_start|>system
|
||||
# Tools
|
||||
|
||||
You have access to the following functions:
|
||||
|
||||
<tools>
|
||||
<function>
|
||||
<name>search_docs</name>
|
||||
<description>Search docs</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>query</name>
|
||||
<type>string</type>
|
||||
<description>Search query</description>
|
||||
<enum>["api", "cli"]</enum>
|
||||
</parameter>
|
||||
<parameter>
|
||||
<name>mode</name>
|
||||
<type>['string', 'null']</type>
|
||||
<description>Mode</description>
|
||||
<anyOf>[{"type": "string"}, {"type": "number"}]</anyOf>
|
||||
</parameter>
|
||||
<parameter>
|
||||
<name>payload</name>
|
||||
<type>object</type>
|
||||
<description>Payload</description>
|
||||
<properties>{"enabled": {"type": "boolean"}}</properties>
|
||||
<required>["enabled"]</required>
|
||||
</parameter>
|
||||
<parameter>
|
||||
<name>tags</name>
|
||||
<type>array</type>
|
||||
<description>Tags</description>
|
||||
<items>{"type": "string"}</items>
|
||||
</parameter>
|
||||
<$defs>{"shared": {"type": "string"}}</$defs>
|
||||
<required>["query"]</required>
|
||||
</parameters>
|
||||
</function>
|
||||
</tools>
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||
|
||||
<tool_call>
|
||||
<function=example_function_name>
|
||||
<parameter=example_parameter_1>
|
||||
value_1
|
||||
</parameter>
|
||||
<parameter=example_parameter_2>
|
||||
This is the value for the second parameter
|
||||
that can span
|
||||
multiple lines
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
<IMPORTANT>
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||
- Required parameters MUST be specified
|
||||
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||
</IMPORTANT><|im_end|>
|
||||
`
|
||||
toolTextWithSystem := `<|im_start|>system
|
||||
Follow policy.
|
||||
|
||||
# Tools
|
||||
|
||||
You have access to the following functions:
|
||||
|
||||
<tools>
|
||||
<function>
|
||||
<name>search_docs</name>
|
||||
<description>Search docs</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>query</name>
|
||||
<type>string</type>
|
||||
<description>Search query</description>
|
||||
<enum>["api", "cli"]</enum>
|
||||
</parameter>
|
||||
<parameter>
|
||||
<name>mode</name>
|
||||
<type>['string', 'null']</type>
|
||||
<description>Mode</description>
|
||||
<anyOf>[{"type": "string"}, {"type": "number"}]</anyOf>
|
||||
</parameter>
|
||||
<parameter>
|
||||
<name>payload</name>
|
||||
<type>object</type>
|
||||
<description>Payload</description>
|
||||
<properties>{"enabled": {"type": "boolean"}}</properties>
|
||||
<required>["enabled"]</required>
|
||||
</parameter>
|
||||
<parameter>
|
||||
<name>tags</name>
|
||||
<type>array</type>
|
||||
<description>Tags</description>
|
||||
<items>{"type": "string"}</items>
|
||||
</parameter>
|
||||
<$defs>{"shared": {"type": "string"}}</$defs>
|
||||
<required>["query"]</required>
|
||||
</parameters>
|
||||
</function>
|
||||
</tools>
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||
|
||||
<tool_call>
|
||||
<function=example_function_name>
|
||||
<parameter=example_parameter_1>
|
||||
value_1
|
||||
</parameter>
|
||||
<parameter=example_parameter_2>
|
||||
This is the value for the second parameter
|
||||
that can span
|
||||
multiple lines
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
<IMPORTANT>
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||
- Required parameters MUST be specified
|
||||
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||
</IMPORTANT><|im_end|>
|
||||
`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
think *api.ThinkValue
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no system default thinking on",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHello<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "no system explicit thinking off",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
think: thinkFalse(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHello<|im_end|>\n\n<|im_start|>assistant\n<think></think>",
|
||||
},
|
||||
{
|
||||
name: "literal endthink does not enable thinking",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "literal </think> only"},
|
||||
},
|
||||
think: thinkFalse(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nliteral </think> only<|im_end|>\n\n<|im_start|>assistant\n<think></think>",
|
||||
},
|
||||
{
|
||||
name: "user no think toggle",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello /no_think"},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHello /no_think<|im_end|>\n\n<|im_start|>assistant\n<think></think>",
|
||||
},
|
||||
{
|
||||
name: "system think toggle overrides false",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "Policy /think"},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
think: thinkFalse(),
|
||||
expected: "\n\n\n<|im_start|>system\nPolicy <|im_end|>\n\n<|im_start|>user\nHello<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "later toggle wins",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "Policy /no_think"},
|
||||
{Role: "user", Content: "Actually /think"},
|
||||
},
|
||||
think: thinkFalse(),
|
||||
expected: "\n\n\n<|im_start|>system\nPolicy <|im_end|>\n\n<|im_start|>user\nActually /think<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "system sanitizes toggles but preserves closing tag",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "A /think B /no_think C </think>"},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
think: thinkFalse(),
|
||||
expected: "\n\n\n<|im_start|>system\nA B C </think><|im_end|>\n\n<|im_start|>user\nHello<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant plain content adds empty think block",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello there"},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think></think>Hello there<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant reasoning content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Answer", Thinking: "Need to think"},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think>\nNeed to think\n</think>\nAnswer<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant preserves existing think tags",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "<think>kept</think>Answer"},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think>kept</think>Answer<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "tools without system",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Use a tool"},
|
||||
},
|
||||
tools: nemotron3NanoReferenceTools(),
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n" + toolText + "\n<|im_start|>user\nUse a tool<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "system with tools",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "Follow policy."},
|
||||
{Role: "user", Content: "Use a tool"},
|
||||
},
|
||||
tools: nemotron3NanoReferenceTools(),
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n" + toolTextWithSystem + "\n<|im_start|>user\nUse a tool<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant tool call with content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Checking now.",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nWeather?<|im_end|>\n<|im_start|>assistant\n<think></think>Checking now.\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant tool call with structured arguments",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Create data"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{Key: "payload", Value: map[string]any{"count": 42, "nested": map[string]any{"value": "ok"}}},
|
||||
{Key: "tags", Value: []any{"a", "b"}},
|
||||
}),
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nCreate data<|im_end|>\n<|im_start|>assistant\n<think></think>\n<tool_call>\n<function=create>\n<parameter=payload>\n{\"count\": 42, \"nested\": {\"value\": \"ok\"}}\n</parameter>\n<parameter=tags>\n[\"a\", \"b\"]\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant tool call truncated with reasoning",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Checking now.",
|
||||
Thinking: "Need weather",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
}},
|
||||
},
|
||||
{Role: "user", Content: "And tomorrow?"},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nWeather?<|im_end|>\n<|im_start|>assistant\n<think></think>Checking now.\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n<|im_start|>user\nAnd tomorrow?<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant tool call truncated open think only",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>draft",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
}},
|
||||
},
|
||||
{Role: "user", Content: "And tomorrow?"},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nWeather?<|im_end|>\n<|im_start|>assistant\n<think></think>\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n<|im_start|>user\nAnd tomorrow?<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant tool call empty content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nWeather?<|im_end|>\n<|im_start|>assistant\n<think></think>\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant truncated with think pair",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "<think>hidden</think>Visible"},
|
||||
{Role: "user", Content: "Next"},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think></think>Visible<|im_end|>\n<|im_start|>user\nNext<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant truncated reasoning content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Thinking: "hidden", Content: "Visible"},
|
||||
{Role: "user", Content: "Next"},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think></think>\nVisible<|im_end|>\n<|im_start|>user\nNext<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant truncated plain content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Visible"},
|
||||
{Role: "user", Content: "Next"},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think></think>Visible<|im_end|>\n<|im_start|>user\nNext<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant truncated empty content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: ""},
|
||||
{Role: "user", Content: "Next"},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n<think></think><|im_end|>\n<|im_start|>user\nNext<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "consecutive tool messages grouped",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Do work"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "step",
|
||||
Arguments: testArgs(map[string]any{"value": 1}),
|
||||
},
|
||||
}},
|
||||
},
|
||||
{Role: "tool", Content: "one"},
|
||||
{Role: "tool", Content: "two"},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\nDo work<|im_end|>\n<|im_start|>assistant\n<think></think>\n<tool_call>\n<function=step>\n<parameter=value>\n1\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n<|im_start|>user\n<tool_response>\none\n</tool_response>\n<tool_response>\ntwo\n</tool_response>\n<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "fallback role",
|
||||
messages: []api.Message{
|
||||
{Role: "developer", Content: "Custom role content"},
|
||||
},
|
||||
think: thinkTrue(),
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>developer\nCustom role content<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
}
|
||||
|
||||
verifyJinja2 := os.Getenv("VERIFY_JINJA2") != ""
|
||||
if verifyJinja2 {
|
||||
if _, err := os.Stat(filepath.Join(nemotron3NanoRepoRoot(t), ".venv", "bin", "python3")); err != nil {
|
||||
t.Fatal("VERIFY_JINJA2=1 requires .venv/bin/python3")
|
||||
}
|
||||
}
|
||||
|
||||
renderer := &Nemotron3NanoRenderer{}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
|
||||
if err != nil {
|
||||
t.Fatalf("Render() error = %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expected, got); diff != "" {
|
||||
t.Fatalf("renderer mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if verifyJinja2 {
|
||||
jinja2Output := renderNemotron3NanoWithJinja2(t, tt.messages, tt.tools, tt.think)
|
||||
if diff := cmp.Diff(tt.expected, jinja2Output); diff != "" {
|
||||
t.Fatalf("reference template mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func nemotron3NanoReferenceTools() []api.Tool {
|
||||
return []api.Tool{{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "search_docs",
|
||||
Description: "Search docs",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Defs: map[string]any{"shared": map[string]any{"type": "string"}},
|
||||
Required: []string{"query"},
|
||||
Properties: testPropsOrdered([]orderedProp{
|
||||
{
|
||||
Key: "query",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "Search query",
|
||||
Enum: []any{"api", "cli"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: "mode",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"string", "null"},
|
||||
Description: "Mode",
|
||||
AnyOf: []api.ToolProperty{
|
||||
{Type: api.PropertyType{"string"}},
|
||||
{Type: api.PropertyType{"number"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: "payload",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"object"},
|
||||
Description: "Payload",
|
||||
Properties: testPropsOrdered([]orderedProp{{Key: "enabled", Value: api.ToolProperty{Type: api.PropertyType{"boolean"}}}}),
|
||||
Required: []string{"enabled"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: "tags",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"array"},
|
||||
Description: "Tags",
|
||||
Items: map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
func renderNemotron3NanoWithJinja2(t *testing.T, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string {
|
||||
t.Helper()
|
||||
|
||||
type jinja2ToolCall struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments any `json:"arguments"`
|
||||
} `json:"function"`
|
||||
}
|
||||
type jinja2Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []jinja2ToolCall `json:"tool_calls,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
var jMsgs []jinja2Message
|
||||
for _, m := range messages {
|
||||
jm := jinja2Message{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.Thinking,
|
||||
Name: m.ToolName,
|
||||
ToolCallID: m.ToolCallID,
|
||||
}
|
||||
for _, tc := range m.ToolCalls {
|
||||
jtc := jinja2ToolCall{ID: tc.ID}
|
||||
jtc.Function.Name = tc.Function.Name
|
||||
var args map[string]any
|
||||
raw, _ := tc.Function.Arguments.MarshalJSON()
|
||||
if err := json.Unmarshal(raw, &args); err != nil {
|
||||
t.Fatalf("failed to unmarshal tool args: %v", err)
|
||||
}
|
||||
jtc.Function.Arguments = args
|
||||
jm.ToolCalls = append(jm.ToolCalls, jtc)
|
||||
}
|
||||
jMsgs = append(jMsgs, jm)
|
||||
}
|
||||
|
||||
msgsJSON, err := json.Marshal(jMsgs)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal messages: %v", err)
|
||||
}
|
||||
|
||||
toolsJSON := "None"
|
||||
if len(tools) > 0 {
|
||||
b, err := json.Marshal(tools)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal tools: %v", err)
|
||||
}
|
||||
toolsJSON = string(b)
|
||||
}
|
||||
|
||||
thinking := "unset"
|
||||
if think != nil {
|
||||
if think.Bool() {
|
||||
thinking = "true"
|
||||
} else {
|
||||
thinking = "false"
|
||||
}
|
||||
}
|
||||
|
||||
repoRoot := nemotron3NanoRepoRoot(t)
|
||||
templatePath := filepath.Join(repoRoot, "model", "renderers", nemotron3NanoTemplate)
|
||||
pythonPath := filepath.Join(repoRoot, ".venv", "bin", "python3")
|
||||
script := `
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from transformers.utils.chat_template_utils import _compile_jinja_template
|
||||
|
||||
template_path, messages_json, tools_json, thinking = sys.argv[1:5]
|
||||
tmpl = _compile_jinja_template(Path(template_path).read_text())
|
||||
kwargs = {
|
||||
"messages": json.loads(messages_json),
|
||||
"add_generation_prompt": True,
|
||||
}
|
||||
if tools_json != "None":
|
||||
kwargs["tools"] = json.loads(tools_json)
|
||||
if thinking == "true":
|
||||
kwargs["enable_thinking"] = True
|
||||
elif thinking == "false":
|
||||
kwargs["enable_thinking"] = False
|
||||
print(tmpl.render(**kwargs), end="")
|
||||
`
|
||||
|
||||
cmd := exec.Command(pythonPath, "-c", script, templatePath, string(msgsJSON), toolsJSON, thinking)
|
||||
var stdout, stderr strings.Builder
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.Fatalf("python render failed: %v\nstderr: %s", err, stderr.String())
|
||||
}
|
||||
return stdout.String()
|
||||
}
|
||||
|
||||
func nemotron3NanoRepoRoot(t *testing.T) string {
|
||||
t.Helper()
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
t.Fatal("failed to locate test file")
|
||||
}
|
||||
return filepath.Dir(filepath.Dir(filepath.Dir(filename)))
|
||||
}
|
||||
@@ -8,561 +8,46 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestNemotron3NanoRenderer(t *testing.T) {
|
||||
func TestNemotron3NanoRenderer_Images(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msgs []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
name string
|
||||
msgs []api.Message
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic user message - thinking mode",
|
||||
name: "single image inserts placeholder",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
{Role: "user", Content: "Describe this image.", Images: []api.ImageData{api.ImageData("img1")}},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||
"<|im_start|>user\nHello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0]Describe this image.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "basic user message - no thinking",
|
||||
name: "generic image placeholder is rewritten",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
{Role: "user", Content: "[img]Describe this image.", Images: []api.ImageData{api.ImageData("img1")}},
|
||||
},
|
||||
thinkValue: nil,
|
||||
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||
"<|im_start|>user\nHello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think></think>",
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0]Describe this image.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "with system message",
|
||||
name: "image offsets increment across messages",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
{Role: "user", Content: "Describe the first image.", Images: []api.ImageData{api.ImageData("img1")}},
|
||||
{Role: "assistant", Content: "It shows something."},
|
||||
{Role: "user", Content: "Compare these.", Images: []api.ImageData{api.ImageData("img2"), api.ImageData("img3")}},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" +
|
||||
"<|im_start|>user\nHello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello! How can I help?"},
|
||||
{Role: "user", Content: "Tell me a joke"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||
"<|im_start|>user\nHi<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think></think>Hello! How can I help?<|im_end|>\n" +
|
||||
"<|im_start|>user\nTell me a joke<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "with tools",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"city"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||
"<function>\n<name>get_weather</name>\n" +
|
||||
"<description>Get the current weather</description>\n" +
|
||||
"<parameters>\n" +
|
||||
"<parameter>\n<name>city</name>\n<type>string</type>\n<description>The city name</description>\n</parameter>\n" +
|
||||
"<required>[\"city\"]</required>\n" +
|
||||
"</parameters>\n</function>\n</tools>\n\n" +
|
||||
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||
"- Required parameters MUST be specified\n" +
|
||||
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||
"</IMPORTANT><|im_end|>\n" +
|
||||
"<|im_start|>user\nWhat's the weather in Paris?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "tool call with response",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny, 72F"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"city"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||
"<function>\n<name>get_weather</name>\n" +
|
||||
"<description>Get the current weather</description>\n" +
|
||||
"<parameters>\n" +
|
||||
"<parameter>\n<name>city</name>\n<type>string</type>\n<description>The city name</description>\n</parameter>\n" +
|
||||
"<required>[\"city\"]</required>\n" +
|
||||
"</parameters>\n</function>\n</tools>\n\n" +
|
||||
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||
"- Required parameters MUST be specified\n" +
|
||||
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||
"</IMPORTANT><|im_end|>\n" +
|
||||
"<|im_start|>user\nWhat's the weather in Paris?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think></think>\n" +
|
||||
"<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||
"<|im_start|>user\n<tool_response>\nSunny, 72F\n</tool_response>\n<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant with content and tool call",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check that for you.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||
"<function>\n<name>get_weather</name>\n" +
|
||||
"<parameters>\n" +
|
||||
"<parameter>\n<name>city</name>\n<type>string</type>\n</parameter>\n" +
|
||||
"</parameters>\n</function>\n</tools>\n\n" +
|
||||
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||
"- Required parameters MUST be specified\n" +
|
||||
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||
"</IMPORTANT><|im_end|>\n" +
|
||||
"<|im_start|>user\nWhat's the weather?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think></think>Let me check that for you.\n" +
|
||||
"<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||
"<|im_start|>user\n<tool_response>\nSunny\n</tool_response>\n<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "thinking in history is truncated",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello!", Thinking: "Let me think about this..."},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||
"<|im_start|>user\nHi<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think></think>Hello!<|im_end|>\n" +
|
||||
"<|im_start|>user\nHow are you?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "parallel tool calls",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Weather in Paris and London?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "London"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny"},
|
||||
{Role: "tool", Content: "Rainy"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||
"<function>\n<name>get_weather</name>\n" +
|
||||
"<parameters>\n" +
|
||||
"<parameter>\n<name>city</name>\n<type>string</type>\n</parameter>\n" +
|
||||
"</parameters>\n</function>\n</tools>\n\n" +
|
||||
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||
"- Required parameters MUST be specified\n" +
|
||||
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||
"</IMPORTANT><|im_end|>\n" +
|
||||
"<|im_start|>user\nWeather in Paris and London?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think></think>\n" +
|
||||
"<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n" +
|
||||
"<tool_call>\n<function=get_weather>\n<parameter=city>\nLondon\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||
"<|im_start|>user\n<tool_response>\nSunny\n</tool_response>\n<tool_response>\nRainy\n</tool_response>\n<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "thinking disabled when user doesn't request it",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: nil,
|
||||
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||
"<|im_start|>user\nHello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think></think>",
|
||||
},
|
||||
{
|
||||
name: "complex message history with thinking, tools, tool calls, tool results and content",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"},
|
||||
{Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "Paris"})}},
|
||||
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "London"})}},
|
||||
}},
|
||||
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"},
|
||||
{Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"},
|
||||
{Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "calculate", Arguments: testArgs(map[string]any{"expression": "2+2"})}},
|
||||
}},
|
||||
{Role: "tool", Content: "4", ToolCallID: "call3"},
|
||||
{Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "calculate",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"expression": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||
"<function>\n<name>get_weather</name>\n" +
|
||||
"<parameters>\n" +
|
||||
"<parameter>\n<name>city</name>\n<type>string</type>\n</parameter>\n" +
|
||||
"</parameters>\n</function>\n" +
|
||||
"<function>\n<name>calculate</name>\n" +
|
||||
"<parameters>\n" +
|
||||
"<parameter>\n<name>expression</name>\n<type>string</type>\n</parameter>\n" +
|
||||
"</parameters>\n</function>\n</tools>\n\n" +
|
||||
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||
"- Required parameters MUST be specified\n" +
|
||||
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||
"</IMPORTANT><|im_end|>\n" +
|
||||
"<|im_start|>user\nWhat's the weather in Paris and London? Also, what's 2+2?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>\nI need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.\n</think>\n" +
|
||||
"<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>\n" +
|
||||
"<tool_call>\n<function=get_weather>\n<parameter=city>\nLondon\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||
"<|im_start|>user\n<tool_response>\nSunny, 22°C\n</tool_response>\n<tool_response>\nRainy, 15°C\n</tool_response>\n<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>\nNow I have the weather data. Let me calculate 2+2.\n</think>\n" +
|
||||
"<tool_call>\n<function=calculate>\n<parameter=expression>\n2+2\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||
"<|im_start|>user\n<tool_response>\n4\n</tool_response>\n<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>\nPerfect! I have all the information needed to provide a complete answer.\n</think>\n" +
|
||||
"Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "empty messages list",
|
||||
msgs: []api.Message{},
|
||||
thinkValue: nil,
|
||||
expected: "<|im_start|>system\n<|im_end|>\n<|im_start|>assistant\n<think></think>",
|
||||
},
|
||||
{
|
||||
name: "tool result with JSON content",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Get user info"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "get_user", Arguments: testArgs(map[string]any{"id": "123"})}},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_user",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||
"<function>\n<name>get_user</name>\n<parameters>\n" +
|
||||
"<parameter>\n<name>id</name>\n<type>string</type>\n</parameter>\n" +
|
||||
"</parameters>\n</function>\n</tools>\n\n" +
|
||||
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||
"- Required parameters MUST be specified\n" +
|
||||
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||
"</IMPORTANT><|im_end|>\n" +
|
||||
"<|im_start|>user\nGet user info<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think></think>\n" +
|
||||
"<tool_call>\n<function=get_user>\n<parameter=id>\n123\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||
"<|im_start|>user\n<tool_response>\n{\"name\": \"John\", \"age\": 30, \"active\": true}\n</tool_response>\n<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "assistant message with only thinking no content",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Think about this"},
|
||||
{Role: "assistant", Thinking: "Deep thoughts here...", Content: ""},
|
||||
{Role: "user", Content: "What did you think?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||
"<|im_start|>user\nThink about this<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think></think><|im_end|>\n" +
|
||||
"<|im_start|>user\nWhat did you think?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "tool call with complex nested argument",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Create data"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{
|
||||
Name: "create",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"data": map[string]any{"nested": "value", "count": 42},
|
||||
}),
|
||||
}},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Created"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "create",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||
"<function>\n<name>create</name>\n<parameters>\n" +
|
||||
"<parameter>\n<name>data</name>\n<type>object</type>\n</parameter>\n" +
|
||||
"</parameters>\n</function>\n</tools>\n\n" +
|
||||
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||
"- Required parameters MUST be specified\n" +
|
||||
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||
"</IMPORTANT><|im_end|>\n" +
|
||||
"<|im_start|>user\nCreate data<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think></think>\n" +
|
||||
"<tool_call>\n<function=create>\n<parameter=data>\n{\"count\":42,\"nested\":\"value\"}\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||
"<|im_start|>user\n<tool_response>\nCreated\n</tool_response>\n<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "content explaining the format itself",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "How do I format a tool call?"},
|
||||
{Role: "assistant", Content: "To call a tool, use <tool_call> tags with <function=name> inside."},
|
||||
{Role: "user", Content: "Thanks!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\n<|im_end|>\n" +
|
||||
"<|im_start|>user\nHow do I format a tool call?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think></think>To call a tool, use <tool_call> tags with <function=name> inside.<|im_end|>\n" +
|
||||
"<|im_start|>user\nThanks!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "unicode in content and tool args",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Translate 你好"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "translate", Arguments: testArgs(map[string]any{"text": "你好"})}},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Hello"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "translate",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"text": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"# Tools\n\nYou have access to the following functions:\n\n<tools>\n" +
|
||||
"<function>\n<name>translate</name>\n<parameters>\n" +
|
||||
"<parameter>\n<name>text</name>\n<type>string</type>\n</parameter>\n" +
|
||||
"</parameters>\n</function>\n</tools>\n\n" +
|
||||
"If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" +
|
||||
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n" +
|
||||
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" +
|
||||
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n" +
|
||||
"- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n" +
|
||||
"- Required parameters MUST be specified\n" +
|
||||
"- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" +
|
||||
"- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" +
|
||||
"</IMPORTANT><|im_end|>\n" +
|
||||
"<|im_start|>user\nTranslate 你好<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think></think>\n" +
|
||||
"<tool_call>\n<function=translate>\n<parameter=text>\n你好\n</parameter>\n</function>\n</tool_call>\n<|im_end|>\n" +
|
||||
"<|im_start|>user\n<tool_response>\nHello\n</tool_response>\n<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n<think>\n",
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0]Describe the first image.<|im_end|>\n<|im_start|>assistant\n<think></think>It shows something.<|im_end|>\n<|im_start|>user\n[img-1][img-2]Compare these.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
renderer := &Nemotron3NanoRenderer{}
|
||||
rendered, err := renderer.Render(tt.msgs, tt.tools, tt.thinkValue)
|
||||
rendered, err := renderer.Render(tt.msgs, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
|
||||
t.Fatalf("mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -95,6 +95,8 @@ func rendererForName(name string) Renderer {
|
||||
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
|
||||
case "lfm2-thinking":
|
||||
return &LFM2Renderer{IsThinking: true, useImgTags: RenderImgTags}
|
||||
case "laguna":
|
||||
return &LagunaRenderer{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
222
model/renderers/testdata/nemotron3nano_chat_template.jinja2
vendored
Normal file
222
model/renderers/testdata/nemotron3nano_chat_template.jinja2
vendored
Normal file
@@ -0,0 +1,222 @@
|
||||
{% macro render_extra_keys(json_dict, handled_keys) %}
|
||||
{%- if json_dict is mapping %}
|
||||
{%- for json_key in json_dict if json_key not in handled_keys %}
|
||||
{%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
|
||||
{{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
|
||||
{%- else %}
|
||||
{{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{% endmacro %}
|
||||
{%- set enable_thinking = enable_thinking if enable_thinking is defined else True %}
|
||||
{%- set reasoning_budget = reasoning_budget if reasoning_budget is defined else None %}
|
||||
{%- set truncate_history_thinking = truncate_history_thinking if truncate_history_thinking is defined else True %}
|
||||
{%- set response_format = response_format if response_format is defined else None %}
|
||||
|
||||
{# Scan messages for VLM thinking toggles to override enable_thinking #}
|
||||
{%- set toggle = namespace(enable=enable_thinking) %}
|
||||
{%- for m in messages %}
|
||||
{%- if m['role'] == 'user' or m['role'] == 'system' -%}
|
||||
{%- if m['content'] is string -%}
|
||||
{%- set c = m['content'] %}
|
||||
{%- if '/think' in c.replace('</think>', '') -%}
|
||||
{%- set toggle.enable = true -%}
|
||||
{%- elif '/no_think' in c -%}
|
||||
{%- set toggle.enable = false -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor %}
|
||||
|
||||
{# Prepare message iteration similar to LM template #}
|
||||
{%- set ns = namespace(last_user_idx = -1) %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- for m in loop_messages %}
|
||||
{%- if m["role"] == "user" %}
|
||||
{%- set ns.last_user_idx = loop.index0 %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{%- if messages[0]["role"] == "system" %}
|
||||
{%- set system_message = messages[0]["content"] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set system_message = "" %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{%- if not tools is defined %}
|
||||
{%- set tools = [] %}
|
||||
{%- endif %}
|
||||
{# Recompute last_user_idx relative to loop_messages after handling system #}
|
||||
{%- set ns = namespace(last_user_idx = -1) %}
|
||||
{%- for m in loop_messages %}
|
||||
{%- if m["role"] == "user" %}
|
||||
{%- set ns.last_user_idx = loop.index0 %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{# System preamble with LM formatting, sanitize thinking toggles #}
|
||||
{%- if system_message is defined %}
|
||||
{%- set sys_content = system_message | string %}
|
||||
{%- set sys_content = sys_content.replace('</think>', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '</think>') %}
|
||||
{{- "<|im_start|>system\n" + sys_content }}
|
||||
{%- else %}
|
||||
{%- if tools is iterable and tools | length > 0 %}
|
||||
{{- "<|im_start|>system\n" }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if tools is iterable and tools | length > 0 %}
|
||||
{%- if system_message is defined and system_message | length > 0 %}
|
||||
{{- "\n\n" }}
|
||||
{%- endif %}
|
||||
{{- "# Tools\n\nYou have access to the following functions:\n\n" }}
|
||||
{{- "<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{%- if tool.function is defined %}
|
||||
{%- set tool = tool.function %}
|
||||
{%- endif %}
|
||||
{{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
|
||||
{%- if tool.description is defined %}
|
||||
{{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
|
||||
{%- endif %}
|
||||
{{- '\n<parameters>' }}
|
||||
{%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
|
||||
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
||||
{{- '\n<parameter>' }}
|
||||
{{- '\n<name>' ~ param_name ~ '</name>' }}
|
||||
{%- if param_fields.type is defined %}
|
||||
{{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
|
||||
{%- endif %}
|
||||
{%- if param_fields.description is defined %}
|
||||
{{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
|
||||
{%- endif %}
|
||||
{%- if param_fields.enum is defined %}
|
||||
{{- '\n<enum>' ~ (param_fields.enum | tojson | safe) ~ '</enum>' }}
|
||||
{%- endif %}
|
||||
{%- set handled_keys = ['name', 'type', 'description', 'enum'] %}
|
||||
{{- render_extra_keys(param_fields, handled_keys) }}
|
||||
{{- '\n</parameter>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{% set handled_keys = ['type', 'properties', 'required'] %}
|
||||
{{- render_extra_keys(tool.parameters, handled_keys) }}
|
||||
{%- if tool.parameters is defined and tool.parameters.required is defined %}
|
||||
{{- '\n<required>' ~ (tool.parameters.required | tojson | safe) ~ '</required>' }}
|
||||
{%- endif %}
|
||||
{{- '\n</parameters>' }}
|
||||
{%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
|
||||
{{- render_extra_keys(tool, handled_keys) }}
|
||||
{{- '\n</function>' }}
|
||||
{%- endfor %}
|
||||
{{- "\n</tools>" }}
|
||||
|
||||
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
|
||||
{%- endif %}
|
||||
|
||||
{%- if system_message is defined %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- else %}
|
||||
{%- if tools is iterable and tools | length > 0 %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
||||
{# Iterate conversation #}
|
||||
{%- for message in loop_messages %}
|
||||
{%- if message.role == "assistant" %}
|
||||
{# Use LM assistant handling #}
|
||||
{%- if message.reasoning_content is defined and message.reasoning_content is string and message.reasoning_content | trim | length > 0 %}
|
||||
{%- set content = "<think>\n" ~ message.reasoning_content ~ "\n</think>\n" ~ (message.content | default('', true)) %}
|
||||
{%- else %}
|
||||
{%- set content = message.content | default('', true) %}
|
||||
{%- if content is string -%}
|
||||
{%- if '<think>' not in content and '</think>' not in content -%}
|
||||
{%- set content = "<think></think>" ~ content -%}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{%- set content = content -%}
|
||||
{%- endif -%}
|
||||
{%- endif %}
|
||||
{%- if message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- set include_content = not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
|
||||
{%- if content is string and content | trim | length > 0 %}
|
||||
{%- if include_content %}
|
||||
{{- (content | trim) ~ '\n' -}}
|
||||
{%- else %}
|
||||
{%- set c = (content | string) %}
|
||||
{%- if '</think>' in c %}
|
||||
{%- set c = c.split('</think>')[-1] %}
|
||||
{%- elif '<think>' in c %}
|
||||
{%- set c = c.split('<think>')[0] %}
|
||||
{%- endif %}
|
||||
{%- set c = "<think></think>" ~ c | trim %}
|
||||
{%- if c | length > 0 %}
|
||||
{{- c ~ '\n' -}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- "<think></think>" -}}
|
||||
{%- endif %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '<tool_call>\n<function=' ~ tool_call.name ~ '>\n' -}}
|
||||
{%- if tool_call.arguments is defined %}
|
||||
{%- for args_name, args_value in tool_call.arguments|items %}
|
||||
{{- '<parameter=' ~ args_name ~ '>\n' -}}
|
||||
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
|
||||
{{- args_value ~ '\n</parameter>\n' -}}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '</function>\n</tool_call>\n' -}}
|
||||
{%- endfor %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- else %}
|
||||
{%- if not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
|
||||
{{- '<|im_start|>assistant\n' ~ (content | default('', true) | string | trim) ~ '<|im_end|>\n' }}
|
||||
{%- else %}
|
||||
{%- set c = (content | default('', true) | string) %}
|
||||
{%- if '<think>' in c and '</think>' in c %}
|
||||
{%- set c = "<think></think>" ~ c.split('</think>')[-1] %}
|
||||
{%- endif %}
|
||||
{%- set c = c | trim %}
|
||||
{%- if c | length > 0 %}
|
||||
{{- '<|im_start|>assistant\n' ~ c ~ '<|im_end|>\n' }}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>assistant\n<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- elif message.role == "user" or message.role == "system" %}
|
||||
{{- '<|im_start|>' + message.role + '\n' }}
|
||||
{%- set content = message.content | string %}
|
||||
{{- content }}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
||||
{{- '<|im_start|>user\n' }}
|
||||
{%- endif %}
|
||||
{{- '<tool_response>\n' }}
|
||||
{{- message.content }}
|
||||
{{- '\n</tool_response>\n' }}
|
||||
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif loop.last %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{# Generation prompt using computed thinking toggle #}
|
||||
{%- if add_generation_prompt %}
|
||||
{%- if toggle.enable %}
|
||||
{{- '<|im_start|>assistant\n<think>\n' }}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>assistant\n<think></think>' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
@@ -494,15 +494,18 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
for _, layer := range baseLayers {
|
||||
if layer.GGML != nil {
|
||||
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
|
||||
ft := layer.GGML.KV().FileType()
|
||||
if quantType == "" && hasSourceFP8Tensors(layer.GGML.KV()) && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" && slices.Contains([]string{"F16", "BF16", "F32"}, ft.String()) {
|
||||
quantType = "Q8_0"
|
||||
}
|
||||
if quantType != "" && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" {
|
||||
want, err := ggml.ParseFileType(quantType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ft := layer.GGML.KV().FileType()
|
||||
if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
|
||||
return errors.New("quantization is only supported for F16 and F32 models")
|
||||
if !slices.Contains([]string{"F16", "BF16", "F32"}, ft.String()) {
|
||||
return errors.New("quantization is only supported for F16, BF16 and F32 models")
|
||||
} else if ft != want {
|
||||
layer, err = quantizeLayer(layer, quantType, fn)
|
||||
if err != nil {
|
||||
@@ -531,6 +534,12 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
}
|
||||
r.Parameters["stop"] = []string{"<turn|>"}
|
||||
}
|
||||
case "laguna":
|
||||
config.Renderer = cmp.Or(config.Renderer, "laguna")
|
||||
config.Parser = cmp.Or(config.Parser, "laguna")
|
||||
case "nemotron_h", "nemotron_h_moe", "nemotron_h_omni":
|
||||
config.Renderer = cmp.Or(config.Renderer, "nemotron-3-nano")
|
||||
config.Parser = cmp.Or(config.Parser, "nemotron-3-nano")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -606,6 +615,10 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasSourceFP8Tensors(kv ggml.KV) bool {
|
||||
return kv.String("source_quantization") == "hf_fp8" && len(kv.Strings("source_fp8_tensors")) > 0
|
||||
}
|
||||
|
||||
func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.ProgressResponse)) (*layerGGML, error) {
|
||||
ft := layer.GGML.KV().FileType()
|
||||
var doneBytes atomic.Uint64
|
||||
|
||||
90
server/laguna_quantization_test.go
Normal file
90
server/laguna_quantization_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
func TestLagunaGGUFQuantization(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
tensor string
|
||||
originalType fsggml.TensorType
|
||||
requestedType fsggml.TensorType
|
||||
fileType fsggml.FileType
|
||||
blockCount int
|
||||
wantType fsggml.TensorType
|
||||
wantQuantize bool
|
||||
}{
|
||||
{
|
||||
name: "non_routed_weights_preserved",
|
||||
tensor: "blk.1.attn_q.weight",
|
||||
originalType: fsggml.TensorTypeBF16,
|
||||
requestedType: fsggml.TensorTypeQ8_0,
|
||||
fileType: fsggml.FileTypeQ8_0,
|
||||
blockCount: 2,
|
||||
wantType: fsggml.TensorTypeBF16,
|
||||
wantQuantize: false,
|
||||
},
|
||||
{
|
||||
name: "shared_expert_weights_preserved",
|
||||
tensor: "blk.1.ffn_gate_shexp.weight",
|
||||
originalType: fsggml.TensorTypeBF16,
|
||||
requestedType: fsggml.TensorTypeQ4_K,
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
blockCount: 2,
|
||||
wantType: fsggml.TensorTypeBF16,
|
||||
wantQuantize: false,
|
||||
},
|
||||
{
|
||||
name: "routed_gate_q8",
|
||||
tensor: "blk.1.ffn_gate_exps.weight",
|
||||
originalType: fsggml.TensorTypeBF16,
|
||||
requestedType: fsggml.TensorTypeQ8_0,
|
||||
fileType: fsggml.FileTypeQ8_0,
|
||||
blockCount: 2,
|
||||
wantType: fsggml.TensorTypeQ8_0,
|
||||
wantQuantize: true,
|
||||
},
|
||||
{
|
||||
name: "routed_down_q4_promoted",
|
||||
tensor: "blk.1.ffn_down_exps.weight",
|
||||
originalType: fsggml.TensorTypeBF16,
|
||||
requestedType: fsggml.TensorTypeQ4_K,
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
blockCount: 2,
|
||||
wantType: fsggml.TensorTypeQ6_K,
|
||||
wantQuantize: true,
|
||||
},
|
||||
{
|
||||
name: "routed_down_q4_not_promoted_when_q8_requested",
|
||||
tensor: "blk.1.ffn_down_exps.weight",
|
||||
originalType: fsggml.TensorTypeBF16,
|
||||
requestedType: fsggml.TensorTypeQ8_0,
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
blockCount: 2,
|
||||
wantType: fsggml.TensorTypeQ8_0,
|
||||
wantQuantize: true,
|
||||
},
|
||||
{
|
||||
name: "routed_down_q4_k_s_promoted",
|
||||
tensor: "blk.0.ffn_down_exps.weight",
|
||||
originalType: fsggml.TensorTypeBF16,
|
||||
requestedType: fsggml.TensorTypeQ4_K,
|
||||
fileType: fsggml.FileTypeQ4_K_S,
|
||||
blockCount: 8,
|
||||
wantType: fsggml.TensorTypeQ5_K,
|
||||
wantQuantize: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotType, gotQuantize := lagunaGGUFQuantization(tt.tensor, tt.originalType, tt.requestedType, tt.fileType, tt.blockCount)
|
||||
if gotType != tt.wantType || gotQuantize != tt.wantQuantize {
|
||||
t.Fatalf("lagunaGGUFQuantization(%q) = (%s, %v), want (%s, %v)", tt.tensor, gotType, gotQuantize, tt.wantType, tt.wantQuantize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
@@ -51,11 +52,14 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
||||
}
|
||||
|
||||
type quantizeState struct {
|
||||
nAttnV int // Number of attn_*v* weight tensors
|
||||
nFfnDown int // Number of ffn_down tensors
|
||||
iAttnV int // Running counter of number of attn_v tensors that have been processed
|
||||
iFfnDown int // Running counter of number of ffn_down tensors that have been processed
|
||||
hasOutput bool // used to figure out if a model shares tok_embd with the output weight
|
||||
nAttnV int // Number of attn_*v* weight tensors
|
||||
nFfnDown int // Number of ffn_down tensors
|
||||
iAttnV int // Running counter of number of attn_v tensors that have been processed
|
||||
iFfnDown int // Running counter of number of ffn_down tensors that have been processed
|
||||
hasOutput bool // used to figure out if a model shares tok_embd with the output weight
|
||||
preserveSourceFP8ToQ8 bool
|
||||
preserveSourceQ4 bool
|
||||
sourceFP8Tensors map[string]struct{}
|
||||
}
|
||||
|
||||
func useMoreBits(iLayer, nLayers int) bool {
|
||||
@@ -108,6 +112,53 @@ func qwen3LinearAttnQuantType(name string) (fsggml.TensorType, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func isLagunaGGUFRoutedExpertWeight(name string) bool {
|
||||
return strings.HasSuffix(name, ".weight") && (strings.Contains(name, "ffn_gate_exps") ||
|
||||
strings.Contains(name, "ffn_up_exps") ||
|
||||
strings.Contains(name, "ffn_down_exps"))
|
||||
}
|
||||
|
||||
func lagunaGGUFBlockIndex(name string) (int, bool) {
|
||||
if !strings.HasPrefix(name, "blk.") {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
parts := strings.SplitN(strings.TrimPrefix(name, "blk."), ".", 2)
|
||||
if len(parts) != 2 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
i, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return i, true
|
||||
}
|
||||
|
||||
func lagunaGGUFQuantization(name string, originalType, requestedType fsggml.TensorType, ftype fsggml.FileType, blockCount int) (fsggml.TensorType, bool) {
|
||||
if !isLagunaGGUFRoutedExpertWeight(name) {
|
||||
return originalType, false
|
||||
}
|
||||
|
||||
if strings.HasSuffix(name, ".ffn_down_exps.weight") {
|
||||
if i, ok := lagunaGGUFBlockIndex(name); ok && blockCount > 0 {
|
||||
switch ftype {
|
||||
case fsggml.FileTypeQ4_K_M:
|
||||
if requestedType != fsggml.TensorTypeQ8_0 && useMoreBits(i, blockCount) {
|
||||
return fsggml.TensorTypeQ6_K, true
|
||||
}
|
||||
case fsggml.FileTypeQ4_K_S:
|
||||
if requestedType != fsggml.TensorTypeQ8_0 && i < blockCount/8 {
|
||||
return fsggml.TensorTypeQ5_K, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return requestedType, true
|
||||
}
|
||||
|
||||
func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType, name string, shape []uint64, ftype fsggml.FileType) fsggml.TensorType {
|
||||
// Ported from llama_tensor_get_type, removed unsupported quantization types
|
||||
nExperts := max(1, kv.Uint("expert_count", 0))
|
||||
@@ -120,10 +171,10 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
|
||||
newType = fsggml.TensorTypeQ6_K
|
||||
}
|
||||
} else if strings.Contains(name, "attn_v.weight") {
|
||||
if (ftype == fsggml.FileTypeQ4_K_M) &&
|
||||
if newType != fsggml.TensorTypeQ8_0 && (ftype == fsggml.FileTypeQ4_K_M) &&
|
||||
useMoreBits(qs.iAttnV, qs.nAttnV) {
|
||||
newType = fsggml.TensorTypeQ6_K
|
||||
} else if ftype == fsggml.FileTypeQ4_K_S && qs.iAttnV < 4 {
|
||||
} else if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_S && qs.iAttnV < 4 {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
|
||||
@@ -158,31 +209,35 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
|
||||
// expert alphabetically, so dense increments the counter and expert uses counter-1.
|
||||
var iLayer int
|
||||
if strings.Contains(name, "_exps") {
|
||||
if kv.Architecture() == "laguna" {
|
||||
goto finalize
|
||||
}
|
||||
iLayer = max(0, qs.iFfnDown-1)
|
||||
} else {
|
||||
iLayer = qs.iFfnDown
|
||||
qs.iFfnDown++
|
||||
}
|
||||
n_layer := qs.nFfnDown
|
||||
if ftype == fsggml.FileTypeQ4_K_M {
|
||||
if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_M {
|
||||
if useMoreBits(iLayer, n_layer) {
|
||||
newType = fsggml.TensorTypeQ6_K
|
||||
}
|
||||
} else if ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 {
|
||||
} else if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
} else if strings.Contains(name, "attn_output.weight") {
|
||||
if nExperts == 8 {
|
||||
if newType != fsggml.TensorTypeQ8_0 && nExperts == 8 {
|
||||
if ftype == fsggml.FileTypeQ4_K_S || ftype == fsggml.FileTypeQ4_K_M {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(name, "attn_qkv.weight") {
|
||||
if ftype == fsggml.FileTypeQ4_K_M {
|
||||
if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_M {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
}
|
||||
|
||||
finalize:
|
||||
if newType.IsQuantized() {
|
||||
nx := shape[0]
|
||||
qk_k := newType.BlockSize()
|
||||
@@ -218,7 +273,12 @@ func quantize(in, out *os.File, orig *fsggml.GGML, newFileType fsggml.FileType,
|
||||
kv := maps.Clone(orig.KV())
|
||||
kv["general.file_type"] = newFileType
|
||||
// kv["general.quantization_version"] = ggml.QuantizationVersion()
|
||||
qs := &quantizeState{}
|
||||
qs := &quantizeState{
|
||||
sourceFP8Tensors: sourceFP8TensorSet(kv),
|
||||
}
|
||||
hasSourceFP8 := hasSourceFP8Tensors(kv)
|
||||
qs.preserveSourceFP8ToQ8 = hasSourceFP8 && newFileType == fsggml.FileTypeQ8_0
|
||||
qs.preserveSourceQ4 = hasSourceFP8 && slices.Contains([]fsggml.FileType{fsggml.FileTypeQ4_K_M, fsggml.FileTypeQ4_K_S}, newFileType)
|
||||
// Build up the quantize state so newType can adjust types
|
||||
layerCount := 0
|
||||
for k, l := range orig.Tensors().GroupLayers() {
|
||||
@@ -304,13 +364,34 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
|
||||
newType := fsggml.TensorType(t.Kind)
|
||||
if quantize {
|
||||
if qs.preserveSourceFP8ToQ8 {
|
||||
if _, ok := qs.sourceFP8Tensors[name]; !ok {
|
||||
return newType
|
||||
}
|
||||
}
|
||||
|
||||
if slices.Contains([]string{"qwen3next", "qwen35", "qwen35moe"}, kv.Architecture()) && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
|
||||
if qt, ok := qwen3LinearAttnQuantType(name); ok {
|
||||
return qt
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Consider extracting architecture-specific GGUF quantization policy
|
||||
// from server so different quantization backends can share one source of
|
||||
// truth for model-family specializations.
|
||||
// get more optimal quantization type based on the tensor shape, layer, etc.
|
||||
if qs.preserveSourceQ4 {
|
||||
if _, ok := qs.sourceFP8Tensors[name]; !ok {
|
||||
defaultType = fsggml.TensorTypeQ8_0
|
||||
}
|
||||
}
|
||||
if kv.Architecture() == "laguna" {
|
||||
var ok bool
|
||||
defaultType, ok = lagunaGGUFQuantization(name, newType, defaultType, ftype, int(kv.Uint("block_count", 0)))
|
||||
if !ok {
|
||||
return newType
|
||||
}
|
||||
}
|
||||
newType = getTensorNewType(kv, qs, defaultType, t.Name, t.Shape, ftype)
|
||||
if newType != defaultType {
|
||||
slog.Debug("tensor quantization adjusted for better quality", "name", t.Name, "requested", defaultType, "quantization", newType)
|
||||
@@ -318,3 +399,16 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
}
|
||||
return newType
|
||||
}
|
||||
|
||||
func sourceFP8TensorSet(kv fsggml.KV) map[string]struct{} {
|
||||
names := kv.Strings("source_fp8_tensors")
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make(map[string]struct{}, len(names))
|
||||
for _, name := range names {
|
||||
out[name] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -308,6 +308,95 @@ func TestQuantizeModel(t *testing.T) {
|
||||
"output.weight": fsggml.TensorTypeQ8_0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "source_fp8_q8_preserves_bf16_tensors",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "test",
|
||||
"source_quantization": "hf_fp8",
|
||||
"source_fp8_tensors": []string{"blk.1.ffn_down_exps.weight"},
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.1.ffn_down_exps.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.attn_q.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
},
|
||||
newType: "Q8_0",
|
||||
expectedTensorTypes: map[string]fsggml.TensorType{
|
||||
"blk.1.ffn_down_exps.weight": fsggml.TensorTypeQ8_0,
|
||||
"blk.1.attn_q.weight": fsggml.TensorTypeBF16,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "source_fp8_q4_promotes_bf16_tensors_to_q8",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "test",
|
||||
"source_quantization": "hf_fp8",
|
||||
"source_fp8_tensors": []string{
|
||||
"blk.1.ffn_gate_exps.weight",
|
||||
"blk.1.ffn_down_exps.weight",
|
||||
},
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.1.ffn_gate_exps.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.ffn_down_exps.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.attn_q.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.attn_v.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.ffn_down.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.attn_q_norm.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.ffn_gate_inp.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "output.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
},
|
||||
newType: "Q4_K_M",
|
||||
expectedTensorTypes: map[string]fsggml.TensorType{
|
||||
"blk.1.ffn_gate_exps.weight": fsggml.TensorTypeQ4_K,
|
||||
"blk.1.ffn_down_exps.weight": fsggml.TensorTypeQ6_K,
|
||||
"blk.1.attn_q.weight": fsggml.TensorTypeQ8_0,
|
||||
"blk.1.attn_v.weight": fsggml.TensorTypeQ8_0,
|
||||
"blk.1.ffn_down.weight": fsggml.TensorTypeQ8_0,
|
||||
"blk.1.attn_q_norm.weight": fsggml.TensorTypeBF16,
|
||||
"blk.1.ffn_gate_inp.weight": fsggml.TensorTypeBF16,
|
||||
"output.weight": fsggml.TensorTypeQ8_0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "f32_short_data",
|
||||
kv: map[string]any{
|
||||
|
||||
@@ -618,8 +618,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
// only send messages with meaningful content (empty messages confuse clients)
|
||||
if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 {
|
||||
// Emit chunks that carry logprobs even if the parser is still buffering
|
||||
// visible content, otherwise generate logprobs disappear for models with
|
||||
// builtin thinking/tool parsers.
|
||||
if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 || len(res.Logprobs) > 0 {
|
||||
ch <- res
|
||||
}
|
||||
|
||||
|
||||
@@ -102,6 +102,36 @@ func createRequest(t *testing.T, fn func(*gin.Context), body any) *httptest.Resp
|
||||
return w.ResponseRecorder
|
||||
}
|
||||
|
||||
func readCreatedModelConfig(t *testing.T, name string) model.ConfigV2 {
|
||||
t.Helper()
|
||||
|
||||
mf, err := manifest.ParseNamedManifest(model.ParseName(name))
|
||||
if err != nil {
|
||||
t.Fatalf("parse manifest: %v", err)
|
||||
}
|
||||
if mf.Config.Digest == "" {
|
||||
t.Fatalf("unexpected empty config digest for manifest")
|
||||
}
|
||||
|
||||
configPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
t.Fatalf("config blob path: %v", err)
|
||||
}
|
||||
|
||||
cfgFile, err := os.Open(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open config blob: %v", err)
|
||||
}
|
||||
defer cfgFile.Close()
|
||||
|
||||
var cfg model.ConfigV2
|
||||
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
|
||||
t.Fatalf("decode config: %v", err)
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func checkFileExists(t *testing.T, p string, expect []string) {
|
||||
t.Helper()
|
||||
|
||||
@@ -981,6 +1011,134 @@ func TestCreateGemma4KeepsDynamicRendererAlias(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateLagunaDetectsRendererParser(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
var s Server
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "laguna",
|
||||
"general.parameter_count": uint64(33_400_000_000),
|
||||
}, nil)
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
mf, err := manifest.ParseNamedManifest(model.ParseName("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("parse manifest: %v", err)
|
||||
}
|
||||
if mf.Config.Digest == "" {
|
||||
t.Fatalf("unexpected empty config digest for manifest")
|
||||
}
|
||||
|
||||
configPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
t.Fatalf("config blob path: %v", err)
|
||||
}
|
||||
|
||||
cfgFile, err := os.Open(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open config blob: %v", err)
|
||||
}
|
||||
defer cfgFile.Close()
|
||||
|
||||
var cfg model.ConfigV2
|
||||
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
|
||||
t.Fatalf("decode config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Renderer != "laguna" {
|
||||
t.Fatalf("expected renderer %q, got %q", "laguna", cfg.Renderer)
|
||||
}
|
||||
if cfg.Parser != "laguna" {
|
||||
t.Fatalf("expected parser %q, got %q", "laguna", cfg.Parser)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateNemotronHDefaultsRendererParser(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
for _, arch := range []string{"nemotron_h", "nemotron_h_moe", "nemotron_h_omni"} {
|
||||
t.Run(arch, func(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
var s Server
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": arch,
|
||||
}, nil)
|
||||
|
||||
name := strings.ReplaceAll(arch, "_", "-")
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: name,
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
cfg := readCreatedModelConfig(t, name)
|
||||
if cfg.Renderer != "nemotron-3-nano" {
|
||||
t.Fatalf("expected renderer %q, got %q", "nemotron-3-nano", cfg.Renderer)
|
||||
}
|
||||
if cfg.Parser != "nemotron-3-nano" {
|
||||
t.Fatalf("expected parser %q, got %q", "nemotron-3-nano", cfg.Parser)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateNemotronHDefaultsKeepExplicitRendererParser(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
for _, arch := range []string{"nemotron_h", "nemotron_h_moe", "nemotron_h_omni"} {
|
||||
t.Run(arch, func(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
var s Server
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": arch,
|
||||
}, nil)
|
||||
|
||||
const (
|
||||
renderer = "custom-renderer"
|
||||
parser = "custom-parser"
|
||||
)
|
||||
|
||||
name := strings.ReplaceAll(arch, "_", "-") + "-custom"
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: name,
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Renderer: renderer,
|
||||
Parser: parser,
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
cfg := readCreatedModelConfig(t, name)
|
||||
if cfg.Renderer != renderer {
|
||||
t.Fatalf("expected renderer %q, got %q", renderer, cfg.Renderer)
|
||||
}
|
||||
if cfg.Parser != parser {
|
||||
t.Fatalf("expected parser %q, got %q", parser, cfg.Parser)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectModelTypeFromFiles(t *testing.T) {
|
||||
t.Run("gguf file", func(t *testing.T) {
|
||||
_, digest := createBinFile(t, nil, nil)
|
||||
|
||||
@@ -1473,6 +1473,119 @@ func TestGenerateLogprobs(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateLogprobsWithBuiltinParser(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{}
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
responses := []llm.CompletionResponse{
|
||||
{
|
||||
Content: "h",
|
||||
Logprobs: []llm.Logprob{{TokenLogprob: llm.TokenLogprob{Token: "h", Logprob: -0.1}}},
|
||||
},
|
||||
{
|
||||
Content: "i</think>Hello",
|
||||
Logprobs: []llm.Logprob{{TokenLogprob: llm.TokenLogprob{Token: "hi", Logprob: -0.2}}},
|
||||
},
|
||||
{
|
||||
Content: " world",
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
Logprobs: []llm.Logprob{{TokenLogprob: llm.TokenLogprob{Token: " world", Logprob: -0.3}}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, resp := range responses {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
fn(resp)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
req.successCh <- &runnerRef{llama: &mock}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
if w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-generate-logprob-parser",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Parser: "deepseek3",
|
||||
Template: `{{ .Prompt }}`,
|
||||
Stream: &stream,
|
||||
}); w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
noStream := false
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-generate-logprob-parser",
|
||||
Prompt: "Why is the sky blue?",
|
||||
Stream: &noStream,
|
||||
Logprobs: true,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp api.GenerateResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if got := len(resp.Logprobs); got != 3 {
|
||||
t.Fatalf("expected 3 logprob entries, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatLogprobs(t *testing.T) {
|
||||
t.Run("invalid top_logprobs negative", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -418,7 +418,7 @@ func (s *Scheduler) load(req *LlmRequest, systemInfo ml.SystemInfo, gpus []ml.De
|
||||
|
||||
// Some architectures are not safe with num_parallel > 1.
|
||||
// ref: https://github.com/ollama/ollama/issues/4165
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen35", "qwen35moe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen35", "qwen35moe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe", "nemotron_h_omni"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
numParallel = 1
|
||||
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
modelparsers "github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -132,11 +133,11 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
if isSafetensors {
|
||||
modelType = "safetensors model"
|
||||
spinnerKey = "create"
|
||||
capabilities = inferSafetensorsCapabilities(opts.ModelDir)
|
||||
|
||||
// Set parser and renderer name based on architecture
|
||||
parserName = getParserName(opts.ModelDir)
|
||||
rendererName = getRendererName(opts.ModelDir)
|
||||
capabilities = inferSafetensorsCapabilities(opts.ModelDir, resolveParserName(opts.Modelfile, parserName))
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
@@ -183,7 +184,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func inferSafetensorsCapabilities(modelDir string) []string {
|
||||
func inferSafetensorsCapabilities(modelDir, parserName string) []string {
|
||||
capabilities := []string{"completion"}
|
||||
|
||||
// Qwen3.5 multimodal checkpoints use ConditionalGeneration architectures.
|
||||
@@ -195,7 +196,16 @@ func inferSafetensorsCapabilities(modelDir string) []string {
|
||||
capabilities = append(capabilities, "audio")
|
||||
}
|
||||
|
||||
if supportsThinking(modelDir) {
|
||||
var builtinParser modelparsers.Parser
|
||||
if parserName != "" {
|
||||
builtinParser = modelparsers.ParserForName(parserName)
|
||||
}
|
||||
|
||||
if builtinParser != nil && builtinParser.HasToolSupport() {
|
||||
capabilities = append(capabilities, "tools")
|
||||
}
|
||||
|
||||
if supportsThinking(modelDir) || (builtinParser != nil && builtinParser.HasThinkingSupport()) {
|
||||
capabilities = append(capabilities, "thinking")
|
||||
}
|
||||
|
||||
@@ -453,8 +463,8 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// supportsThinking checks if the model supports thinking mode based on its architecture.
|
||||
// This reads the config.json from the model directory and checks the architectures field.
|
||||
// supportsThinking checks if the model supports thinking mode based on known
|
||||
// architectures that do not expose a cleaner signal in their local metadata.
|
||||
func supportsThinking(modelDir string) bool {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
@@ -554,6 +564,9 @@ func getParserName(modelDir string) string {
|
||||
// Check architectures for known parsers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "laguna") {
|
||||
return "laguna"
|
||||
}
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
@@ -571,6 +584,9 @@ func getParserName(modelDir string) string {
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "laguna") {
|
||||
return "laguna"
|
||||
}
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
@@ -608,6 +624,9 @@ func getRendererName(modelDir string) string {
|
||||
// Check architectures for known renderers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "laguna") {
|
||||
return "laguna"
|
||||
}
|
||||
if strings.Contains(archLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
@@ -625,6 +644,9 @@ func getRendererName(modelDir string) string {
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "laguna") {
|
||||
return "laguna"
|
||||
}
|
||||
if strings.Contains(typeLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
|
||||
@@ -352,7 +352,7 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got := inferSafetensorsCapabilities(dir); !slices.Equal(got, tt.want) {
|
||||
if got := inferSafetensorsCapabilities(dir, ""); !slices.Equal(got, tt.want) {
|
||||
t.Fatalf("inferSafetensorsCapabilities() = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
})
|
||||
@@ -554,6 +554,11 @@ func TestSupportsThinking(t *testing.T) {
|
||||
configJSON: `{"model_type": "deepseek"}`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "laguna architecture without template",
|
||||
configJSON: `{"architectures": ["LagunaForCausalLM"], "model_type": "laguna"}`,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty config",
|
||||
configJSON: `{}`,
|
||||
@@ -584,6 +589,55 @@ func TestSupportsThinking_NoConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferSafetensorsCapabilitiesFromParser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
parserName string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "laguna tools and thinking",
|
||||
parserName: "laguna",
|
||||
want: []string{"completion", "tools", "thinking"},
|
||||
},
|
||||
{
|
||||
name: "functiongemma tools only",
|
||||
parserName: "functiongemma",
|
||||
want: []string{"completion", "tools"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got := inferSafetensorsCapabilities(dir, tt.parserName); !slices.Equal(got, tt.want) {
|
||||
t.Fatalf("inferSafetensorsCapabilities() = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferSafetensorsCapabilitiesLaguna(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"architectures": ["LagunaForCausalLM"], "model_type": "laguna"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := inferSafetensorsCapabilities(dir, "laguna")
|
||||
for _, want := range []string{"completion", "tools", "thinking"} {
|
||||
if !slices.Contains(got, want) {
|
||||
t.Fatalf("capabilities %v missing %q", got, want)
|
||||
}
|
||||
}
|
||||
if slices.Contains(got, "vision") || slices.Contains(got, "audio") {
|
||||
t.Fatalf("unexpected non-text capability in %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetParserName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -615,6 +669,11 @@ func TestGetParserName(t *testing.T) {
|
||||
configJSON: `{"model_type": "qwen3"}`,
|
||||
want: "qwen3",
|
||||
},
|
||||
{
|
||||
name: "laguna model",
|
||||
configJSON: `{"architectures": ["LagunaForCausalLM"], "model_type": "laguna"}`,
|
||||
want: "laguna",
|
||||
},
|
||||
{
|
||||
name: "no config",
|
||||
configJSON: `{}`,
|
||||
@@ -660,6 +719,11 @@ func TestGetRendererName(t *testing.T) {
|
||||
configJSON: `{"architectures": ["LlamaForCausalLM"]}`,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "laguna model",
|
||||
configJSON: `{"architectures": ["LagunaForCausalLM"], "model_type": "laguna"}`,
|
||||
want: "laguna",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -776,6 +776,11 @@ type tensorImportTransform interface {
|
||||
quantizationType(name string, shape []int32, quantize string) string
|
||||
}
|
||||
|
||||
type sourceFP8TensorImportTransform interface {
|
||||
sourceFP8TensorQuantization(name string, shape []int32, requested string, fallback string) string
|
||||
sourceFP8BF16Quantization(name string, shape []int32, requested string) string
|
||||
}
|
||||
|
||||
type noopImportTransform struct{}
|
||||
|
||||
func (noopImportTransform) skipTensor(string) bool { return false }
|
||||
@@ -804,6 +809,7 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
|
||||
"Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform,
|
||||
"Gemma4ForCausalLM": newGemma4ImportTransform,
|
||||
"Gemma4ForConditionalGeneration": newGemma4ImportTransform,
|
||||
"LagunaForCausalLM": newLagunaImportTransform,
|
||||
}
|
||||
|
||||
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
|
||||
@@ -842,6 +848,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to construct import transform for architecture %q: %w", sourceConfig.Architecture(), err)
|
||||
}
|
||||
sourceFP8Transform, _ := importTransform.(sourceFP8TensorImportTransform)
|
||||
|
||||
// Resolve the optional packed layer creator
|
||||
var packedCreator PackedTensorLayerCreator
|
||||
@@ -991,9 +998,17 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
// synthetic tests may not pass the generic import size filter.
|
||||
quantizeType = "mxfp8"
|
||||
}
|
||||
quantizeType = sourceFP8TensorQuantization(outTD.Name, outTD.Shape, quantize, quantizeType)
|
||||
if sourceFP8Transform != nil {
|
||||
quantizeType = sourceFP8Transform.sourceFP8TensorQuantization(outTD.Name, outTD.Shape, quantize, quantizeType)
|
||||
} else {
|
||||
quantizeType = sourceFP8TensorQuantization(outTD.Name, outTD.Shape, quantize, quantizeType)
|
||||
}
|
||||
case sourceQuantKind == sourceQuantizedKindSourceFP8:
|
||||
quantizeType = sourceFP8BF16PromotionQuantization(outTD.Name, outTD.Shape, quantize)
|
||||
if sourceFP8Transform != nil {
|
||||
quantizeType = sourceFP8Transform.sourceFP8BF16Quantization(outTD.Name, outTD.Shape, quantize)
|
||||
} else {
|
||||
quantizeType = sourceFP8BF16PromotionQuantization(outTD.Name, outTD.Shape, quantize)
|
||||
}
|
||||
case effectiveQuantize != "":
|
||||
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, effectiveQuantize)
|
||||
}
|
||||
|
||||
59
x/create/laguna.go
Normal file
59
x/create/laguna.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/safetensors"
|
||||
)
|
||||
|
||||
type lagunaImportTransform struct{}
|
||||
|
||||
func newLagunaImportTransform(string, sourceModelConfig) (tensorImportTransform, error) {
|
||||
return lagunaImportTransform{}, nil
|
||||
}
|
||||
|
||||
func (lagunaImportTransform) skipTensor(string) bool { return false }
|
||||
|
||||
func (lagunaImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
||||
if td == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return []*safetensors.TensorData{td}, nil
|
||||
}
|
||||
|
||||
func (lagunaImportTransform) quantizationType(name string, shape []int32, quantize string) string {
|
||||
if !lagunaIsHFRoutedExpertWeight(name) {
|
||||
return ""
|
||||
}
|
||||
return GetTensorQuantization(name, shape, quantize)
|
||||
}
|
||||
|
||||
func (lagunaImportTransform) sourceFP8TensorQuantization(name string, shape []int32, requested string, fallback string) string {
|
||||
if !lagunaIsHFRoutedExpertWeight(name) {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch normalizeQuantType(requested) {
|
||||
case "nvfp4", "mxfp4":
|
||||
if lagunaKeepSourceFP8TensorAtMXFP8(name, shape) {
|
||||
return "mxfp8"
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func (lagunaImportTransform) sourceFP8BF16Quantization(string, []int32, string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func lagunaKeepSourceFP8TensorAtMXFP8(name string, shape []int32) bool {
|
||||
if len(shape) != 2 || !isAligned(shape, "mxfp8") {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.Contains(name, "down_proj")
|
||||
}
|
||||
|
||||
func lagunaIsHFRoutedExpertWeight(name string) bool {
|
||||
return strings.HasSuffix(name, ".weight") && strings.Contains(name, ".mlp.experts.")
|
||||
}
|
||||
200
x/create/laguna_test.go
Normal file
200
x/create/laguna_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
st "github.com/ollama/ollama/x/safetensors"
|
||||
)
|
||||
|
||||
func TestCreateSafetensorsModel_LagunaHFFP8RespectsSourceTensorPrecision(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requested string
|
||||
wantFP8Gate string
|
||||
wantFP8Up string
|
||||
wantFP8Down string
|
||||
wantBF16QProj string
|
||||
}{
|
||||
{
|
||||
name: "default mxfp8 import keeps source bf16 tensors",
|
||||
requested: "",
|
||||
wantFP8Gate: "mxfp8",
|
||||
wantFP8Up: "mxfp8",
|
||||
wantFP8Down: "mxfp8",
|
||||
wantBF16QProj: "",
|
||||
},
|
||||
{
|
||||
name: "nvfp4 import keeps source bf16 tensors and preserves down_proj at mxfp8",
|
||||
requested: "nvfp4",
|
||||
wantFP8Gate: "nvfp4",
|
||||
wantFP8Up: "nvfp4",
|
||||
wantFP8Down: "mxfp8",
|
||||
wantBF16QProj: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configJSON := `{
|
||||
"model_type": "laguna",
|
||||
"architectures": ["LagunaForCausalLM"],
|
||||
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("model.embed_tokens.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.gate.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
})
|
||||
|
||||
quantizeByName := make(map[string]string)
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quantizeByName[name] = quantize
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, tt.requested, createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
if got := quantizeByName["model.layers.0.mlp.experts.0.gate_proj.weight"]; got != tt.wantFP8Gate {
|
||||
t.Fatalf("gate_proj quantization = %q, want %q", got, tt.wantFP8Gate)
|
||||
}
|
||||
if got := quantizeByName["model.layers.0.mlp.experts.0.up_proj.weight"]; got != tt.wantFP8Up {
|
||||
t.Fatalf("up_proj quantization = %q, want %q", got, tt.wantFP8Up)
|
||||
}
|
||||
if got := quantizeByName["model.layers.0.mlp.experts.0.down_proj.weight"]; got != tt.wantFP8Down {
|
||||
t.Fatalf("down_proj quantization = %q, want %q", got, tt.wantFP8Down)
|
||||
}
|
||||
for _, name := range []string{
|
||||
"model.layers.0.self_attn.q_proj.weight",
|
||||
"model.embed_tokens.weight",
|
||||
"lm_head.weight",
|
||||
"model.layers.0.mlp.gate.weight",
|
||||
} {
|
||||
if got := quantizeByName[name]; got != tt.wantBF16QProj {
|
||||
t.Fatalf("%s quantization = %q, want %q", name, got, tt.wantBF16QProj)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_LagunaBF16QuantizesOnlyRoutedExperts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requested string
|
||||
want map[string]string
|
||||
}{
|
||||
{
|
||||
name: "int8 quantizes only routed experts",
|
||||
requested: "int8",
|
||||
want: map[string]string{
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight": "int8",
|
||||
"model.layers.0.mlp.experts.0.up_proj.weight": "int8",
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight": "int8",
|
||||
"model.layers.0.mlp.shared_experts.gate_proj.weight": "",
|
||||
"model.layers.0.mlp.shared_experts.down_proj.weight": "",
|
||||
"model.layers.0.self_attn.q_proj.weight": "",
|
||||
"model.layers.0.mlp.down_proj.weight": "",
|
||||
"model.embed_tokens.weight": "",
|
||||
"lm_head.weight": "",
|
||||
"model.layers.0.mlp.gate.weight": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int4 keeps routed down_proj at int8 and leaves others bf16",
|
||||
requested: "int4",
|
||||
want: map[string]string{
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight": "int4",
|
||||
"model.layers.0.mlp.experts.0.up_proj.weight": "int4",
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight": "int8",
|
||||
"model.layers.0.mlp.shared_experts.gate_proj.weight": "",
|
||||
"model.layers.0.mlp.shared_experts.down_proj.weight": "",
|
||||
"model.layers.0.self_attn.q_proj.weight": "",
|
||||
"model.layers.0.mlp.down_proj.weight": "",
|
||||
"model.embed_tokens.weight": "",
|
||||
"lm_head.weight": "",
|
||||
"model.layers.0.mlp.gate.weight": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configJSON := `{
|
||||
"model_type": "laguna",
|
||||
"architectures": ["LagunaForCausalLM"]
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.gate_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.up_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.experts.0.down_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.shared_experts.gate_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.shared_experts.down_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.down_proj.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("model.embed_tokens.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.mlp.gate.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
})
|
||||
|
||||
quantizeByName := make(map[string]string)
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quantizeByName[name] = quantize
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, tt.requested, createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
for name, want := range tt.want {
|
||||
if got := quantizeByName[name]; got != want {
|
||||
t.Fatalf("%s quantization = %q, want %q", name, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
_ "github.com/ollama/ollama/x/models/gemma3"
|
||||
_ "github.com/ollama/ollama/x/models/gemma4"
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/laguna"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5"
|
||||
|
||||
@@ -34,6 +34,19 @@ var SiLU = Compile1(
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SoftplusF32 returns softplus(x) computed in float32 precision and cast back
|
||||
// to x's original dtype, as a fused kernel. Matches the laguna attention
|
||||
// output-gate formula: softplus(cast_f32(x)).cast(orig_dtype).
|
||||
var SoftplusF32 = Compile1(
|
||||
"SoftplusF32",
|
||||
func(x *Array) *Array {
|
||||
dt := x.DType()
|
||||
zero := FromValue[float32](0)
|
||||
return Logaddexp(x.AsType(DTypeFloat32), zero).AsType(dt)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SwiGLU returns silu(gate) * up as a fused kernel.
|
||||
var SwiGLU = Compile2(
|
||||
"SwiGLU",
|
||||
|
||||
1216
x/models/laguna/laguna.go
Normal file
1216
x/models/laguna/laguna.go
Normal file
File diff suppressed because it is too large
Load Diff
509
x/models/laguna/laguna_test.go
Normal file
509
x/models/laguna/laguna_test.go
Normal file
@@ -0,0 +1,509 @@
|
||||
package laguna
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
func TestParseConfigLagunaXS(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
cfg, err := parseConfig([]byte(`{
|
||||
"model_type": "laguna",
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 8192,
|
||||
"moe_intermediate_size": 512,
|
||||
"shared_expert_intermediate_size": 512,
|
||||
"num_hidden_layers": 4,
|
||||
"num_attention_heads": 48,
|
||||
"num_attention_heads_per_layer": [48, 64, 64, 64],
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"vocab_size": 100352,
|
||||
"max_position_embeddings": 131072,
|
||||
"layer_types": ["full_attention", "sliding_attention", "sliding_attention", "sliding_attention"],
|
||||
"sliding_window": 512,
|
||||
"mlp_only_layers": [0],
|
||||
"decoder_sparse_step": 1,
|
||||
"num_experts": 256,
|
||||
"num_experts_per_tok": 8,
|
||||
"norm_topk_prob": true,
|
||||
"moe_routed_scaling_factor": 2.5,
|
||||
"gating": "per-head",
|
||||
"rms_norm_eps": 1e-6,
|
||||
"partial_rotary_factor": 0.5,
|
||||
"rope_parameters": {
|
||||
"rope_theta": 500000,
|
||||
"rope_type": "yarn",
|
||||
"factor": 32,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"beta_fast": 64,
|
||||
"beta_slow": 1,
|
||||
"attention_factor": 1
|
||||
},
|
||||
"swa_rope_parameters": {
|
||||
"partial_rotary_factor": 1.0,
|
||||
"rope_theta": 10000,
|
||||
"rope_type": "linear"
|
||||
}
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if cfg.FullRopeDim != 64 {
|
||||
t.Fatalf("FullRopeDim = %d, want 64", cfg.FullRopeDim)
|
||||
}
|
||||
if cfg.FullRopeBase != 500000 {
|
||||
t.Fatalf("FullRopeBase = %v, want 500000", cfg.FullRopeBase)
|
||||
}
|
||||
if cfg.FullRopeScale != 1 {
|
||||
t.Fatalf("FullRopeScale = %v, want explicit YaRN attention_factor", cfg.FullRopeScale)
|
||||
}
|
||||
if cfg.FullRopeFreqs == nil {
|
||||
t.Fatal("FullRopeFreqs should be precomputed for YaRN")
|
||||
}
|
||||
if cfg.SlidingRopeDim != 128 {
|
||||
t.Fatalf("SlidingRopeDim = %d, want 128", cfg.SlidingRopeDim)
|
||||
}
|
||||
if cfg.SlidingRopeBase != 10000 {
|
||||
t.Fatalf("SlidingRopeBase = %v, want 10000", cfg.SlidingRopeBase)
|
||||
}
|
||||
if !layerIsSliding(&cfg, 1) {
|
||||
t.Fatal("layer 1 should use sliding attention")
|
||||
}
|
||||
if layerUsesMoE(&cfg, 0) {
|
||||
t.Fatal("layer 0 should be dense due to mlp_only_layers")
|
||||
}
|
||||
if !layerUsesMoE(&cfg, 1) {
|
||||
t.Fatal("layer 1 should use MoE")
|
||||
}
|
||||
if got := numHeadsForLayer(&cfg, 1); got != 64 {
|
||||
t.Fatalf("numHeadsForLayer(1) = %d, want 64", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigLagunaFP8RopeScaling(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
cfg, err := parseConfig([]byte(`{
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 8192,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 48,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"vocab_size": 100352,
|
||||
"max_position_embeddings": 131072,
|
||||
"rope_theta": 500000,
|
||||
"partial_rotary_factor": 0.5,
|
||||
"rope_scaling": {
|
||||
"rope_type": "yarn",
|
||||
"factor": 32
|
||||
}
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.FullRopeBase != 500000 {
|
||||
t.Fatalf("FullRopeBase = %v, want 500000", cfg.FullRopeBase)
|
||||
}
|
||||
if cfg.FullRopeDim != 64 {
|
||||
t.Fatalf("FullRopeDim = %d, want 64", cfg.FullRopeDim)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigLagunaGASchema(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
cfg, err := parseConfig([]byte(`{
|
||||
"model_type": "laguna",
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 8192,
|
||||
"moe_intermediate_size": 512,
|
||||
"shared_expert_intermediate_size": 512,
|
||||
"num_hidden_layers": 4,
|
||||
"num_attention_heads": 48,
|
||||
"num_attention_heads_per_layer": [48, 64, 64, 64],
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"vocab_size": 100352,
|
||||
"max_position_embeddings": 131072,
|
||||
"layer_types": ["full_attention", "sliding_attention", "sliding_attention", "sliding_attention"],
|
||||
"sliding_window": 512,
|
||||
"mlp_layer_types": ["dense", "sparse", "sparse", "sparse"],
|
||||
"num_experts": 256,
|
||||
"num_experts_per_tok": 8,
|
||||
"moe_routed_scaling_factor": 2.5,
|
||||
"gating": true,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"partial_rotary_factor": 0.5,
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"rope_theta": 500000,
|
||||
"rope_type": "yarn",
|
||||
"factor": 32,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"beta_fast": 64,
|
||||
"beta_slow": 1,
|
||||
"attention_factor": 1,
|
||||
"partial_rotary_factor": 0.5
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000,
|
||||
"rope_type": "default",
|
||||
"partial_rotary_factor": 1.0
|
||||
}
|
||||
}
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if cfg.Gating != "per-head" {
|
||||
t.Fatalf("Gating = %q, want per-head", cfg.Gating)
|
||||
}
|
||||
if !cfg.NormTopKProb {
|
||||
t.Fatal("NormTopKProb should default true")
|
||||
}
|
||||
if cfg.FullRopeBase != 500000 {
|
||||
t.Fatalf("FullRopeBase = %v, want 500000", cfg.FullRopeBase)
|
||||
}
|
||||
if cfg.SlidingRopeBase != 10000 {
|
||||
t.Fatalf("SlidingRopeBase = %v, want 10000", cfg.SlidingRopeBase)
|
||||
}
|
||||
if cfg.FullRopeDim != 64 {
|
||||
t.Fatalf("FullRopeDim = %d, want 64", cfg.FullRopeDim)
|
||||
}
|
||||
if cfg.SlidingRopeDim != 128 {
|
||||
t.Fatalf("SlidingRopeDim = %d, want 128", cfg.SlidingRopeDim)
|
||||
}
|
||||
if layerUsesMoE(&cfg, 0) {
|
||||
t.Fatal("layer 0 should be dense due to mlp_layer_types")
|
||||
}
|
||||
if !layerUsesMoE(&cfg, 1) {
|
||||
t.Fatal("layer 1 should use MoE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTinyLagunaLoadAndForward(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
cfg, err := parseConfig([]byte(`{
|
||||
"model_type": "laguna",
|
||||
"hidden_size": 8,
|
||||
"intermediate_size": 12,
|
||||
"moe_intermediate_size": 4,
|
||||
"shared_expert_intermediate_size": 4,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_attention_heads_per_layer": [2, 2],
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 4,
|
||||
"vocab_size": 16,
|
||||
"max_position_embeddings": 64,
|
||||
"layer_types": ["full_attention", "sliding_attention"],
|
||||
"sliding_window": 2,
|
||||
"mlp_only_layers": [0],
|
||||
"decoder_sparse_step": 1,
|
||||
"num_experts": 2,
|
||||
"num_experts_per_tok": 1,
|
||||
"norm_topk_prob": false,
|
||||
"moe_routed_scaling_factor": 2.5,
|
||||
"gating": "per-head",
|
||||
"rms_norm_eps": 1e-5,
|
||||
"partial_rotary_factor": 0.5,
|
||||
"rope_parameters": {
|
||||
"rope_theta": 10000,
|
||||
"rope_type": "yarn",
|
||||
"factor": 2,
|
||||
"original_max_position_embeddings": 16,
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1
|
||||
},
|
||||
"swa_rope_parameters": {
|
||||
"partial_rotary_factor": 1.0,
|
||||
"rope_theta": 10000,
|
||||
"rope_type": "linear"
|
||||
}
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Config: &cfg,
|
||||
Layers: []*Layer{
|
||||
{LayerIdx: 0, IsSliding: false},
|
||||
{LayerIdx: 1, IsSliding: true},
|
||||
},
|
||||
}
|
||||
tensors := tinyLagunaTensors()
|
||||
if err := m.LoadWeights(tensors); err != nil {
|
||||
t.Fatalf("LoadWeights failed: %v", err)
|
||||
}
|
||||
|
||||
tokens := mlx.FromValues([]int32{1, 2, 3}, 1, 3)
|
||||
caches := m.NewCaches()
|
||||
defer func() {
|
||||
for _, c := range caches {
|
||||
if c != nil {
|
||||
c.Free()
|
||||
}
|
||||
}
|
||||
}()
|
||||
hidden := m.Forward(&batch.Batch{
|
||||
InputIDs: tokens,
|
||||
SeqOffsets: []int32{0},
|
||||
SeqQueryLens: []int32{int32(tokens.Dim(1))},
|
||||
}, caches)
|
||||
mlx.Eval(hidden)
|
||||
if got := hidden.Dims(); len(got) != 3 || got[0] != 1 || got[1] != 3 || got[2] != 8 {
|
||||
t.Fatalf("hidden shape = %v, want [1 3 8]", got)
|
||||
}
|
||||
|
||||
logits := m.Unembed(hidden)
|
||||
mlx.Eval(logits)
|
||||
if got := logits.Dims(); len(got) != 3 || got[0] != 1 || got[1] != 3 || got[2] != 16 {
|
||||
t.Fatalf("logits shape = %v, want [1 3 16]", got)
|
||||
}
|
||||
for i, v := range logits.Floats() {
|
||||
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
|
||||
t.Fatalf("logits[%d] is not finite: %v", i, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTinyLagunaLoadWeightsFusesDenseGateUp(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
cfg, err := parseConfig([]byte(`{
|
||||
"model_type": "laguna",
|
||||
"hidden_size": 8,
|
||||
"intermediate_size": 12,
|
||||
"moe_intermediate_size": 4,
|
||||
"shared_expert_intermediate_size": 4,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_attention_heads_per_layer": [2, 2],
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 4,
|
||||
"vocab_size": 16,
|
||||
"max_position_embeddings": 64,
|
||||
"layer_types": ["full_attention", "sliding_attention"],
|
||||
"sliding_window": 2,
|
||||
"mlp_only_layers": [0],
|
||||
"decoder_sparse_step": 1,
|
||||
"num_experts": 2,
|
||||
"num_experts_per_tok": 1,
|
||||
"norm_topk_prob": false,
|
||||
"moe_routed_scaling_factor": 2.5,
|
||||
"gating": "per-head",
|
||||
"rms_norm_eps": 1e-5
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Config: &cfg,
|
||||
Layers: []*Layer{
|
||||
{LayerIdx: 0, IsSliding: false},
|
||||
{LayerIdx: 1, IsSliding: true},
|
||||
},
|
||||
}
|
||||
if err := m.LoadWeights(tinyLagunaTensors()); err != nil {
|
||||
t.Fatalf("LoadWeights failed: %v", err)
|
||||
}
|
||||
|
||||
moe, ok := m.Layers[1].MLP.(*SparseMoE)
|
||||
if !ok {
|
||||
t.Fatalf("layer 1 MLP type = %T, want *SparseMoE", m.Layers[1].MLP)
|
||||
}
|
||||
if !moe.SwitchMLP.UseFusedGateUp {
|
||||
t.Fatal("expected dense SwitchMLP to fuse gate/up expert weights")
|
||||
}
|
||||
if moe.SwitchMLP.GateUpWeight == nil {
|
||||
t.Fatal("expected fused GateUpWeight to be populated")
|
||||
}
|
||||
if got, want := moe.SwitchMLP.GateUpWeight.Dims(), []int{2, 8, 8}; len(got) != len(want) || got[0] != want[0] || got[1] != want[1] || got[2] != want[2] {
|
||||
t.Fatalf("GateUpWeight dims = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSparseMoERouteBiasAffectsSelectionNotRoutingWeights(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
cfg := &Config{
|
||||
HiddenSize: 1,
|
||||
NumExperts: 2,
|
||||
NumExpertsPerTok: 1,
|
||||
NormTopKProb: false,
|
||||
}
|
||||
|
||||
moe := &SparseMoE{
|
||||
Gate: nn.NewLinear(mlx.FromValues([]float32{-4, -3}, 2, 1).AsType(mlx.DTypeBFloat16), nil),
|
||||
EScoreCorrectionBias: mlx.FromValues([]float32{0.5, 0}, 2),
|
||||
}
|
||||
|
||||
xFlat := mlx.FromValues([]float32{1}, 1, int(cfg.HiddenSize)).AsType(mlx.DTypeBFloat16)
|
||||
scores, inds := moe.route(xFlat, cfg)
|
||||
scores = scores.AsType(mlx.DTypeFloat32)
|
||||
inds = inds.AsType(mlx.DTypeInt32)
|
||||
mlx.Eval(scores, inds)
|
||||
|
||||
gates := moe.Gate.Forward(xFlat).AsType(mlx.DTypeFloat32)
|
||||
probs := mlx.Sigmoid(gates)
|
||||
mlx.Eval(probs)
|
||||
probVals := probs.Floats()
|
||||
if probVals[0] >= probVals[1] {
|
||||
t.Fatalf("expected unbiased sigmoid scores to prefer expert 1, got %v", probVals)
|
||||
}
|
||||
if probVals[0]+0.5 <= probVals[1] {
|
||||
t.Fatalf("expected bias to flip selection to expert 0, got probs=%v", probVals)
|
||||
}
|
||||
if got := inds.Ints(); len(got) != 1 || got[0] != 0 {
|
||||
t.Fatalf("selected experts = %v, want [0]", got)
|
||||
}
|
||||
if got := scores.Floats(); len(got) != 1 || math.Abs(float64(got[0]-probVals[0])) > 1e-6 {
|
||||
t.Fatalf("routing weights = %v, want [%v] using unbiased sigmoid scores", got, probVals[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchMLPFusedGateUpMatchesSeparate(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
cfg := &Config{HiddenSize: 4, NumExpertsPerTok: 2}
|
||||
B, L := int32(2), int32(3)
|
||||
xVals := make([]float32, int(B*L*cfg.HiddenSize))
|
||||
for i := range xVals {
|
||||
xVals[i] = float32((i%17)-8) * 0.01
|
||||
}
|
||||
x := mlx.FromValues(xVals, int(B), int(L), int(cfg.HiddenSize)).AsType(mlx.DTypeBFloat16)
|
||||
|
||||
indicesVals := make([]int32, B*L*cfg.NumExpertsPerTok)
|
||||
for i := 0; i < len(indicesVals); i += int(cfg.NumExpertsPerTok) {
|
||||
indicesVals[i] = int32((i / int(cfg.NumExpertsPerTok)) % 2)
|
||||
indicesVals[i+1] = int32(((i / int(cfg.NumExpertsPerTok)) + 1) % 2)
|
||||
}
|
||||
indices := mlx.FromValues(indicesVals, int(B*L), int(cfg.NumExpertsPerTok))
|
||||
|
||||
separate := &SwitchMLP{
|
||||
GateWeight: makePatternExpertWeight(2, 4, 3, 0.011),
|
||||
UpWeight: makePatternExpertWeight(2, 4, 3, 0.017),
|
||||
DownWeight: makePatternExpertWeight(2, 3, 4, 0.013),
|
||||
}
|
||||
fused := &SwitchMLP{
|
||||
GateUpWeight: fuseExpertStacks(separate.GateWeight, separate.UpWeight, 2),
|
||||
DownWeight: separate.DownWeight,
|
||||
UseFusedGateUp: true,
|
||||
}
|
||||
|
||||
gotSeparate := separate.Forward(x, indices, cfg)
|
||||
gotFused := fused.Forward(x, indices, cfg)
|
||||
mlx.Eval(gotSeparate, gotFused)
|
||||
|
||||
gotFusedF32 := gotFused.AsType(mlx.DTypeFloat32)
|
||||
gotSeparateF32 := gotSeparate.AsType(mlx.DTypeFloat32)
|
||||
mlx.Eval(gotFusedF32, gotSeparateF32)
|
||||
assertFloatSlicesClose(t, gotFusedF32.Floats(), gotSeparateF32.Floats(), 1e-5)
|
||||
}
|
||||
|
||||
func TestCombinedTensorGlobalScaleIgnoresInputGlobalScale(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
tensors := map[string]*mlx.Array{
|
||||
"proj.weight.global_scale": mlx.FromValues([]float32{0.25}, 1),
|
||||
"proj.weight.input_global_scale": mlx.FromValues([]float32{8}, 1),
|
||||
}
|
||||
|
||||
got, _ := combinedTensorGlobalScale(tensors, "proj.weight")
|
||||
if got == nil {
|
||||
t.Fatal("combinedTensorGlobalScale returned nil")
|
||||
}
|
||||
mlx.Eval(got)
|
||||
vals := got.Floats()
|
||||
if len(vals) != 1 || vals[0] != 0.25 {
|
||||
t.Fatalf("combinedTensorGlobalScale = %v, want [0.25]", vals)
|
||||
}
|
||||
}
|
||||
|
||||
func tinyLagunaTensors() map[string]*mlx.Array {
|
||||
tensors := map[string]*mlx.Array{
|
||||
"model.embed_tokens.weight": weights(16, 8),
|
||||
"model.norm.weight": ones(8),
|
||||
"lm_head.weight": weights(16, 8),
|
||||
}
|
||||
for layer := range 2 {
|
||||
prefix := "model.layers." + string(rune('0'+layer))
|
||||
tensors[prefix+".input_layernorm.weight"] = ones(8)
|
||||
tensors[prefix+".post_attention_layernorm.weight"] = ones(8)
|
||||
tensors[prefix+".self_attn.q_proj.weight"] = weights(8, 8)
|
||||
tensors[prefix+".self_attn.k_proj.weight"] = weights(4, 8)
|
||||
tensors[prefix+".self_attn.v_proj.weight"] = weights(4, 8)
|
||||
tensors[prefix+".self_attn.o_proj.weight"] = weights(8, 8)
|
||||
tensors[prefix+".self_attn.g_proj.weight"] = weights(2, 8)
|
||||
tensors[prefix+".self_attn.q_norm.weight"] = ones(4)
|
||||
tensors[prefix+".self_attn.k_norm.weight"] = ones(4)
|
||||
}
|
||||
|
||||
tensors["model.layers.0.mlp.gate_proj.weight"] = weights(12, 8)
|
||||
tensors["model.layers.0.mlp.up_proj.weight"] = weights(12, 8)
|
||||
tensors["model.layers.0.mlp.down_proj.weight"] = weights(8, 12)
|
||||
|
||||
tensors["model.layers.1.mlp.gate.weight"] = weights(2, 8)
|
||||
tensors["model.layers.1.mlp.experts.e_score_correction_bias"] = mlx.FromValues([]float32{0.1, -0.1}, 2)
|
||||
for expert := range 2 {
|
||||
prefix := "model.layers.1.mlp.experts." + string(rune('0'+expert))
|
||||
tensors[prefix+".gate_proj.weight"] = weights(4, 8)
|
||||
tensors[prefix+".up_proj.weight"] = weights(4, 8)
|
||||
tensors[prefix+".down_proj.weight"] = weights(8, 4)
|
||||
}
|
||||
tensors["model.layers.1.mlp.shared_expert.gate_proj.weight"] = weights(4, 8)
|
||||
tensors["model.layers.1.mlp.shared_expert.up_proj.weight"] = weights(4, 8)
|
||||
tensors["model.layers.1.mlp.shared_expert.down_proj.weight"] = weights(8, 4)
|
||||
return tensors
|
||||
}
|
||||
|
||||
func makeExpertWeight(vals []float32, dims ...int) *mlx.Array {
|
||||
return mlx.FromValues(vals, dims...).AsType(mlx.DTypeBFloat16)
|
||||
}
|
||||
|
||||
func makePatternExpertWeight(numExperts, rows, cols int, scale float32) *mlx.Array {
|
||||
vals := make([]float32, numExperts*rows*cols)
|
||||
for i := range vals {
|
||||
vals[i] = float32((i%23)-11) * scale
|
||||
}
|
||||
return makeExpertWeight(vals, numExperts, rows, cols)
|
||||
}
|
||||
|
||||
func assertFloatSlicesClose(t *testing.T, got, want []float32, tol float64) {
|
||||
t.Helper()
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("length mismatch: got %d want %d", len(got), len(want))
|
||||
}
|
||||
for i := range got {
|
||||
if math.Abs(float64(got[i]-want[i])) > tol {
|
||||
t.Fatalf("value[%d] = %v, want %v (tol=%g)", i, got[i], want[i], tol)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func weights(rows, cols int) *mlx.Array {
|
||||
vals := make([]float32, rows*cols)
|
||||
for i := range vals {
|
||||
vals[i] = float32((i%7)-3) * 0.01
|
||||
}
|
||||
return mlx.FromValues(vals, rows, cols)
|
||||
}
|
||||
|
||||
func ones(n int) *mlx.Array {
|
||||
vals := make([]float32, n)
|
||||
for i := range vals {
|
||||
vals[i] = 1
|
||||
}
|
||||
return mlx.FromValues(vals, n)
|
||||
}
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -447,7 +447,9 @@ func extractPretokenizer(data json.RawMessage) string {
|
||||
if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
|
||||
for _, pt := range seq.Pretokenizers {
|
||||
if pt.Type == "Split" && pt.Pattern.Regex != "" {
|
||||
return pt.Pattern.Regex
|
||||
if _, err := regexp.Compile(rewritePatternForRE2(pt.Pattern.Regex)); err == nil {
|
||||
return pt.Pattern.Regex
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,3 +22,31 @@ func TestLoadFromBytesRejectsWordPiece(t *testing.T) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractPretokenizerSkipsUnsupportedSequenceSplit(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"type": "Sequence",
|
||||
"pretokenizers": [
|
||||
{
|
||||
"type": "Split",
|
||||
"pattern": {
|
||||
"Regex": "(?:\\r?\\n)+(?!\\r?\\n)"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Split",
|
||||
"pattern": {
|
||||
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
pattern := extractPretokenizer(data)
|
||||
if pattern == "" {
|
||||
t.Fatal("expected supported Split pretokenizer")
|
||||
}
|
||||
if strings.Contains(pattern, `(?!\r?\n)`) {
|
||||
t.Fatalf("selected unsupported newline splitter: %q", pattern)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user