Compare commits

..

1 Commits

Author SHA1 Message Date
Michael Yang
9ef2106b47 cmd: create blob in parallel with checksum
a simple optimisation where once a blob has been checksumed, immediately
upload it; don't wait for all files to be checksumed before starting
upload.
2026-01-20 12:09:02 -08:00
42 changed files with 281 additions and 6662 deletions

View File

@@ -749,7 +749,7 @@ type ShowResponse struct {
Messages []Message `json:"messages,omitempty"`
RemoteModel string `json:"remote_model,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
ModelInfo map[string]any `json:"model_info"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"`

View File

@@ -43,7 +43,6 @@ import (
"github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/create"
@@ -205,7 +204,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
spinner.Stop()
req.Model = modelName
quantize, _ := cmd.Flags().GetString("quantize")
@@ -219,42 +217,29 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
g.SetLimit(runtime.GOMAXPROCS(0))
for blob, err := range createBlobs(req.Files, req.Adapters) {
if err != nil {
return err
}
files := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Files {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
// TODO: this is incorrect since the file might be in a subdirectory
// instead this should take the path relative to the model directory
// but the current implementation does not allow this
files.Store(filepath.Base(f), digest)
return nil
_, err := createBlob(cmd, client, blob.Abs, blob.Digest, p)
return err
})
}
adapters := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Adapters {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
// TODO: same here
adapters.Store(filepath.Base(f), digest)
return nil
})
if _, ok := req.Files[blob.Rel]; ok {
req.Files[blob.Rel] = blob.Digest
} else if _, ok := req.Adapters[blob.Rel]; ok {
req.Adapters[blob.Rel] = blob.Digest
}
}
if err := g.Wait(); err != nil {
return err
}
req.Files = files.Items()
req.Adapters = adapters.Items()
spinner.Stop()
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error {
@@ -292,54 +277,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
realPath, err := filepath.EvalSymlinks(path)
if err != nil {
return "", err
}
bin, err := os.Open(realPath)
if err != nil {
return "", err
}
defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
var pw progressWriter
status := fmt.Sprintf("copying file %s 0%%", digest)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer spinner.Stop()
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
return
}
}
}()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
}
type progressWriter struct {
n atomic.Int64
}
@@ -899,11 +836,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
for _, arg := range args {
// Unload the model if it's running before deletion
if err := loadOrUnloadModel(cmd, &runOptions{
Model: arg,
Model: args[0],
KeepAlive: &api.Duration{Duration: 0},
}); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", arg)
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
}
}

103
cmd/create.go Normal file
View File

@@ -0,0 +1,103 @@
package cmd
import (
"crypto/sha256"
"fmt"
"io"
"iter"
"os"
"path/filepath"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/progress"
"github.com/spf13/cobra"
)
type blob struct {
Rel, Abs, Digest string
}
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
realPath, err := filepath.EvalSymlinks(path)
if err != nil {
return "", err
}
bin, err := os.Open(realPath)
if err != nil {
return "", err
}
defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
var pw progressWriter
status := fmt.Sprintf("copying file %s 0%%", digest)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer spinner.Stop()
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
return
}
}
}()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
}
func createBlobs(mappings ...map[string]string) iter.Seq2[blob, error] {
return func(yield func(blob, error) bool) {
for _, mapping := range mappings {
for rel, abs := range mapping {
if abs, ok := strings.CutPrefix(abs, "abs:"); ok {
f, err := os.Open(abs)
if err != nil {
yield(blob{}, err)
return
}
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
yield(blob{}, err)
return
}
if err := f.Close(); err != nil {
yield(blob{}, err)
return
}
if !yield(blob{
Rel: rel,
Abs: abs,
Digest: fmt.Sprintf("sha256:%x", h.Sum(nil)),
}, nil) {
return
}
}
}
}
}
}

View File

@@ -311,10 +311,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &deepseekocr{}
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
case "Glm4MoeLiteForCausalLM":
conv = &glm4MoeLiteModel{}
case "Lfm2ForCausalLM":
conv = &lfm2Model{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}

View File

@@ -1,150 +0,0 @@
package convert
import (
"cmp"
"fmt"
"log/slog"
"regexp"
"strconv"
"github.com/ollama/ollama/fs/ggml"
)
type glm4MoeLiteModel struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
KVLoraRank uint32 `json:"kv_lora_rank"`
QLoraRank uint32 `json:"q_lora_rank"`
VHeadDim uint32 `json:"v_head_dim"`
ExpertCount uint32 `json:"n_routed_experts"`
ExpertSharedCount uint32 `json:"n_shared_experts"`
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
ExpertWeightsNorm bool `json:"norm_topk_prob"`
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
}
func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "glm4moelite"
kv["general.type"] = "model"
kv["glm4moelite.block_count"] = p.HiddenLayers
numHeads := p.NumAttentionHeads
numKVHeads := p.NumKeyValueHeads
kv["glm4moelite.attention.head_count"] = numHeads
kv["glm4moelite.attention.head_count_kv"] = numKVHeads
kv["glm4moelite.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
kv["glm4moelite.attention.kv_lora_rank"] = p.KVLoraRank
kv["glm4moelite.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["glm4moelite.attention.q_lora_rank"] = p.QLoraRank
kv["glm4moelite.attention.value_length"] = p.VHeadDim
kv["glm4moelite.context_length"] = p.MaxPositionEmbeddings
kv["glm4moelite.embedding_length"] = p.HiddenSize
kv["glm4moelite.expert_count"] = p.ExpertCount
kv["glm4moelite.expert_feed_forward_length"] = p.ExpertIntermediateSize
kv["glm4moelite.expert_shared_count"] = p.ExpertSharedCount
kv["glm4moelite.expert_gating_func"] = uint32(2)
kv["glm4moelite.expert_used_count"] = p.ExpertUsedCount
kv["glm4moelite.expert_weights_norm"] = p.ExpertWeightsNorm
kv["glm4moelite.expert_weights_scale"] = p.ExpertWeightsScale
kv["glm4moelite.feed_forward_length"] = p.IntermediateSize
kv["glm4moelite.leading_dense_block_count"] = p.LeadingDenseBlockCount
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
kv["tokenizer.ggml.pre"] = "glm4"
return kv
}
func (p *glm4MoeLiteModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
"self_attn.kv_b_proj", "attn_kv_b",
"self_attn.q_a_proj", "attn_q_a",
"self_attn.q_a_layernorm", "attn_q_a_norm",
"self_attn.q_b_proj", "attn_q_b",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
"mlp.gate", "ffn_gate_inp",
}
}
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, p.HiddenLayers*3)
for i := range p.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
skipLayer := func(n string, minValue uint32) bool {
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(n)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return uint32(blkNum) >= minValue
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
// skip any additional layers (such as the Multi-Token Prediction layer)
if skipLayer(t.Name(), p.HiddenLayers) {
slog.Debug("skipping layer", "name", t.Name())
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}

View File

@@ -1,100 +0,0 @@
package convert
import (
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type lfm2Model struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
NormEps float32 `json:"norm_eps"`
ConvLCache uint32 `json:"conv_L_cache"`
LayerTypes []string `json:"layer_types"`
TieEmbedding bool `json:"tie_embedding"`
}
var _ ModelConverter = (*lfm2Model)(nil)
func (p *lfm2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "lfm2"
kv["lfm2.vocab_size"] = p.VocabSize
kv["lfm2.block_count"] = p.NumHiddenLayers
kv["lfm2.embedding_length"] = p.HiddenSize
kv["lfm2.feed_forward_length"] = p.IntermediateSize
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
// Build per-layer KV head count array based on layer_types
// (0 = shortconv layer, non-zero = attention layer with that many KV heads)
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
for i := range p.NumHiddenLayers {
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
kvHeadCounts[i] = p.NumKeyValueHeads
}
}
kv["lfm2.attention.head_count"] = p.NumAttentionHeads
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
kv["lfm2.rope.freq_base"] = p.RopeTheta
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
return kv
}
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
shape := t.Shape()
// Squeeze conv weights: [D, 1, K] -> [D, K]
if strings.HasSuffix(t.Name(), "shortconv.conv.weight") {
if len(shape) == 3 && shape[1] == 1 {
shape = []uint64{shape[0], shape[2]}
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
}
return out
}
func (p *lfm2Model) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.embedding_norm", "output_norm",
"model.layers", "blk",
"operator_norm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.out_proj", "attn_output",
"self_attn.q_layernorm", "attn_q_norm",
"self_attn.k_layernorm", "attn_k_norm",
"conv.conv", "shortconv.conv",
"conv.in_proj", "shortconv.in_proj",
"conv.out_proj", "shortconv.out_proj",
"feed_forward.w1", "ffn_gate",
"feed_forward.w2", "ffn_down",
"feed_forward.w3", "ffn_up",
"ffn_norm", "ffn_norm",
}
}

View File

@@ -40,7 +40,6 @@ const (
func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
strings.HasSuffix(t.name, ".bias") ||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||

View File

@@ -269,8 +269,6 @@ func (kv KV) OllamaEngineRequired() bool {
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
"glm4moelite",
"lfm2",
}, kv.Architecture())
}
@@ -858,9 +856,7 @@ func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"bert",
"gemma3",
"glm4moelite",
"gptoss", "gpt-oss",
"lfm2",
"mistral3",
"olmo3",
"qwen3", "qwen3moe",

View File

@@ -1,148 +0,0 @@
//go:build integration
package integration
import (
"context"
"encoding/base64"
"fmt"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestImageGeneration(t *testing.T) {
skipUnderMinVRAM(t, 8)
type testCase struct {
imageGenModel string
visionModel string
prompt string
expectedWords []string
}
testCases := []testCase{
{
imageGenModel: "jmorgan/z-image-turbo",
visionModel: "llama3.2-vision",
prompt: "A cartoon style llama flying like a superhero through the air with clouds in the background",
expectedWords: []string{"llama", "flying", "cartoon", "cloud", "sky", "superhero", "air", "animal", "camelid"},
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s->%s", tc.imageGenModel, tc.visionModel), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Pull both models
if err := PullIfMissing(ctx, client, tc.imageGenModel); err != nil {
t.Fatalf("failed to pull image gen model: %v", err)
}
if err := PullIfMissing(ctx, client, tc.visionModel); err != nil {
t.Fatalf("failed to pull vision model: %v", err)
}
// Generate the image
t.Logf("Generating image with prompt: %s", tc.prompt)
imageBase64, err := generateImage(ctx, client, tc.imageGenModel, tc.prompt)
if err != nil {
if strings.Contains(err.Error(), "image generation not available") {
t.Skip("Target system does not support image generation")
} else if strings.Contains(err.Error(), "executable file not found in") { // Windows pattern, not yet supported
t.Skip("Windows does not support image generation yet")
} else if strings.Contains(err.Error(), "CUDA driver version is insufficient") {
t.Skip("Driver is too old")
} else if strings.Contains(err.Error(), "insufficient memory for image generation") {
t.Skip("insufficient memory for image generation")
} else if strings.Contains(err.Error(), "error while loading shared libraries: libcuda.so.1") { // AMD GPU or CPU
t.Skip("CUDA GPU is not available")
} else if strings.Contains(err.Error(), "ollama-mlx: no such file or directory") {
// most likely linux arm - not supported yet
t.Skip("unsupported architecture")
}
t.Fatalf("failed to generate image: %v", err)
}
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
t.Fatalf("failed to decode image: %v", err)
}
t.Logf("Generated image: %d bytes", len(imageData))
// Preload vision model and check GPU loading
err = client.Generate(ctx, &api.GenerateRequest{Model: tc.visionModel}, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load vision model: %v", err)
}
// Use vision model to describe the image
chatReq := api.ChatRequest{
Model: tc.visionModel,
Messages: []api.Message{
{
Role: "user",
Content: "Describe this image in detail. What is shown? What style is it? What is the main subject doing?",
Images: []api.ImageData{imageData},
},
},
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}
// Verify the vision model's response contains expected keywords
response := DoChat(ctx, t, client, chatReq, tc.expectedWords, 240*time.Second, 30*time.Second)
if response != nil {
t.Logf("Vision model response: %s", response.Content)
// Additional detailed check for keywords
content := strings.ToLower(response.Content)
foundWords := []string{}
missingWords := []string{}
for _, word := range tc.expectedWords {
if strings.Contains(content, word) {
foundWords = append(foundWords, word)
} else {
missingWords = append(missingWords, word)
}
}
t.Logf("Found keywords: %v", foundWords)
if len(missingWords) > 0 {
t.Logf("Missing keywords (at least one was found so test passed): %v", missingWords)
}
}
})
}
}
// generateImage calls the Ollama API to generate an image and returns the base64 image data
func generateImage(ctx context.Context, client *api.Client, model, prompt string) (string, error) {
var imageBase64 string
err := client.Generate(ctx, &api.GenerateRequest{
Model: model,
Prompt: prompt,
}, func(resp api.GenerateResponse) error {
if resp.Image != "" {
imageBase64 = resp.Image
}
return nil
})
if err != nil {
return "", fmt.Errorf("failed to generate image: %w", err)
}
if imageBase64 == "" {
return "", fmt.Errorf("no image data in response")
}
return imageBase64, nil
}

View File

@@ -38,7 +38,6 @@ var (
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"lfm2.5-thinking",
"ministral-3",
"qwen3-coder:30b",
"gpt-oss:20b",
@@ -144,7 +143,6 @@ var (
"granite3.3",
"hermes3",
"internlm2",
"lfm2.5-thinking",
"llama-guard3",
"llama-pro",
"llama2-chinese",
@@ -265,7 +263,6 @@ var (
"snowflake-arctic-embed2",
}
libraryToolsModels = []string{
"lfm2.5-thinking",
"qwen3-vl",
"gpt-oss:20b",
"gpt-oss:120b",

View File

@@ -162,7 +162,6 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
SSMConv(ctx Context, kernel Tensor) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor

View File

@@ -1641,13 +1641,6 @@ func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2,
return tt
}
func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_ssm_conv(ctx.(*Context).ctx, t.t, kernel.(*Tensor).t),
}
}
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{
b: t.b,

View File

@@ -1,304 +0,0 @@
package glm4moelite
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
numExpertsUsed int
numExperts int
normTopKProb bool
routedScalingFactor float32
kvLoraRank,
qkNopeHeadDim,
qkRopeHeadDim,
kqNopeHeadDim,
qkHeadDim int
qLoraRank int
vHeadDim int
hiddenSize,
numHeads,
numKVHeads int
eps,
ropeBase float32
kqScale float64
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1.0)
}
type Attention struct {
Q *nn.Linear `gguf:"attn_q"`
QA *nn.Linear `gguf:"attn_q_a"`
QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"`
QB *nn.Linear `gguf:"attn_q_b"`
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
KVB *nn.Linear `gguf:"attn_kv_b"`
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
seqLength := hiddenStates.Dim(1)
var query ml.Tensor
if opts.qLoraRank == 0 {
query = attn.Q.Forward(ctx, hiddenStates)
} else {
query = attn.QA.Forward(ctx, hiddenStates)
query = attn.QANorm.Forward(ctx, query, opts.eps)
query = attn.QB.Forward(ctx, query)
}
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
kRot := compressedKV.View(ctx,
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
compressedKV.Stride(1), 1,
compressedKV.Stride(1), compressedKV.Dim(1),
)
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
}
type MLP interface {
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
}
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"`
SharedExpert *dense `gguf:",suf:_shexp"`
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
}
func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor {
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices)
hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices)
hiddenStates = hiddenStates.SILU(ctx, upStates)
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
experts = experts.Mul(ctx, topKWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextStates
}
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
if moe.ExpProbsBias != nil {
scores = scores.Add(ctx, moe.ExpProbsBias)
}
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
return topKIndices
}
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
residuals := hiddenStates
routerLogits := moe.Router.Forward(ctx, hiddenStates)
scores := routerLogits.Sigmoid(ctx)
topKIndices := moe.topKIndices(ctx, scores, opts)
topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices)
if opts.normTopKProb {
topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx))
topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
}
topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor))
hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts)
sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts)
hiddenStates = hiddenStates.Add(ctx, sharedExpertResult)
return hiddenStates
}
type dense struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *Attention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP MLP
}
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenStates
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
}
func New(c fs.Config) (model.Model, error) {
layers := make([]Layer, c.Uint("block_count"))
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
for i := range layers {
if i < firstDenseLayerIndex {
layers[i].MLP = &dense{}
} else {
layers[i].MLP = &sparse{}
}
}
keyLength := int(c.Uint("attention.key_length"))
valueLength := int(c.Uint("attention.value_length"))
kqScale := 1.0 / math.Sqrt(float64(keyLength))
var pre []string
switch c.String("tokenizer.ggml.pre") {
case "glm4":
pre = []string{
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
}
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
pre...,
),
Layers: layers,
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("expert_weights_norm", true),
qLoraRank: int(c.Uint("attention.q_lora_rank")),
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
qkHeadDim: keyLength,
vHeadDim: valueLength,
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
routedScalingFactor: c.Float("expert_weights_scale"),
kqScale: kqScale,
},
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("glm4moelite", New)
}

View File

@@ -1,410 +0,0 @@
package lfm2
import (
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
var _ kvcache.Cache = (*HybridCache)(nil)
// HybridCache stores:
// - a standard causal KV cache for attention layers
// - a per-sequence recurrent conv state for shortconv layers
//
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
type HybridCache struct {
kv *kvcache.Causal
backend ml.Backend
dtype ml.DType
maxSequences int
hiddenSize int
dConv int
// slot mapping for recurrent state
slotForSeq map[int]int
refCount []int
freeSlots []int
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
// track any error from EnsureWritable to propagate later
writableError error
}
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
return &HybridCache{
kv: kvcache.NewCausalCache(shift),
hiddenSize: hiddenSize,
dConv: dConv,
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
}
}
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *HybridCache) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *HybridCache) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
// Derive equal-length sequence layout for shortconv.
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
seqCounts := make(map[int]int)
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if _, ok := seqCounts[s]; !ok {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
if len(c.curSeqs) == 0 {
return nil
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
}
}
c.curSeqTokens = want
// When reserving memory for estimation, use fake slot assignments
// without modifying permanent state (slotForSeq, refCount)
if reserve {
c.curSlots = c.curSlots[:0]
slots := make([]int32, nSeqs)
for i := range nSeqs {
c.curSlots = append(c.curSlots, i)
slots[i] = int32(i)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
// Ensure slots exist for sequences in this batch
c.curSlots = c.curSlots[:0]
var newSlots []int // track newly allocated slots that need zeroing
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
c.curSlots = append(c.curSlots, slot)
}
// Zero conv state for newly allocated slots to clear stale data from previous sequences
if len(newSlots) > 0 {
c.zeroConvSlots(ctx, newSlots)
}
// Create a tensor for the current slots
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
// Reset writable state for new forward pass
c.writableEnsured = false
c.writableError = nil
return nil
}
func (c *HybridCache) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, kvcache.ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *HybridCache) freeSlot(slot int) {
// Bounds check before freeing
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// zeroConvSlots zeros the conv state for the given slots across all layers.
// This must be called when recycling slots to prevent stale state from affecting new sequences.
func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) {
if len(slots) == 0 || len(c.convStates) == 0 {
return
}
// Use input context for creating tensors
inputCtx := ctx.Input()
// Create slot indices tensor
slotIndices := make([]int32, len(slots))
for i, s := range slots {
slotIndices[i] = int32(s)
}
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
// Create zero tensor for the slots (SetRows requires F32 source)
zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots))
// Zero each layer's conv state for these slots
for _, buf := range c.convStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
// Returns an error if slot allocation fails.
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
// Copy existing conv state for all initialized layers
for _, buf := range c.convStates {
// buf: [dConv*hiddenSize, maxSlots]
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
// SetRows requires F32 source
srcF32 := src.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
}
}
// Rebuild current slots tensor
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
// On the first write to dst, EnsureWritable will create a private slot.
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
// Bounds check before decrementing
if dstSlot >= 0 && dstSlot < len(c.refCount) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
// src may not have a slot yet; dst will allocate on demand
return
}
// Bounds check before incrementing
if srcSlot >= 0 && srcSlot < len(c.refCount) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *HybridCache) CanResume(seq int, pos int32) bool {
return c.kv.CanResume(seq, pos)
}
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
// For recurrent state, any removal invalidates the state because
// the state at position N depends on all previous positions.
// Drop the slot mapping so it resets on next use.
slot, ok := c.slotForSeq[seq]
if !ok {
return nil
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.curSlotsInput
}
func (c *HybridCache) seqTokens() int {
return c.curSeqTokens
}
func (c *HybridCache) numSeqs() int {
return len(c.curSeqs)
}
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
c.convStates[layer] = buf
return buf
}
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
// Returns an error if copy-on-write allocation fails.
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
}
// UpdateConvState writes a new conv state for current batch sequences.
// newState must have shape [dConv, hiddenSize, nSeqs].
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
// SetRows requires F32 source
srcF32 := src.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
}
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *HybridCache) Seqs() []int {
return slices.Clone(c.curSeqs)
}

View File

@@ -1,444 +0,0 @@
package lfm2
import (
"testing"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
// TestHybridCache tests verify the slot management logic of HybridCache.
// These tests focus on the recurrent state slot allocation, reference counting,
// and copy-on-write semantics without requiring a full ML backend.
// createSlotOnlyCache creates a HybridCache with only the slot management
// fields initialized. Used to test slot logic in isolation.
func createSlotOnlyCache(maxSequences int) *HybridCache {
return &HybridCache{
hiddenSize: 256,
dConv: 3,
maxSequences: maxSequences,
refCount: make([]int, maxSequences),
freeSlots: initFreeSlots(maxSequences),
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
}
}
func initFreeSlots(n int) []int {
slots := make([]int, 0, n)
for i := n - 1; i >= 0; i-- {
slots = append(slots, i)
}
return slots
}
func TestHybridCache_SlotAllocation(t *testing.T) {
cache := createSlotOnlyCache(4)
// Verify initial state
if len(cache.freeSlots) != 4 {
t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots))
}
// Allocate all slots
for range 4 {
slot, err := cache.allocSlot()
if err != nil {
t.Fatalf("allocSlot failed: %v", err)
}
cache.refCount[slot] = 1
}
// Should be full now
if len(cache.freeSlots) != 0 {
t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots))
}
// Trying to allocate another should fail
_, err := cache.allocSlot()
if err != kvcache.ErrKvCacheFull {
t.Errorf("expected ErrKvCacheFull, got %v", err)
}
}
func TestHybridCache_SlotReuse(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate a slot
slot1, _ := cache.allocSlot()
cache.refCount[slot1] = 1
// Free it
cache.refCount[slot1] = 0
cache.freeSlot(slot1)
// Allocate again - should get the same slot back (LIFO)
slot2, _ := cache.allocSlot()
if slot2 != slot1 {
t.Errorf("expected slot %d to be reused, got %d", slot1, slot2)
}
}
func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Simulate sharing slot with seq 2 (copy-on-write style)
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Should share the same slot
if cache.slotForSeq[2] != slot1 {
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
}
// Ref count should be 2
if cache.refCount[slot1] != 2 {
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Share with seq 2
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Unshare seq 2
cache.refCount[slot1]--
delete(cache.slotForSeq, 2)
// Ref count should be back to 1
if cache.refCount[slot1] != 1 {
t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1])
}
// Seq 2 should no longer have a slot
if _, ok := cache.slotForSeq[2]; ok {
t.Error("seq 2 should not have a slot after unshare")
}
}
func TestHybridCache_SlotFreeWhenUnused(t *testing.T) {
cache := createSlotOnlyCache(4)
initialFreeSlots := len(cache.freeSlots)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Free the slot when refCount drops to 0
cache.refCount[slot1]--
if cache.refCount[slot1] <= 0 {
cache.refCount[slot1] = 0
cache.freeSlot(slot1)
}
delete(cache.slotForSeq, 1)
// Slot should be freed
if len(cache.freeSlots) != initialFreeSlots {
t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots))
}
// Ref count should be 0
if cache.refCount[slot1] != 0 {
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_SlotOverwrite(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slots for seq 1 and seq 2
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
slot2, _ := cache.allocSlot()
cache.slotForSeq[2] = slot2
cache.refCount[slot2] = 1
initialFreeSlots := len(cache.freeSlots)
// Simulate overwriting seq 2's slot with slot1 (sharing)
// First free the old slot
cache.refCount[slot2]--
if cache.refCount[slot2] <= 0 {
cache.refCount[slot2] = 0
cache.freeSlot(slot2)
}
// Then share slot1
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Seq 2 should now share slot1
if cache.slotForSeq[2] != slot1 {
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
}
// Old slot2 should be freed
if len(cache.freeSlots) != initialFreeSlots+1 {
t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots))
}
}
func TestHybridCache_BoundsChecking(t *testing.T) {
cache := createSlotOnlyCache(4)
// Test freeing invalid slot (should not panic)
cache.freeSlot(-1)
cache.freeSlot(100) // out of bounds
// freeSlot does bounds checking, so invalid slots should be ignored
if len(cache.freeSlots) != 4 {
t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots))
}
}
func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) {
cache := createSlotOnlyCache(8)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Fork to seq 2, 3, 4 (all share slot1)
for _, seq := range []int{2, 3, 4} {
cache.slotForSeq[seq] = slot1
cache.refCount[slot1]++
}
// Ref count should be 4
if cache.refCount[slot1] != 4 {
t.Errorf("expected refCount 4, got %d", cache.refCount[slot1])
}
// Remove seq 2, 3
for _, seq := range []int{2, 3} {
delete(cache.slotForSeq, seq)
cache.refCount[slot1]--
}
if cache.refCount[slot1] != 2 {
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
}
// Slot should still be allocated (not in free list)
found := false
for _, s := range cache.freeSlots {
if s == slot1 {
found = true
break
}
}
if found {
t.Error("slot1 should not be in free list yet")
}
// Remove remaining sequences
for _, seq := range []int{1, 4} {
delete(cache.slotForSeq, seq)
cache.refCount[slot1]--
}
if cache.refCount[slot1] != 0 {
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_ChainedSharing(t *testing.T) {
cache := createSlotOnlyCache(8)
// Create seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Share 1 -> 2
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Share 2 -> 3 (should still share slot1)
cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1
cache.refCount[slot1]++
// All should share slot1
if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 {
t.Error("all sequences should share slot1")
}
if cache.refCount[slot1] != 3 {
t.Errorf("expected refCount 3, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_CacheParameters(t *testing.T) {
cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5
if cache.hiddenSize != 512 {
t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize)
}
if cache.dConv != 5 {
t.Errorf("expected dConv 5, got %d", cache.dConv)
}
}
func TestHybridCache_NumSeqs(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially no sequences
if cache.numSeqs() != 0 {
t.Errorf("expected 0 seqs, got %d", cache.numSeqs())
}
// Manually set up current batch state
cache.curSeqs = []int{1, 2, 3}
if cache.numSeqs() != 3 {
t.Errorf("expected 3 seqs, got %d", cache.numSeqs())
}
}
func TestHybridCache_SeqTokens(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially 0
if cache.seqTokens() != 0 {
t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens())
}
// Manually set up current batch state
cache.curSeqTokens = 16
if cache.seqTokens() != 16 {
t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens())
}
}
// Test that Seqs returns a clone of curSeqs
func TestHybridCache_Seqs_ReturnsClone(t *testing.T) {
cache := createSlotOnlyCache(4)
cache.curSeqs = []int{1, 2, 3}
seqs := cache.Seqs()
// Modify returned slice
seqs[0] = 999
// Original should be unchanged
if cache.curSeqs[0] != 1 {
t.Error("Seqs should return a clone, not the original slice")
}
}
func TestHybridCache_IsSupportedForBatch(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially not supported (no batch set up)
if cache.IsSupportedForBatch() {
t.Error("expected IsSupportedForBatch to be false initially")
}
// Set up a valid batch
cache.curSeqTokens = 1
cache.curSeqs = []int{1}
if !cache.IsSupportedForBatch() {
t.Error("expected IsSupportedForBatch to be true with valid batch")
}
}
func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) {
cache := createSlotOnlyCache(4)
// zeroConvSlots should handle empty slots without panicking
cache.zeroConvSlots(nil, nil)
cache.zeroConvSlots(nil, []int{})
// zeroConvSlots should handle empty convStates without panicking
cache.zeroConvSlots(nil, []int{0, 1, 2})
}
func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Free the slot (simulating sequence removal)
cache.refCount[slot1]--
cache.freeSlot(slot1)
delete(cache.slotForSeq, 1)
// Verify slot is in free list
if len(cache.freeSlots) != 4 {
t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots))
}
// Allocate for new seq 2 - should get recycled slot
slot2, _ := cache.allocSlot()
if slot2 != slot1 {
t.Errorf("expected recycled slot %d, got %d", slot1, slot2)
}
// This recycled slot would need zeroing in the real implementation
// The actual zeroing is tested via integration tests since it requires ML context
}
func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) {
cache := createSlotOnlyCache(4)
// Simulate the slot allocation flow from StartForward
// When a sequence doesn't have a slot, it gets allocated and tracked as "new"
newSlots := []int{}
// Seq 1 doesn't have a slot - allocate and track
seq := 1
if _, ok := cache.slotForSeq[seq]; !ok {
slot, err := cache.allocSlot()
if err != nil {
t.Fatalf("allocSlot failed: %v", err)
}
cache.slotForSeq[seq] = slot
cache.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
// Verify newSlots contains the allocated slot
if len(newSlots) != 1 {
t.Errorf("expected 1 new slot, got %d", len(newSlots))
}
// Seq 1 already has a slot - should NOT be tracked as new
newSlots2 := []int{}
if _, ok := cache.slotForSeq[seq]; !ok {
slot, _ := cache.allocSlot()
cache.slotForSeq[seq] = slot
cache.refCount[slot] = 1
newSlots2 = append(newSlots2, slot)
}
// Verify no new slots for existing sequence
if len(newSlots2) != 0 {
t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2))
}
}

View File

@@ -1,253 +0,0 @@
package lfm2
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize int
headDim, ropeDim int
eps, ropeBase, ropeScale float32
ropeType string
originalContextLength int
// per-layer head counts (LFM2 alternates attention and recurrent layers)
numHeadsByLayer []int
numKVHeadsByLayer []int
}
func (o Options) headDimValue() int {
// Head dim is shared across layers; fall back to first attention layer head count.
for _, h := range o.numHeadsByLayer {
if h > 0 {
return cmp.Or(o.headDim, o.hiddenSize/h)
}
}
return cmp.Or(o.headDim, o.hiddenSize)
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
opts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(attnFactor),
)
}
headCount := 1
for _, h := range o.numHeadsByLayer {
if h > 0 {
headCount = h
break
}
}
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/headCount), o.ropeBase, 1./o.ropeScale, opts...)
}
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Options
}
func New(c fs.Config) (model.Model, error) {
if c.Uint("expert_count") > 0 {
return nil, model.ErrUnsupportedModel
}
if c.String("tokenizer.ggml.model") != "gpt2" {
return nil, model.ErrUnsupportedTokenizer
}
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
var pretokenizers []string
switch c.String("tokenizer.ggml.pre") {
case "default":
// use default BPE pretokenizer
default:
// llama-bpe style (default for LFM2)
pretokenizers = []string{
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
}
}
m := Model{
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeType: c.String("rope.scaling.type"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
},
}
type headCounts interface {
HeadCount() []uint64
HeadCountKV() []uint64
}
hc, ok := c.(headCounts)
if !ok {
return nil, model.ErrUnsupportedModel
}
headCount := hc.HeadCount()
headCountKV := hc.HeadCountKV()
m.numHeadsByLayer = make([]int, len(m.Layers))
m.numKVHeadsByLayer = make([]int, len(m.Layers))
for i := range m.Layers {
m.numHeadsByLayer[i] = int(headCount[i])
m.numKVHeadsByLayer[i] = int(headCountKV[i])
if m.numKVHeadsByLayer[i] == 0 {
m.Layers[i].Operator = &ShortConv{}
} else {
m.Layers[i].Operator = &Attention{}
}
}
lCache := int(c.Uint("shortconv.l_cache"))
dConv := max(0, lCache-1)
m.Cache = NewHybridCache(m.Shift, m.hiddenSize, dConv)
return &m, nil
}
type Operator interface {
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output,alt:attn_out"`
}
func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
headDim := opts.headDimValue()
numHeads := opts.numHeadsByLayer[layer]
numKVHeads := opts.numKVHeadsByLayer[layer]
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, headDim, numHeads, batchSize)
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return sa.Output.Forward(ctx, attention)
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Operator Operator
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.Operator.Forward(ctx, hiddenState, positions, cache, layer, opts)
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, outputs, m.Cache.(*HybridCache), &m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {
model.Register("lfm2", New)
}

View File

@@ -1,50 +0,0 @@
package lfm2
import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type shortConvKernel struct {
Weight ml.Tensor `gguf:"weight"`
}
// ShortConv implements the LFM2 short-convolution block (GGML_OP_SSM_CONV) with a recurrent
// state stored in the HybridCache.
type ShortConv struct {
Conv *shortConvKernel `gguf:"shortconv.conv"`
InProj *nn.Linear `gguf:"shortconv.in_proj"`
OutProj *nn.Linear `gguf:"shortconv.out_proj"`
}
func (sc *ShortConv) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
nSeqs := cache.numSeqs()
seqTokens := cache.seqTokens()
hiddenSize := hiddenStates.Dim(0)
if nSeqs <= 0 || seqTokens <= 0 || hiddenStates.Dim(1) != nSeqs*seqTokens {
panic("lfm2: unsupported batch layout for shortconv")
}
bcx := sc.InProj.Forward(ctx, hiddenStates).Reshape(ctx, 3*hiddenSize, seqTokens, nSeqs)
elementSize := bcx.Stride(0)
b := bcx.View(ctx, 0*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
c := bcx.View(ctx, 1*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
x := bcx.View(ctx, 2*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
bx := b.Mul(ctx, x).Permute(ctx, 1, 0, 2, 3)
state, err := cache.ConvState(ctx, layer)
if err != nil {
panic("lfm2: failed to get conv state: " + err.Error())
}
sx := state.Concat(ctx, bx, 0)
convOut := sx.SSMConv(ctx, sc.Conv.Weight)
y := c.Mul(ctx, convOut)
dConv := sx.Dim(0) - seqTokens
cache.UpdateConvState(ctx, layer, sx.Slice(ctx, 0, sx.Dim(0)-dConv, sx.Dim(0), 1))
return sc.OutProj.Forward(ctx, y.Reshape(ctx, hiddenSize, seqTokens*nSeqs))
}

View File

@@ -7,9 +7,7 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/glm4moelite"
_ "github.com/ollama/ollama/model/models/gptoss"
_ "github.com/ollama/ollama/model/models/lfm2"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"

View File

@@ -1,410 +0,0 @@
package parsers
import (
"context"
"encoding/xml"
"fmt"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type glm46ParserState int
const (
glm46ParserState_LookingForThinkingOpen glm46ParserState = iota
glm46ParserState_ThinkingStartedEatingWhitespace
glm46ParserState_CollectingThinking
glm46ParserState_ThinkingDoneEatingWhitespace
glm46ParserState_CollectingContent
glm46ParserState_ToolStartedEatingWhitespace
glm46ParserState_CollectingToolContent
)
const (
glm46ThinkingOpenTag = "<think>"
glm46ThinkingCloseTag = "</think>"
glm46ToolOpenTag = "<tool_call>"
glm46ToolCloseTag = "</tool_call>"
)
type GLM46Parser struct {
state glm46ParserState
buffer strings.Builder
tools []api.Tool
}
func (p *GLM46Parser) HasToolSupport() bool {
return true
}
func (p *GLM46Parser) HasThinkingSupport() bool {
return true
}
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
return tools
}
type glm46Event interface {
isGLM46Event()
}
type glm46EventContent struct {
content string
}
func (glm46EventContent) isGLM46Event() {}
type glm46EventRawToolCall struct {
raw string
}
func (glm46EventRawToolCall) isGLM46Event() {}
type glm46EventThinkingContent struct {
content string
}
func (glm46EventThinkingContent) isGLM46Event() {}
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case glm46EventRawToolCall:
toolCall, err := parseGLM46ToolCall(event, p.tools)
if err != nil {
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case glm46EventThinkingContent:
thinkingSb.WriteString(event.content)
case glm46EventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content
// events, we naively append them together here.
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *GLM46Parser) parseEvents() []glm46Event {
var all []glm46Event
keepLooping := true
for keepLooping {
var events []glm46Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4.6 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
// and transitions to the next state. Returns (nil, false) if only whitespace remains
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
func (p *GLM46Parser) eatLeadingWhitespaceAndTransitionTo(nextState glm46ParserState) ([]glm46Event, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false // Still only whitespace, keep waiting for more input
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true // Successfully transitioned
}
// glm46SplitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
// the content after (optionally trimmed of leading whitespace), and updates the buffer
func glm46SplitAtTag(p *GLM46Parser, tag string, trimAfter bool) (string, string) {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
p.buffer.Reset()
p.buffer.WriteString(after)
return before, after
}
func (p *GLM46Parser) eat() ([]glm46Event, bool) {
var events []glm46Event
switch p.state {
case glm46ParserState_LookingForThinkingOpen:
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, glm46ThinkingOpenTag) {
// Found <think> opening tag
after := strings.TrimPrefix(trimmed, glm46ThinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
p.state = glm46ParserState_ThinkingStartedEatingWhitespace
} else {
p.state = glm46ParserState_CollectingThinking
}
return events, true
} else if strings.HasPrefix(glm46ThinkingOpenTag, trimmed) {
// Partial opening tag seen, keep accumulating
return events, false
} else if trimmed == "" {
// Only whitespace, keep accumulating
return events, false
} else {
// No thinking tag found, skip to content collection
p.state = glm46ParserState_CollectingContent
// Don't trim - we want to keep the original content
return events, true
}
case glm46ParserState_ThinkingStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingThinking)
case glm46ParserState_CollectingThinking:
acc := p.buffer.String()
if strings.Contains(acc, glm46ThinkingCloseTag) {
thinking, remaining := glm46SplitAtTag(p, glm46ThinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, glm46EventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = glm46ParserState_ThinkingDoneEatingWhitespace
} else {
p.state = glm46ParserState_CollectingContent
}
return events, true
} else if overlapLen := overlap(acc, glm46ThinkingCloseTag); overlapLen > 0 {
// Partial closing tag - withhold it along with any trailing whitespace before it
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventThinkingContent{content: unambiguous})
}
return events, false
} else {
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventThinkingContent{content: unambiguous})
}
return events, false
}
case glm46ParserState_ThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingContent)
case glm46ParserState_CollectingContent:
if strings.Contains(p.buffer.String(), glm46ToolOpenTag) {
before, after := glm46SplitAtTag(p, glm46ToolOpenTag, true)
if len(before) > 0 {
events = append(events, glm46EventContent{content: before})
}
if after == "" {
p.state = glm46ParserState_ToolStartedEatingWhitespace
} else {
p.state = glm46ParserState_CollectingToolContent
}
return events, true
} else if overlapLen := overlap(p.buffer.String(), glm46ToolOpenTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventContent{content: unambiguous})
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventContent{content: unambiguous})
}
return events, false
}
case glm46ParserState_ToolStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingToolContent)
case glm46ParserState_CollectingToolContent:
acc := p.buffer.String()
if strings.Contains(acc, glm46ToolCloseTag) {
toolContent, _ := glm46SplitAtTag(p, glm46ToolCloseTag, true)
if len(toolContent) == 0 {
slog.Warn("glm46 tool call closing tag found but no content before it")
}
events = append(events, glm46EventRawToolCall{raw: toolContent})
p.state = glm46ParserState_CollectingContent
return events, true
} else {
// Keep accumulating - tool calls are not streamed
// We just wait for the closing tag
return events, false
}
default:
panic("unreachable")
}
}
// GLMToolCallXML represents the structure of a GLM-4.6 tool call for XML parsing
type GLMToolCallXML struct {
XMLName xml.Name `xml:"tool_call"`
Content string `xml:",chardata"` // Function name (text nodes between tags)
Keys []string `xml:"arg_key"` // All arg_key elements in document order
Values []string `xml:"arg_value"` // All arg_value elements in document order
}
// escapeGLM46Content escapes XML entities in text content while preserving arg_key/arg_value tags
func escapeGLM46Content(s string) string {
var result strings.Builder
inTag := false
for i := range len(s) {
ch := s[i]
if ch == '<' {
// Check if this is a known tag
if strings.HasPrefix(s[i:], "<arg_key>") ||
strings.HasPrefix(s[i:], "</arg_key>") ||
strings.HasPrefix(s[i:], "<arg_value>") ||
strings.HasPrefix(s[i:], "</arg_value>") {
inTag = true
}
}
if inTag {
result.WriteByte(ch)
if ch == '>' {
inTag = false
}
} else {
// Escape special characters in text content
switch ch {
case '&':
result.WriteString("&amp;")
case '<':
result.WriteString("&lt;")
case '>':
result.WriteString("&gt;")
default:
result.WriteByte(ch)
}
}
}
return result.String()
}
func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
// Escape any unescaped entities in text content
// We need to escape text between tags, but not the tags themselves
escaped := escapeGLM46Content(raw.raw)
// Wrap the content in a root element to make it valid XML
xmlString := "<tool_call>" + escaped + "</tool_call>"
// Parse XML into struct
var parsed GLMToolCallXML
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
}
// Extract and trim function name
functionName := strings.TrimSpace(parsed.Content)
if functionName == "" {
return api.ToolCall{}, fmt.Errorf("empty function name")
}
// Verify keys and values are paired correctly
if len(parsed.Keys) != len(parsed.Values) {
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
}
// Find the matching tool to get parameter types
var matchedTool *api.Tool
for i := range tools {
if tools[i].Function.Name == functionName {
matchedTool = &tools[i]
break
}
}
// Build arguments map by pairing keys and values
toolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: functionName,
Arguments: api.NewToolCallFunctionArguments(),
},
}
for i := range parsed.Keys {
key := strings.TrimSpace(parsed.Keys[i])
value := parsed.Values[i] // Don't trim here - parseValue handles it
// Look up parameter type
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
// Handle anyOf by collecting all types from the union
if len(prop.AnyOf) > 0 {
for _, anyOfProp := range prop.AnyOf {
paramType = append(paramType, anyOfProp.Type...)
}
} else {
paramType = prop.Type
}
}
}
// Parse value with type coercion
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
}
return toolCall, nil
}

View File

@@ -1,862 +0,0 @@
package parsers
import (
"encoding/xml"
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
func TestGLM46ParserStreaming(t *testing.T) {
type step struct {
input string
wantEvents []glm46Event
}
cases := []struct {
desc string
steps []step
only bool
}{
{
desc: "leading whitespace before think tag",
steps: []step{
{
input: " \n\t ",
wantEvents: []glm46Event{},
},
{
input: "<think>thinking</think>",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking"}},
},
},
},
{
desc: "think tag with whitespace inside",
steps: []step{
{
input: "<think> \n thinking content \n </think>regular content",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking content"},
glm46EventContent{content: "regular content"},
},
},
},
},
{
desc: "tool call with leading whitespace after opening tag",
steps: []step{
{
input: "<think></think><tool_call> \n test \n </tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "test"},
},
},
},
},
{
desc: "simple thinking then content",
steps: []step{
{
input: "<think>I am thinking</think>Now I respond",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "I am thinking"},
glm46EventContent{content: "Now I respond"},
},
},
},
},
{
desc: "streamed thinking content",
steps: []step{
{
input: "<think>hello",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "hello"}},
},
{
input: " world",
wantEvents: []glm46Event{glm46EventThinkingContent{content: " world"}},
},
{
input: "</think>content",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "content before tool call",
steps: []step{
{
input: "<think>Let me call a tool</think>here is text<tool_call>",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "Let me call a tool"},
glm46EventContent{content: "here is text"},
},
},
{
input: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>\n</tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>"},
},
},
},
},
{
desc: "tool call with content after",
steps: []step{
{
input: "<think>thinking</think><tool_call>test</tool_call>after tool",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "after tool"},
},
},
},
},
{
desc: "trailing whitespace between content and tool call is trimmed",
steps: []step{
{
input: "<think>thinking</think>content\n \t <tool_call>test</tool_call>",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
glm46EventContent{content: "content"},
glm46EventRawToolCall{raw: "test"},
},
},
},
},
{
desc: "trailing whitespace between tool call and content is trimmed",
steps: []step{
{
input: "<think>think</think><tool_call>test</tool_call>\n\t after",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "think"},
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "after"},
},
},
},
},
{
desc: "split thinking close tag",
steps: []step{
{
input: "<think>thinking content</th",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking content"}},
},
{
input: "ink>after",
wantEvents: []glm46Event{
glm46EventContent{content: "after"},
},
},
},
},
{
desc: "split thinking open tag",
steps: []step{
{
input: " <thi",
wantEvents: []glm46Event{},
},
{
input: "nk>content</think>",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
},
},
},
{
desc: "split tool open tag",
steps: []step{
{
input: "<think>think</think>content<tool",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "think"}, glm46EventContent{content: "content"}},
},
{
input: "_call>inside",
wantEvents: []glm46Event{},
},
{
input: "</tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "inside"},
},
},
},
},
{
desc: "partial thinking close tag fakeout",
steps: []step{
{
input: "<think>content</th",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
},
{
input: "ought more",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "</thought more"}},
},
},
},
{
desc: "partial thinking open tag fakeout",
steps: []step{
{
input: " <thi",
wantEvents: []glm46Event{},
},
{
input: "nking is fun",
wantEvents: []glm46Event{
glm46EventContent{content: " <thinking is fun"},
},
},
},
},
{
desc: "partial tool open tag fakeout",
steps: []step{
{
input: "<think></think>content\n<tool",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
{
input: " fakeout",
wantEvents: []glm46Event{
glm46EventContent{content: "\n<tool fakeout"},
},
},
},
},
{
desc: "partial tool close tag fakeout",
steps: []step{
{
input: "<think></think><tool_call>content</tool",
wantEvents: []glm46Event{},
},
{
input: " fakeout",
wantEvents: []glm46Event{},
},
{
input: "</tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "content</tool fakeout"},
},
},
},
},
{
desc: "empty thinking tag",
steps: []step{
{
input: "<think></think>content here",
wantEvents: []glm46Event{
glm46EventContent{content: "content here"},
},
},
},
},
{
desc: "multiple tool calls in sequence",
steps: []step{
{
input: "<think>think</think><tool_call>first</tool_call>between<tool_call>second</tool_call>end",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "think"},
glm46EventRawToolCall{raw: "first"},
glm46EventContent{content: "between"},
glm46EventRawToolCall{raw: "second"},
glm46EventContent{content: "end"},
},
},
},
},
{
desc: "no thinking tag - direct to content",
steps: []step{
{
input: "just content here",
wantEvents: []glm46Event{
glm46EventContent{content: "just content here"},
},
},
},
},
{
desc: "no thinking tag - skip to content then tool call",
steps: []step{
{
input: "Here's the answer:<tool_call>test</tool_call>done",
wantEvents: []glm46Event{
glm46EventContent{content: "Here's the answer:"},
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "done"},
},
},
},
},
{
desc: "no thinking tag - whitespace preserved when no tags",
steps: []step{
{
input: " \n content with leading whitespace",
wantEvents: []glm46Event{
glm46EventContent{content: " \n content with leading whitespace"},
},
},
},
},
{
desc: "whitespace after think close tag gets eaten",
steps: []step{
{
input: "<think>thinking</think> \n\t content",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "whitespace after tool_call close tag gets eaten",
steps: []step{
{
input: "<think></think><tool_call>test</tool_call> \n\t content",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "thinking content withholds trailing whitespace (single chunk)",
steps: []step{
{
input: "<think>thinking content ",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking content"},
},
},
{
input: "</think>after",
wantEvents: []glm46Event{
glm46EventContent{content: "after"},
},
},
},
},
{
desc: "thinking content withholds trailing whitespace with newlines",
steps: []step{
{
input: "<think>thinking\n\n ",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
},
},
{
input: "</think>content",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "thinking content trailing whitespace emitted when more content arrives",
steps: []step{
{
input: "<think>thinking ",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
},
},
{
input: "more thinking",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: " more thinking"},
},
},
{
input: "</think>",
wantEvents: []glm46Event{},
},
},
},
{
desc: "thinking content withholds trailing whitespace before partial close tag",
steps: []step{
{
input: "<think>thinking </th",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
},
},
{
input: "ink>content",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
},
},
}
anyOnlies := false
for _, tc := range cases {
if tc.only {
anyOnlies = true
}
}
for _, tc := range cases {
if anyOnlies && !tc.only {
continue
}
t.Run(tc.desc, func(t *testing.T) {
parser := GLM46Parser{}
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
gotEvents := parser.parseEvents()
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
// avoid deep equal on empty vs. nil slices
continue
}
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
}
}
})
}
}
// TestGLMToolCallXMLOrderPreservation verifies that xml.Unmarshal preserves
// document order when collecting multiple elements with the same tag name into slices.
// This is a critical assumption for the GLM-4.6 parser's struct-based approach.
func TestGLMToolCallXMLOrderPreservation(t *testing.T) {
testCases := []struct {
name string
xml string
wantKeys []string
wantValues []string
}{
{
name: "alternating keys and values",
xml: `<tool_call>
function_name
<arg_key>first</arg_key>
<arg_value>A</arg_value>
<arg_key>second</arg_key>
<arg_value>B</arg_value>
<arg_key>third</arg_key>
<arg_value>C</arg_value>
</tool_call>`,
wantKeys: []string{"first", "second", "third"},
wantValues: []string{"A", "B", "C"},
},
{
name: "all keys then all values",
xml: `<tool_call>
function_name
<arg_key>key1</arg_key>
<arg_key>key2</arg_key>
<arg_key>key3</arg_key>
<arg_value>val1</arg_value>
<arg_value>val2</arg_value>
<arg_value>val3</arg_value>
</tool_call>`,
wantKeys: []string{"key1", "key2", "key3"},
wantValues: []string{"val1", "val2", "val3"},
},
{
name: "mixed grouping",
xml: `<tool_call>
function_name
<arg_key>a</arg_key>
<arg_value>1</arg_value>
<arg_key>b</arg_key>
<arg_key>c</arg_key>
<arg_value>2</arg_value>
<arg_value>3</arg_value>
</tool_call>`,
wantKeys: []string{"a", "b", "c"},
wantValues: []string{"1", "2", "3"},
},
{
name: "reverse order - all values then all keys",
xml: `<tool_call>
function_name
<arg_value>X</arg_value>
<arg_value>Y</arg_value>
<arg_value>Z</arg_value>
<arg_key>x</arg_key>
<arg_key>y</arg_key>
<arg_key>z</arg_key>
</tool_call>`,
wantKeys: []string{"x", "y", "z"},
wantValues: []string{"X", "Y", "Z"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var parsed GLMToolCallXML
err := xml.Unmarshal([]byte(tc.xml), &parsed)
if err != nil {
t.Fatalf("failed to unmarshal XML: %v", err)
}
if !reflect.DeepEqual(parsed.Keys, tc.wantKeys) {
t.Errorf("Keys order mismatch:\ngot: %v\nwant: %v", parsed.Keys, tc.wantKeys)
}
if !reflect.DeepEqual(parsed.Values, tc.wantValues) {
t.Errorf("Values order mismatch:\ngot: %v\nwant: %v", parsed.Values, tc.wantValues)
}
})
}
}
func TestGLM46ToolCallParsing(t *testing.T) {
type testCase struct {
name string
rawToolCall string
tools []api.Tool
wantToolCall api.ToolCall
}
cases := []testCase{
{
name: "simple tool call",
tools: []api.Tool{},
rawToolCall: `get-current-weather
<arg_key>location</arg_key>
<arg_value>New York, NY</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-current-weather",
Arguments: args(`{"location": "New York, NY", "unit": "celsius"}`),
},
},
},
{
name: "tool call with typed parameters",
tools: []api.Tool{
tool("calculate", map[string]api.ToolProperty{
"x": {Type: api.PropertyType{"number"}},
"y": {Type: api.PropertyType{"integer"}},
"enabled": {Type: api.PropertyType{"boolean"}},
"items": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `calculate
<arg_key>x</arg_key>
<arg_value>3.14</arg_value>
<arg_key>y</arg_key>
<arg_value>42</arg_value>
<arg_key>enabled</arg_key>
<arg_value>true</arg_value>
<arg_key>items</arg_key>
<arg_value>["a", "b", "c"]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: args(`{"enabled": true, "items": ["a", "b", "c"], "x": 3.14, "y": 42}`),
},
},
},
{
name: "function name with whitespace",
tools: []api.Tool{},
rawToolCall: ` get-weather
<arg_key>city</arg_key>
<arg_value>Paris</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-weather",
Arguments: args(`{"city": "Paris"}`),
},
},
},
{
name: "values with special characters",
tools: []api.Tool{},
rawToolCall: `execute-command
<arg_key>command</arg_key>
<arg_value>ls && echo "done"</arg_value>
<arg_key>message</arg_key>
<arg_value>a < b and c > d</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "execute-command",
Arguments: args(`{"command": "ls && echo \"done\"", "message": "a < b and c > d"}`),
},
},
},
{
name: "unicode in function names and values",
tools: []api.Tool{},
rawToolCall: `获取天气
<arg_key>城市</arg_key>
<arg_value>北京</arg_value>
<arg_key>message</arg_key>
<arg_value>Hello! 你好! 🌟</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: args(`{"message": "Hello! 你好! 🌟", "城市": "北京"}`),
},
},
},
{
name: "empty value",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param1</arg_key>
<arg_value></arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"param1": ""}`),
},
},
},
{
name: "special chars in arg_key names",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param<1></arg_key>
<arg_value>value1</arg_value>
<arg_key>a&b</arg_key>
<arg_value>value2</arg_value>
<arg_key>x>y</arg_key>
<arg_value>value3</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"a&b": "value2", "param<1>": "value1", "x>y": "value3"}`),
},
},
},
{
name: "multiple consecutive ampersands",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param</arg_key>
<arg_value>test &&&& more</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"param": "test &&&& more"}`),
},
},
},
{
name: "mixed special chars together",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param</arg_key>
<arg_value><>&<>&</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"param": "<>&<>&"}`),
},
},
},
{
name: "newlines and tabs in parameter values",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>multiline</arg_key>
<arg_value>line1
indented line2
line3</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"multiline": "line1\n\tindented line2\nline3"}`),
},
},
},
{
name: "single and double quotes in values",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>quotes</arg_key>
<arg_value>She said "Hello's there!"</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"quotes": "She said \"Hello's there!\""}`),
},
},
},
{
name: "CDATA-like content that should be treated as text",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>cdata</arg_key>
<arg_value><![CDATA[not actual cdata]]></arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"cdata": "<![CDATA[not actual cdata]]>"}`),
},
},
},
{
name: "all special XML entities",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>entities</arg_key>
<arg_value>&lt;&gt;&amp;&apos;&quot;</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"entities": "&lt;&gt;&amp;&apos;&quot;"}`),
},
},
},
{
name: "order preservation with multiple parameters",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>first</arg_key>
<arg_value>value1</arg_value>
<arg_key>second</arg_key>
<arg_value>value2</arg_value>
<arg_key>third</arg_key>
<arg_value>value3</arg_value>
<arg_key>fourth</arg_key>
<arg_value>value4</arg_value>
<arg_key>fifth</arg_key>
<arg_value>value5</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"fifth": "value5", "first": "value1", "fourth": "value4", "second": "value2", "third": "value3"}`),
},
},
},
{
name: "order preservation with identical key names but different positions",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param</arg_key>
<arg_value>first occurrence</arg_value>
<arg_key>other</arg_key>
<arg_value>middle</arg_value>
<arg_key>param</arg_key>
<arg_value>second occurrence</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
// Later occurrence should overwrite earlier one
Arguments: args(`{"other": "middle", "param": "second occurrence"}`),
},
},
},
{
name: "array with mixed types",
tools: []api.Tool{
tool("process", map[string]api.ToolProperty{
"items": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `process
<arg_key>items</arg_key>
<arg_value>[1, "hello", true, null]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process",
Arguments: args(`{"items": [1, "hello", true, null]}`),
},
},
},
{
name: "empty array",
tools: []api.Tool{
tool("test", map[string]api.ToolProperty{
"tags": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `test
<arg_key>tags</arg_key>
<arg_value>[]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test",
Arguments: args(`{"tags": []}`),
},
},
},
{
name: "anyOf array or string - with array of objects",
tools: []api.Tool{
tool("TodoWrite", map[string]api.ToolProperty{
"todos": {AnyOf: []api.ToolProperty{{Type: api.PropertyType{"array"}}, {Type: api.PropertyType{"string"}}}},
}),
},
// <tool_call>TodoWrite
// <arg_key>todos</arg_key>
// <arg_value>[{"content": "Set up HTML file and basic structure", "id": "1", "priority": "high", "status": "pending"}, {"content": "Create 3D scene with Three.js", "id": "2", "priority": "high", "status": "pending"}, {"content": "Implement terrain generation with blocks", "id": "3", "priority": "high", "status": "pending"}, {"content": "Add player controls (movement, camera)", "id": "4", "priority": "high", "status": "pending"}, {"content": "Implement block placement/destruction", "id": "5", "priority": "medium", "status": "pending"}, {"content": "Add lighting and textures", "id": "6", "priority": "medium", "status": "pending"}, {"content": "Test and optimize performance", "id": "7", "priority": "low", "status": "pending"}]</arg_value>
// </tool_call>
rawToolCall: `TodoWrite
<arg_key>todos</arg_key>
<arg_value>[{"content": "task 1", "status": "pending", "priority": "high", "id": "1"}, {"content": "task 2", "status": "completed", "priority": "low", "id": "2"}]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "TodoWrite",
Arguments: args(`{"todos": [{"content": "task 1", "id": "1", "priority": "high", "status": "pending"}, {"content": "task 2", "id": "2", "priority": "low", "status": "completed"}]}`),
},
},
},
{
name: "anyOf array or string - with plain string",
tools: []api.Tool{
tool("TodoWrite", map[string]api.ToolProperty{
"todos": {Type: api.PropertyType{"array", "string"}},
}),
},
rawToolCall: `TodoWrite
<arg_key>todos</arg_key>
<arg_value>Error: could not load todos</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "TodoWrite",
Arguments: args(`{"todos": "Error: could not load todos"}`),
},
},
},
}
for i, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
gotToolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: tc.rawToolCall}, tc.tools)
if err != nil {
t.Errorf("case %d (%s): %v", i, tc.name, err)
}
if !toolCallEqual(gotToolCall, tc.wantToolCall) {
t.Errorf("case %d (%s): got tool call %#v, want %#v", i, tc.name, gotToolCall, tc.wantToolCall)
}
})
}
}

View File

@@ -1,20 +0,0 @@
package parsers
import "github.com/ollama/ollama/api"
// GLM47Parser extends GLM46Parser with thinking-aware initialization.
// GLM-4.7's prompt ends with <think> when thinking is enabled, so the parser
// must start in CollectingThinking state (the model outputs thinking content directly).
type GLM47Parser struct {
GLM46Parser
}
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
// When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() {
p.state = glm46ParserState_CollectingThinking
}
return tools
}

View File

@@ -1,99 +0,0 @@
package parsers
import (
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
func TestGLM47ParserAdd(t *testing.T) {
parser := GLM47Parser{}
parser.Init([]api.Tool{
tool("calculate", map[string]api.ToolProperty{
"count": {Type: api.PropertyType{"integer"}},
"enabled": {Type: api.PropertyType{"boolean"}},
}),
}, nil, nil)
// When thinking is enabled (thinkValue nil), the prompt ends with <think>,
// so the model output does NOT include the opening <think> tag.
content, thinking, calls, err := parser.Add("plan</think>Answer<tool_call>calculate<arg_key>count</arg_key><arg_value>3</arg_value><arg_key>enabled</arg_key><arg_value>true</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "plan" {
t.Fatalf("expected thinking 'plan', got %q", thinking)
}
if content != "Answer" {
t.Fatalf("expected content 'Answer', got %q", content)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
expectedArgs := args(`{"count": 3, "enabled": true}`)
if !toolCallEqual(api.ToolCall{Function: api.ToolCallFunction{Arguments: calls[0].Function.Arguments}}, api.ToolCall{Function: api.ToolCallFunction{Arguments: expectedArgs}}) {
t.Fatalf("expected args %#v, got %#v", expectedArgs.ToMap(), calls[0].Function.Arguments.ToMap())
}
}
func TestGLM47ParserNoThinkingContent(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
// When thinking is enabled but model has no thinking to output,
// it should output </think> immediately followed by content.
content, thinking, calls, err := parser.Add("</think>Plain answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if content != "Plain answer" {
t.Fatalf("expected content 'Plain answer', got %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestGLM47ParserThinkingDisabled(t *testing.T) {
parser := GLM47Parser{}
// When thinking is disabled, parser stays in LookingForThinkingOpen state
parser.Init(nil, nil, &api.ThinkValue{Value: false})
// Model outputs plain content (prompt ended with </think>)
content, thinking, calls, err := parser.Add("Plain answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if content != "Plain answer" {
t.Fatalf("expected content 'Plain answer', got %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestGLM47ParserToolCallEscaping(t *testing.T) {
toolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: `exec
<arg_key>expr</arg_key>
<arg_value>a < b && c > d</arg_value>`}, nil)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
expected := api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: args(`{"expr": "a < b && c > d"}`),
},
}
if !reflect.DeepEqual(toolCall, expected) {
t.Fatalf("expected %#v, got %#v", expected, toolCall)
}
}

View File

@@ -1,498 +0,0 @@
package parsers
import (
"encoding/json"
"errors"
"log/slog"
"strconv"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
type LFM2ParserState int
const (
LFM2CollectingThinking LFM2ParserState = iota
LFM2CollectingContent
LFM2CollectingToolCalls
)
const (
lfm2ThinkingOpenTag = "<think>"
lfm2ThinkingCloseTag = "</think>"
lfm2ToolCallStartTag = "<|tool_call_start|>"
lfm2ToolCallEndTag = "<|tool_call_end|>"
)
type LFM2Parser struct {
state LFM2ParserState
buffer strings.Builder
hasThinkingSupport bool
needsThinkingLeadingTrim bool // trim leading whitespace after <think> tag
needsContentLeadingTrim bool // trim leading whitespace after </think> tag
}
func (p *LFM2Parser) HasToolSupport() bool {
return true
}
func (p *LFM2Parser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.ThinkValue) {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
// Check both model capability AND request preference
thinkingEnabled := p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool())
if !thinkingEnabled {
p.state = LFM2CollectingContent
return
}
if prefill && lastMessage.Content != "" {
p.state = LFM2CollectingContent
return
}
p.state = LFM2CollectingThinking
p.needsThinkingLeadingTrim = true
}
func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.setInitialState(lastMessage, thinkValue)
return tools
}
type lfm2Event interface {
isLFM2Event()
}
type lfm2EventThinkingContent struct {
content string
}
type lfm2EventContent struct {
content string
}
type lfm2EventToolCall struct {
toolCall api.ToolCall
}
func (lfm2EventThinkingContent) isLFM2Event() {}
func (lfm2EventContent) isLFM2Event() {}
func (lfm2EventToolCall) isLFM2Event() {}
func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case lfm2EventToolCall:
toolCalls = append(toolCalls, event.toolCall)
case lfm2EventThinkingContent:
thinkingSb.WriteString(event.content)
case lfm2EventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *LFM2Parser) parseEvents() []lfm2Event {
var all []lfm2Event
keepLooping := true
for keepLooping {
var events []lfm2Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
return all
}
func (p *LFM2Parser) eat() ([]lfm2Event, bool) {
var events []lfm2Event
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case LFM2CollectingThinking:
// Strip opening <think> tag if present
if strings.HasPrefix(bufStr, lfm2ThinkingOpenTag) {
bufStr = bufStr[len(lfm2ThinkingOpenTag):]
p.needsThinkingLeadingTrim = true
p.buffer.Reset()
p.buffer.WriteString(bufStr)
}
// Trim leading whitespace after <think> tag (may span multiple chunks)
if p.needsThinkingLeadingTrim {
if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
bufStr = trimmed
p.buffer.Reset()
p.buffer.WriteString(bufStr)
}
// Clear flag once we have non-whitespace content or buffer is empty
if len(bufStr) > 0 {
p.needsThinkingLeadingTrim = false
}
}
if strings.Contains(bufStr, lfm2ThinkingCloseTag) { // thinking[</think>] -> content
split := strings.SplitN(bufStr, lfm2ThinkingCloseTag, 2)
thinking := split[0]
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
remaining := split[1]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = LFM2CollectingContent
p.needsThinkingLeadingTrim = false
// Set flag to trim any additional whitespace that may arrive in later chunks
p.needsContentLeadingTrim = len(remaining) == 0
if len(thinking) > 0 {
events = append(events, lfm2EventThinkingContent{content: thinking})
}
return events, true
} else if overlapLen := overlap(bufStr, lfm2ThinkingCloseTag); overlapLen > 0 { // partial </think>
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, lfm2EventThinkingContent{content: unambiguous})
}
return events, false
} else { // otherwise its thinking content
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, lfm2EventThinkingContent{content: unambiguous})
}
return events, false
}
case LFM2CollectingContent:
// Trim leading whitespace after </think> tag (may span multiple chunks)
if p.needsContentLeadingTrim {
if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
bufStr = trimmed
p.buffer.Reset()
p.buffer.WriteString(bufStr)
}
// Clear flag once we have non-whitespace content
if len(bufStr) > 0 {
p.needsContentLeadingTrim = false
}
}
if strings.Contains(bufStr, lfm2ToolCallStartTag) { // content[<|tool_call_start|>] -> tool calls
split := strings.SplitN(bufStr, lfm2ToolCallStartTag, 2)
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = LFM2CollectingToolCalls
if len(contentBefore) > 0 {
events = append(events, lfm2EventContent{content: contentBefore})
}
return events, true
} else { // otherwise its content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, lfm2EventContent{content: bufStr})
}
return events, false
}
case LFM2CollectingToolCalls:
// Look for complete tool call JSON between tags
if idx := strings.Index(bufStr, lfm2ToolCallEndTag); idx != -1 {
toolCallContent := bufStr[:idx]
if toolCalls, err := p.parseToolCallsContent(toolCallContent); err == nil && len(toolCalls) > 0 {
remaining := bufStr[idx+len(lfm2ToolCallEndTag):]
// Check if there's another tool call
if strings.HasPrefix(remaining, lfm2ToolCallStartTag) {
remaining = remaining[len(lfm2ToolCallStartTag):]
} else {
// No more tool calls, go back to content
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.state = LFM2CollectingContent
}
p.buffer.Reset()
p.buffer.WriteString(remaining)
for _, tc := range toolCalls {
events = append(events, lfm2EventToolCall{toolCall: tc})
}
return events, true
} else if err != nil {
slog.Warn("lfm2 tool call parsing failed", "error", err, "content", toolCallContent)
}
}
return events, false
}
return events, false
}
// parseToolCallsContent parses one or more tool calls from content
// Supports JSON format and Python-style format including multiple calls: [func1(...),func2(...)]
func (p *LFM2Parser) parseToolCallsContent(content string) ([]api.ToolCall, error) {
content = strings.TrimSpace(content)
// Try JSON format first: {"name": "func", "arguments": {...}}
var parsed struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" {
var args api.ToolCallFunctionArguments
if len(parsed.Arguments) > 0 {
if err := json.Unmarshal(parsed.Arguments, &args); err != nil {
return nil, err
}
} else {
args = api.NewToolCallFunctionArguments()
}
return []api.ToolCall{{
Function: api.ToolCallFunction{
Name: parsed.Name,
Arguments: args,
},
}}, nil
}
// Try Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
return p.parsePythonStyleToolCalls(content)
}
// parsePythonStyleToolCalls parses one or more Python-style tool calls
// Examples: [bash(command='ls'),bash(command='pwd')] or bash(command='ls')
func (p *LFM2Parser) parsePythonStyleToolCalls(content string) ([]api.ToolCall, error) {
content = strings.TrimSpace(content)
// Strip outer brackets if present: [func(...)] -> func(...)
if strings.HasPrefix(content, "[") && strings.HasSuffix(content, "]") {
content = content[1 : len(content)-1]
}
var toolCalls []api.ToolCall
// Parse multiple function calls separated by commas at the top level
for len(content) > 0 {
content = strings.TrimSpace(content)
if content == "" {
break
}
// Skip leading comma from previous iteration
if strings.HasPrefix(content, ",") {
content = strings.TrimSpace(content[1:])
if content == "" {
break
}
}
// Find function name
parenIdx := strings.Index(content, "(")
if parenIdx == -1 {
return nil, errors.New("invalid tool call: no opening parenthesis")
}
funcName := strings.TrimSpace(content[:parenIdx])
if funcName == "" {
return nil, errors.New("invalid tool call: empty function name")
}
// Find matching closing parenthesis
closeIdx := findMatchingParen(content, parenIdx)
if closeIdx == -1 {
return nil, errors.New("invalid tool call: no matching closing parenthesis")
}
argsStr := content[parenIdx+1 : closeIdx]
args := api.NewToolCallFunctionArguments()
if argsStr != "" {
if err := parsePythonArgs(argsStr, &args); err != nil {
return nil, err
}
}
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: funcName,
Arguments: args,
},
})
// Move past this function call
content = content[closeIdx+1:]
}
if len(toolCalls) == 0 {
return nil, errors.New("no tool calls found")
}
return toolCalls, nil
}
// findMatchingParen finds the index of the closing parenthesis matching the one at openIdx
// Returns -1 if not found. Handles nested parentheses and quoted strings.
func findMatchingParen(s string, openIdx int) int {
depth := 1
i := openIdx + 1
for i < len(s) && depth > 0 {
switch s[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 {
return i
}
case '\'', '"':
// Skip quoted string
quote := s[i]
i++
for i < len(s) && s[i] != quote {
if s[i] == '\\' && i+1 < len(s) {
i++ // skip escaped char
}
i++
}
}
i++
}
return -1
}
// parseToolCallContent parses a single tool call (for backward compatibility with tests)
func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error) {
calls, err := p.parseToolCallsContent(content)
if err != nil {
return api.ToolCall{}, err
}
if len(calls) == 0 {
return api.ToolCall{}, errors.New("no tool call found")
}
return calls[0], nil
}
// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2"
func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error {
// Simple state machine to parse key='value' pairs
// Handles: command='ls', flag="-la", count=42, enabled=true
var key string
i := 0
for i < len(argsStr) {
// Skip whitespace
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
i++
}
if i >= len(argsStr) {
break
}
// Parse key
keyStart := i
for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' {
i++
}
if i >= len(argsStr) || argsStr[i] != '=' {
return errors.New("invalid argument: expected '='")
}
key = strings.TrimSpace(argsStr[keyStart:i])
i++ // skip '='
// Skip whitespace after =
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') {
i++
}
// Parse value
var value string
if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') {
// Quoted string
quote := argsStr[i]
i++
valueStart := i
for i < len(argsStr) && argsStr[i] != quote {
if argsStr[i] == '\\' && i+1 < len(argsStr) {
i += 2 // skip escaped char
} else {
i++
}
}
value = argsStr[valueStart:i]
if i < len(argsStr) {
i++ // skip closing quote
}
args.Set(key, value)
} else {
// Unquoted value (number, bool, etc)
valueStart := i
for i < len(argsStr) && argsStr[i] != ',' {
i++
}
value = strings.TrimSpace(argsStr[valueStart:i])
// Try to parse as number or bool
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
args.Set(key, v)
} else if v, err := strconv.ParseFloat(value, 64); err == nil {
args.Set(key, v)
} else if value == "true" {
args.Set(key, true)
} else if value == "false" {
args.Set(key, false)
} else {
args.Set(key, value)
}
}
// Skip comma and whitespace
for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
i++
}
}
return nil
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -68,12 +68,6 @@ func ParserForName(name string) Parser {
return &Nemotron3NanoParser{}
case "functiongemma":
return &FunctionGemmaParser{}
case "glm-4.7":
return &GLM47Parser{}
case "lfm2":
return &LFM2Parser{hasThinkingSupport: false}
case "lfm2-thinking":
return &LFM2Parser{hasThinkingSupport: true}
default:
return nil
}

View File

@@ -96,11 +96,3 @@ func testArgs(m map[string]any) api.ToolCallFunctionArguments {
}
return args
}
func args(s string) api.ToolCallFunctionArguments {
var result api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(s), &result); err != nil {
panic("invalid JSON in args(): " + err.Error())
}
return result
}

View File

@@ -1,110 +0,0 @@
package renderers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
type GLM46Renderer struct{}
func (r *GLM46Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
var lastUserIndex int
for i, message := range messages {
if message.Role == "user" {
lastUserIndex = i
}
}
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(string(d) + "\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}\n")
sb.WriteString("<arg_key>{arg-key-1}</arg_key>\n")
sb.WriteString("<arg_value>{arg-value-1}</arg_value>\n")
sb.WriteString("<arg_key>{arg-key-2}</arg_key>\n")
sb.WriteString("<arg_value>{arg-value-2}</arg_value>\n")
sb.WriteString("...\n")
sb.WriteString("</tool_call>")
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>\n")
sb.WriteString(message.Content)
if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") {
sb.WriteString("/nothink")
}
case "assistant":
sb.WriteString("<|assistant|>")
if i > lastUserIndex {
if message.Thinking != "" {
sb.WriteString("\n<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("\n<think></think>")
}
}
if message.Content != "" {
sb.WriteString("\n" + message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>" + toolCall.Function.Name + "\n")
for key, value := range toolCall.Function.Arguments.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>\n")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>\n")
}
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("\n<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>")
case "system":
sb.WriteString("<|system|>\n")
sb.WriteString(message.Content)
}
}
// Add generation prompt
sb.WriteString("<|assistant|>")
if thinkValue != nil && !thinkValue.Bool() {
sb.WriteString("\n<think></think>\n")
}
return sb.String(), nil
}

View File

@@ -1,223 +0,0 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestGLM46Renderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
skip string
}{
{
name: "basic",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
expected: `[gMASK]<sop><|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
},
expected: `[gMASK]<sop><|system|>
You are a helpful assistant.<|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "basic with user assistant user",
messages: []api.Message{
{Role: "user", Content: "What is the capital of France?"},
{Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."},
{Role: "user", Content: "Fantastic!"},
},
expected: `[gMASK]<sop><|user|>
What is the capital of France?<|assistant|>
The capital of France is Paris.<|user|>
Fantastic!<|assistant|>`,
},
{
skip: "tool call ordering not guaranteed yet",
name: "tools",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in Tokyo?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
},
},
},
},
expected: `[gMASK]<sop><|system|>
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
</tools>
For each function call, output the function name and arguments within the following XML format:
<tool_call>{function-name}
<arg_key>{arg-key-1}</arg_key>
<arg_value>{arg-value-1}</arg_value>
<arg_key>{arg-key-2}</arg_key>
<arg_value>{arg-value-2}</arg_value>
...
</tool_call><|system|>
You are a helpful assistant with access to tools.<|user|>
What is the weather like in Tokyo?<|assistant|>`,
},
{
skip: "tool call ordering not guaranteed yet",
name: "tool calls",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in Tokyo?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Tokyo, Japan", "unit": "celsius"}`),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Japan", "unit": "fahrenheit"}`),
},
},
},
},
{
Role: "tool",
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
ToolName: "get_weather",
},
{
Role: "tool",
Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}",
ToolName: "get_weather",
},
{
Role: "assistant",
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
},
},
},
},
expected: `[gMASK]<sop><|system|>
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
</tools>
For each function call, output the function name and arguments within the following XML format:
<tool_call>{function-name}
<arg_key>{arg-key-1}</arg_key>
<arg_value>{arg-value-1}</arg_value>
<arg_key>{arg-key-2}</arg_key>
<arg_value>{arg-value-2}</arg_value>
...
</tool_call><|system|>
You are a helpful assistant with access to tools.<|user|>
What is the weather like in Tokyo?<|assistant|>
<think></think>
<tool_call>get_weather
<arg_key>location</arg_key>
<arg_value>Tokyo, Japan</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>
</tool_call>
<tool_call>get_weather
<arg_key>location</arg_key>
<arg_value>Japan</arg_value>
<arg_key>unit</arg_key>
<arg_value>fahrenheit</arg_value>
</tool_call><|observation|>
<tool_response>
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
</tool_response>
<tool_response>
{"temperature": 68, "weather": "sunny", "humidity": 75}
</tool_response><|assistant|>
<think></think>
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
},
{
name: "think true",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `[gMASK]<sop><|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "think false",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `[gMASK]<sop><|user|>
Hello, how are you?/nothink<|assistant|>
<think></think>
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.skip != "" {
t.Skip(tt.skip)
}
renderer := &GLM46Renderer{}
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
t.Logf("Got:\n%s", rendered)
t.Logf("Expected:\n%s", tt.expected)
}
})
}
}

View File

@@ -1,170 +0,0 @@
package renderers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
// GLM47Renderer renders messages for GLM-4.7 models.
//
// GLM-4.7 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
//
// 1. INTERLEAVED THINKING
// The model thinks between tool calls and after receiving tool results.
// This enables complex step-by-step reasoning: interpreting each tool output
// before deciding what to do next. Thinking blocks are preserved and returned
// with tool results to maintain reasoning continuity.
//
// 2. PRESERVED THINKING
// The model retains reasoning content from previous assistant turns in context.
// This preserves reasoning continuity across multi-turn conversations. The
// upstream API has a "clear_thinking" parameter to control this:
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
//
// 3. TURN-LEVEL THINKING
// Controls whether the model should reason on each turn. The upstream API
// uses "enable_thinking" parameter:
// - enable_thinking=true: outputs <think> to start reasoning
// - enable_thinking=false: outputs </think> to skip reasoning
//
// OLLAMA DEFAULTS:
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
// - Users can disable thinking per-turn via thinkValue=false
type GLM47Renderer struct{}
func (r *GLM47Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(formatGLM47ToolJSON(d))
sb.WriteString("\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
}
think := true
if thinkValue != nil && !thinkValue.Bool() {
think = false
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>")
sb.WriteString(message.Content)
case "assistant":
sb.WriteString("<|assistant|>")
if message.Thinking != "" {
sb.WriteString("<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("</think>")
}
if message.Content != "" {
sb.WriteString(message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("<tool_call>" + toolCall.Function.Name)
sb.WriteString(renderGLM47ToolArguments(toolCall.Function.Arguments))
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("<tool_response>")
sb.WriteString(message.Content)
sb.WriteString("</tool_response>")
case "system":
sb.WriteString("<|system|>")
sb.WriteString(message.Content)
}
}
sb.WriteString("<|assistant|>")
if think {
sb.WriteString("<think>")
} else {
sb.WriteString("</think>")
}
return sb.String(), nil
}
func renderGLM47ToolArguments(args api.ToolCallFunctionArguments) string {
var sb strings.Builder
for key, value := range args.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
}
return sb.String()
}
func formatGLM47ToolJSON(raw []byte) string {
var sb strings.Builder
sb.Grow(len(raw) + len(raw)/10)
inString := false
escaped := false
for i := range raw {
ch := raw[i]
sb.WriteByte(ch)
if inString {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
continue
}
if ch == ':' || ch == ',' {
sb.WriteByte(' ')
}
}
return sb.String()
}

View File

@@ -1,191 +0,0 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestGLM47Renderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
}{
{
name: "basic user message",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
expected: "[gMASK]<sop><|user|>Hello<|assistant|><think>",
},
{
name: "thinking disabled",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "[gMASK]<sop><|user|>Hello<|assistant|></think>",
},
{
name: "system and user",
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
expected: "[gMASK]<sop><|system|>You are helpful.<|user|>Hello<|assistant|><think>",
},
{
name: "multi-turn conversation",
messages: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello there"},
{Role: "user", Content: "How are you?"},
},
expected: "[gMASK]<sop><|user|>Hi<|assistant|></think>Hello there<|user|>How are you?<|assistant|><think>",
},
{
name: "assistant with reasoning_content",
messages: []api.Message{
{Role: "user", Content: "Answer with reasoning."},
{Role: "assistant", Thinking: "Plan.", Content: "Done."},
},
expected: "[gMASK]<sop><|user|>Answer with reasoning.<|assistant|><think>Plan.</think>Done.<|assistant|><think>",
},
{
name: "tool call with empty content",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Tokyo", "unit": "celsius"}`),
},
},
},
},
{Role: "tool", Content: `{"temperature":22}`},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string"}}`),
},
},
},
},
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|><think>",
},
{
name: "tool call with content",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
Content: "Let me check",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Tokyo"}`),
},
},
},
},
{Role: "tool", Content: `{"temperature":22}`},
{Role: "assistant", Content: "It is 22C."},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string"}}`),
},
},
},
},
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think>Let me check<tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|></think>It is 22C.<|assistant|><think>",
},
{
name: "multiple tool calls and responses",
messages: []api.Message{
{Role: "user", Content: "Compare weather"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Tokyo"}`),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Paris"}`),
},
},
},
},
{Role: "tool", Content: `{"temperature":22}`},
{Role: "tool", Content: `{"temperature":18}`},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string"}}`),
},
},
},
},
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Compare weather<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><tool_call>get_weather<arg_key>location</arg_key><arg_value>Paris</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><tool_response>{\"temperature\":18}</tool_response><|assistant|><think>",
},
{
name: "preserved thinking in multi-turn",
messages: []api.Message{
{Role: "user", Content: "Think step by step"},
{Role: "assistant", Thinking: "Let me think...", Content: "Here's my answer."},
{Role: "user", Content: "Continue"},
},
expected: "[gMASK]<sop><|user|>Think step by step<|assistant|><think>Let me think...</think>Here's my answer.<|user|>Continue<|assistant|><think>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
renderer := &GLM47Renderer{}
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
t.Logf("Got:\n%s", rendered)
t.Logf("Expected:\n%s", tt.expected)
}
})
}
}

View File

@@ -1,144 +0,0 @@
package renderers
import (
"encoding/json"
"strings"
"github.com/ollama/ollama/api"
)
type LFM2Renderer struct {
IsThinking bool
}
func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
// Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer
// Extract first system message if present (to combine with tools)
var firstSystemContent string
startIdx := 0
if len(messages) > 0 && messages[0].Role == "system" {
firstSystemContent = messages[0].Content
startIdx = 1
}
// Append tools to first system content
if len(tools) > 0 {
if firstSystemContent != "" {
firstSystemContent += "\n"
}
firstSystemContent += "List of tools: ["
for i, tool := range tools {
toolJSON, err := json.Marshal(tool)
if err != nil {
return "", err
}
firstSystemContent += string(toolJSON)
if i < len(tools)-1 {
firstSystemContent += ", "
}
}
firstSystemContent += "]"
}
// Output first system block if it has content
if firstSystemContent != "" {
sb.WriteString("<|im_start|>system\n")
sb.WriteString(firstSystemContent)
sb.WriteString("<|im_end|>\n")
}
// Find the index of the last assistant message for thinking stripping
lastAssistantIndex := -1
for i := len(messages) - 1; i >= startIdx; i-- {
if messages[i].Role == "assistant" {
lastAssistantIndex = i
break
}
}
// Track whether we need to add generation prompt
needsGenerationPrompt := len(messages) > 0
for i := startIdx; i < len(messages); i++ {
message := messages[i]
switch message.Role {
case "system":
// Additional system messages (after the first) are rendered normally
sb.WriteString("<|im_start|>system\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
case "user":
sb.WriteString("<|im_start|>user\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true
case "assistant":
sb.WriteString("<|im_start|>assistant\n")
// Check if this is the last assistant message
isLastAssistant := i == lastAssistantIndex
// Process content (may need thinking stripped)
content := message.Content
// Handle thinking tags in assistant content
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
if strings.Contains(content, "</think>") {
parts := strings.SplitN(content, "</think>", 2)
if len(parts) > 1 {
if !isLastAssistant && !keepPastThinking {
// Strip thinking entirely for past assistant messages
content = strings.TrimSpace(parts[1])
} else {
// Preserve thinking but trim whitespace after </think>
content = parts[0] + "</think>" + strings.TrimLeft(parts[1], " \t\n\r")
}
}
}
if len(message.ToolCalls) > 0 {
// Assistant with tool calls - write content first (if any after stripping)
if content != "" {
sb.WriteString(content)
}
for _, toolCall := range message.ToolCalls {
sb.WriteString("<|tool_call_start|>")
toolCallJSON := map[string]any{
"name": toolCall.Function.Name,
"arguments": toolCall.Function.Arguments,
}
callJSON, _ := json.Marshal(toolCallJSON)
sb.WriteString(string(callJSON))
sb.WriteString("<|tool_call_end|>")
}
} else {
sb.WriteString(content)
}
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true
case "tool":
// Tool responses are rendered as plain messages per the chat template
sb.WriteString("<|im_start|>tool\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true
}
}
// Add generation prompt
if needsGenerationPrompt {
sb.WriteString("<|im_start|>assistant\n")
// Note: Model is a "thinking-only" model - it will output <think> itself
// We don't add <think> tag to the prompt
}
return sb.String(), nil
}

View File

@@ -1,427 +0,0 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestLFM2Renderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
}{
{
name: "basic user message",
messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple system messages rendered separately",
messages: []api.Message{
{Role: "system", Content: "First instruction."},
{Role: "system", Content: "Second instruction."},
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multi-turn conversation",
messages: []api.Message{
{Role: "user", Content: "What is 2+2?"},
{Role: "assistant", Content: "The answer is 4."},
{Role: "user", Content: "Thanks!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "only system message",
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n",
},
{
// When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false)
name: "user-assistant-user: last assistant preserves thinking",
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reasoning</think>A1"},
{Role: "user", Content: "Q2"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reasoning</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n",
},
{
// With two assistants, first is stripped (not last), second preserved (is last)
name: "multi-turn thinking: first stripped, second preserved",
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reason1</think>A1"},
{Role: "user", Content: "Q2"},
{Role: "assistant", Content: "<think>reason2</think>A2"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
},
{
// With thinking enabled (keep_past_thinking=true), both preserved
name: "multi-turn thinking: both preserved when thinking enabled",
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reason1</think>A1"},
{Role: "user", Content: "Q2"},
{Role: "assistant", Content: "<think>reason2</think>A2"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "assistant with tool calls",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "assistant with content and tool calls",
messages: []api.Message{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: "Let me check.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tool response",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{Role: "assistant", Content: "Let me check."},
{Role: "tool", Content: "22C, Sunny"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple tool calls",
messages: []api.Message{
{Role: "user", Content: "Get weather for Paris and London"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "London",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tools definitions with system message",
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "What's the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tools definitions without system message",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "multiple tools without system message",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "get_time",
Description: "Get time",
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nList of tools: [{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather\",\"parameters\":{\"type\":\"\",\"properties\":null}}}, {\"type\":\"function\",\"function\":{\"name\":\"get_time\",\"description\":\"Get time\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "user-tool sequence",
messages: []api.Message{
{Role: "user", Content: "Check weather"},
{Role: "tool", Content: "22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "full tool call cycle",
messages: []api.Message{
{Role: "user", Content: "Check weather"},
{Role: "assistant", Content: "Let me check"},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "It's 22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "unicode content",
messages: []api.Message{
{Role: "user", Content: "你好世界! مرحبا 🌍"},
{Role: "assistant", Content: "Hello! 👋"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "newlines in content",
messages: []api.Message{
{Role: "user", Content: "Line 1\nLine 2\n\nLine 4"},
{Role: "assistant", Content: "Response with\nmultiple\nlines"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "empty assistant content",
messages: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: ""},
{Role: "user", Content: "OK"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n",
},
{
// Generation prompt does NOT include <think> - model outputs it
name: "generation prompt has no think tag",
messages: []api.Message{
{Role: "user", Content: "Think hard"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n",
},
{
// Interleaved: thinking before tool call - last assistant preserves thinking
name: "thinking before tool call (last assistant)",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>I need to check the weather</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>I need to check the weather</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
},
{
// Two assistants with tool calls - first has thinking stripped
name: "two assistants with tools: first thinking stripped",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>checking</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
},
{
// Two assistants with tools - both preserved when thinking enabled
name: "two assistants with tools: both preserved when thinking enabled",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>checking</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>checking</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
},
{
// Content before thinking before tool call
name: "content then thinking then tool call",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "Let me check.<think>Using weather API</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<think>Using weather API</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
},
}
renderer := &LFM2Renderer{IsThinking: true}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatalf("Render() error = %v", err)
}
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -80,12 +80,6 @@ func rendererForName(name string) Renderer {
return &Nemotron3NanoRenderer{}
case "functiongemma":
return &FunctionGemmaRenderer{}
case "glm-4.7":
return &GLM47Renderer{}
case "lfm2":
return &LFM2Renderer{IsThinking: false}
case "lfm2-thinking":
return &LFM2Renderer{IsThinking: true}
default:
return nil
}

View File

@@ -1,26 +1,6 @@
package renderers
import (
"encoding/json"
"github.com/ollama/ollama/api"
)
func args(s string) api.ToolCallFunctionArguments {
var result api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(s), &result); err != nil {
panic("invalid JSON in args(): " + err.Error())
}
return result
}
func propsMap(s string) *api.ToolPropertiesMap {
var result api.ToolPropertiesMap
if err := json.Unmarshal([]byte(s), &result); err != nil {
panic("invalid JSON in propsMap(): " + err.Error())
}
return &result
}
import "github.com/ollama/ollama/api"
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {

View File

@@ -3,22 +3,20 @@ package parser
import (
"bufio"
"bytes"
"crypto/sha256"
"errors"
"fmt"
"io"
"io/fs"
"maps"
"net/http"
"os"
"os/user"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"sync"
"golang.org/x/mod/semver"
"golang.org/x/sync/errgroup"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
@@ -54,7 +52,10 @@ var deprecatedParameters = []string{
// CreateRequest creates a new *api.CreateRequest from an existing Modelfile
func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) {
req := &api.CreateRequest{}
req := &api.CreateRequest{
Files: make(map[string]string),
Adapters: make(map[string]string),
}
var messages []api.Message
var licenses []string
@@ -63,12 +64,7 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
for _, c := range f.Commands {
switch c.Name {
case "model":
path, err := expandPath(c.Args, relativeDir)
if err != nil {
return nil, err
}
digestMap, err := fileDigestMap(path)
files, err := filesMap(c.Args, relativeDir)
if errors.Is(err, os.ErrNotExist) {
req.From = c.Args
continue
@@ -76,25 +72,14 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
return nil, err
}
if req.Files == nil {
req.Files = digestMap
} else {
for k, v := range digestMap {
req.Files[k] = v
}
}
maps.Copy(req.Files, files)
case "adapter":
path, err := expandPath(c.Args, relativeDir)
files, err := filesMap(c.Args, relativeDir)
if err != nil {
return nil, err
}
digestMap, err := fileDigestMap(path)
if err != nil {
return nil, err
}
req.Adapters = digestMap
maps.Copy(req.Adapters, files)
case "template":
req.Template = c.Args
case "system":
@@ -154,106 +139,66 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
return req, nil
}
func fileDigestMap(path string) (map[string]string, error) {
fl := make(map[string]string)
func filesMap(args, base string) (map[string]string, error) {
path, err := expandPath(args, base)
if err != nil {
return nil, err
}
fi, err := os.Stat(path)
if err != nil {
return nil, err
}
var files []string
if fi.IsDir() {
fs, err := filesForModel(path)
if err != nil {
return nil, err
}
for _, f := range fs {
f, err := filepath.EvalSymlinks(f)
if err != nil {
return nil, err
}
rel, err := filepath.Rel(path, f)
if err != nil {
return nil, err
}
if !filepath.IsLocal(rel) {
return nil, fmt.Errorf("insecure path: %s", rel)
}
files = append(files, f)
}
} else {
files = []string{path}
mapping := make(map[string]string)
if !fi.IsDir() {
return map[string]string{
filepath.Base(path): "abs:" + path,
}, nil
}
var mu sync.Mutex
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
for _, f := range files {
g.Go(func() error {
digest, err := digestForFile(f)
if err != nil {
return err
}
mu.Lock()
defer mu.Unlock()
fl[f] = digest
return nil
})
root, err := os.OpenRoot(path)
if err != nil {
return nil, err
}
defer root.Close()
if err := g.Wait(); err != nil {
files, err := filesForModel(root)
if err != nil {
return nil, err
}
return fl, nil
for _, file := range files {
// create a temporary mapping from relative path to absolute path
mapping[file] = "abs:" + filepath.Join(root.Name(), file)
}
return mapping, nil
}
func digestForFile(filename string) (string, error) {
filepath, err := filepath.EvalSymlinks(filename)
if err != nil {
return "", err
}
bin, err := os.Open(filepath)
if err != nil {
return "", err
}
defer bin.Close()
hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return "", err
}
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
}
func filesForModel(path string) ([]string, error) {
func filesForModel(root *os.Root) ([]string, error) {
detectContentType := func(path string) (string, error) {
f, err := os.Open(path)
f, err := root.Open(path)
if err != nil {
return "", err
}
defer f.Close()
var b bytes.Buffer
b.Grow(512)
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
bts := make([]byte, 512)
n, err := io.ReadFull(f, bts)
if errors.Is(err, io.ErrUnexpectedEOF) {
// short read, use what we have
bts = bts[:n]
} else if err != nil {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";")
return contentType, nil
}
glob := func(pattern, contentType string) ([]string, error) {
matches, err := filepath.Glob(pattern)
matches, err := fs.Glob(root.FS(), pattern)
if err != nil {
return nil, err
}
@@ -262,7 +207,7 @@ func filesForModel(path string) ([]string, error) {
if ct, err := detectContentType(match); err != nil {
return nil, err
} else if len(contentType) > 0 && ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
return nil, fmt.Errorf("invalid content type: expected %s for %s, got %s", ct, match, contentType)
}
}
@@ -271,25 +216,25 @@ func filesForModel(path string) ([]string, error) {
var files []string
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
if st, _ := glob(filepath.Join(path, "model*.safetensors"), ""); len(st) > 0 {
if st, _ := glob("model*.safetensors", ""); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "consolidated*.safetensors"), ""); len(st) > 0 {
} else if st, _ := glob("consolidated*.safetensors", ""); len(st) > 0 {
// covers consolidated.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
} else if pt, _ := glob("pytorch_model*.bin", "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
files = append(files, pt...)
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
} else if pt, _ := glob("consolidated*.pth", "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers consolidated.x.pth, consolidated.pth
files = append(files, pt...)
} else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 {
} else if gg, _ := glob("*.gguf", "application/octet-stream"); len(gg) > 0 {
// covers gguf files ending in .gguf
files = append(files, gg...)
} else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 {
} else if gg, _ := glob("*.bin", "application/octet-stream"); len(gg) > 0 {
// covers gguf files ending in .bin
files = append(files, gg...)
} else {
@@ -297,7 +242,7 @@ func filesForModel(path string) ([]string, error) {
}
// add configuration files, json files are detected as text/plain
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
js, err := glob("*.json", "text/plain")
if err != nil {
return nil, err
}
@@ -305,7 +250,7 @@ func filesForModel(path string) ([]string, error) {
// bert models require a nested config.json
// TODO(mxyng): merge this with the glob above
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain")
js, err = glob("**/*.json", "text/plain")
if err != nil {
return nil, err
}
@@ -313,9 +258,9 @@ func filesForModel(path string) ([]string, error) {
// add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob)
// tokenizer.model might be a unresolved git lfs reference; error if it is
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
if tks, _ := glob("tokenizer.model", "application/octet-stream"); len(tks) > 0 {
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
} else if tks, _ := glob("**/tokenizer.model", "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}

View File

@@ -2,7 +2,6 @@ package parser
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
@@ -15,6 +14,7 @@ import (
"unicode/utf16"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/encoding"
@@ -775,25 +775,13 @@ MESSAGE assistant Hi! How are you?
t.Error(err)
}
if diff := cmp.Diff(actual, c.expected); diff != "" {
if diff := cmp.Diff(actual, c.expected, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}
func getSHA256Digest(t *testing.T, r io.Reader) (string, int64) {
t.Helper()
h := sha256.New()
n, err := io.Copy(h, r)
if err != nil {
t.Fatal(err)
}
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, string) {
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "testbin.*.gguf")
@@ -808,19 +796,12 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
if err := ggml.WriteGGUF(f, base, ti); err != nil {
t.Fatal(err)
}
// Calculate sha256 of file
if _, err := f.Seek(0, 0); err != nil {
t.Fatal(err)
}
digest, _ := getSHA256Digest(t, f)
return f.Name(), digest
return f.Name()
}
func TestCreateRequestFiles(t *testing.T) {
n1, d1 := createBinFile(t, nil, nil)
n2, d2 := createBinFile(t, map[string]any{"foo": "bar"}, nil)
n1 := createBinFile(t, nil, nil)
n2 := createBinFile(t, map[string]any{"foo": "bar"}, nil)
cases := []struct {
input string
@@ -828,11 +809,20 @@ func TestCreateRequestFiles(t *testing.T) {
}{
{
fmt.Sprintf("FROM %s", n1),
&api.CreateRequest{Files: map[string]string{n1: d1}},
&api.CreateRequest{
Files: map[string]string{
filepath.Base(n1): "abs:" + n1,
},
},
},
{
fmt.Sprintf("FROM %s\nFROM %s", n1, n2),
&api.CreateRequest{Files: map[string]string{n1: d1, n2: d2}},
&api.CreateRequest{
Files: map[string]string{
filepath.Base(n1): "abs:" + n1,
filepath.Base(n2): "abs:" + n2,
},
},
},
}
@@ -852,7 +842,7 @@ func TestCreateRequestFiles(t *testing.T) {
t.Error(err)
}
if diff := cmp.Diff(actual, c.expected); diff != "" {
if diff := cmp.Diff(actual, c.expected, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
@@ -860,15 +850,15 @@ func TestCreateRequestFiles(t *testing.T) {
func TestFilesForModel(t *testing.T) {
tests := []struct {
name string
setup func(string) error
wantFiles []string
wantErr bool
expectErrType error
name string
setup func(*testing.T, *os.Root)
want []string
wantErr error
}{
{
name: "safetensors model files",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
t.Helper()
files := []string{
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
@@ -876,13 +866,12 @@ func TestFilesForModel(t *testing.T) {
"tokenizer.json",
}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
return err
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
"config.json",
@@ -891,7 +880,7 @@ func TestFilesForModel(t *testing.T) {
},
{
name: "safetensors with both tokenizer.json and tokenizer.model",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
// Create binary content for tokenizer.model (application/octet-stream)
binaryContent := make([]byte, 512)
for i := range binaryContent {
@@ -903,17 +892,16 @@ func TestFilesForModel(t *testing.T) {
"tokenizer.json",
}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
return err
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
t.Fatal(err)
}
}
// Write tokenizer.model as binary
if err := os.WriteFile(filepath.Join(dir, "tokenizer.model"), binaryContent, 0o644); err != nil {
return err
if err := root.WriteFile("tokenizer.model", binaryContent, 0o644); err != nil {
t.Fatal(err)
}
return nil
},
wantFiles: []string{
want: []string{
"model-00001-of-00001.safetensors",
"config.json",
"tokenizer.json",
@@ -922,46 +910,44 @@ func TestFilesForModel(t *testing.T) {
},
{
name: "safetensors with consolidated files - prefers model files",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
files := []string{
"model-00001-of-00001.safetensors",
"consolidated.safetensors",
"config.json",
}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
return err
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"model-00001-of-00001.safetensors", // consolidated files should be excluded
"config.json",
},
},
{
name: "safetensors without model-.safetensors files - uses consolidated",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
files := []string{
"consolidated.safetensors",
"config.json",
}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
return err
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"consolidated.safetensors",
"config.json",
},
},
{
name: "pytorch model files",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
// Create a file that will be detected as application/zip
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} // PK zip header
files := []string{
@@ -974,13 +960,12 @@ func TestFilesForModel(t *testing.T) {
if file == "config.json" {
content = []byte(`{"config": true}`)
}
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
if err := root.WriteFile(file, content, 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"pytorch_model-00001-of-00002.bin",
"pytorch_model-00002-of-00002.bin",
"config.json",
@@ -988,7 +973,7 @@ func TestFilesForModel(t *testing.T) {
},
{
name: "consolidated pth files",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04}
files := []string{
"consolidated.00.pth",
@@ -1000,13 +985,12 @@ func TestFilesForModel(t *testing.T) {
if file == "config.json" {
content = []byte(`{"config": true}`)
}
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
if err := root.WriteFile(file, content, 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"consolidated.00.pth",
"consolidated.01.pth",
"config.json",
@@ -1014,7 +998,7 @@ func TestFilesForModel(t *testing.T) {
},
{
name: "gguf files",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
// Create binary content that will be detected as application/octet-stream
binaryContent := make([]byte, 512)
for i := range binaryContent {
@@ -1029,20 +1013,19 @@ func TestFilesForModel(t *testing.T) {
if file == "config.json" {
content = []byte(`{"config": true}`)
}
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
if err := root.WriteFile(file, content, 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"model.gguf",
"config.json",
},
},
{
name: "bin files as gguf",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
binaryContent := make([]byte, 512)
for i := range binaryContent {
binaryContent[i] = byte(i % 256)
@@ -1056,35 +1039,32 @@ func TestFilesForModel(t *testing.T) {
if file == "config.json" {
content = []byte(`{"config": true}`)
}
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
if err := root.WriteFile(file, content, 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"model.bin",
"config.json",
},
},
{
name: "no model files found",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
// Only create non-model files
files := []string{"README.md", "config.json"}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("content"), 0o644); err != nil {
return err
if err := root.WriteFile(file, []byte("content"), 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantErr: true,
expectErrType: ErrModelNotFound,
wantErr: ErrModelNotFound,
},
{
name: "invalid content type for pytorch model",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
// Create pytorch model file with wrong content type (text instead of zip)
files := []string{
"pytorch_model.bin",
@@ -1092,68 +1072,32 @@ func TestFilesForModel(t *testing.T) {
}
for _, file := range files {
content := []byte("plain text content")
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
if err := root.WriteFile(file, content, 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantErr: true,
wantErr: ErrModelNotFound,
},
}
tmpDir := t.TempDir()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testDir := filepath.Join(tmpDir, tt.name)
if err := os.MkdirAll(testDir, 0o755); err != nil {
t.Fatalf("Failed to create test directory: %v", err)
}
if err := tt.setup(testDir); err != nil {
t.Fatalf("Setup failed: %v", err)
}
files, err := filesForModel(testDir)
if tt.wantErr {
if err == nil {
t.Error("Expected error, but got none")
}
if tt.expectErrType != nil && err != tt.expectErrType {
t.Errorf("Expected error type %v, got %v", tt.expectErrType, err)
}
return
}
root, err := os.OpenRoot(t.TempDir())
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
t.Fatalf("Failed to open root: %v", err)
}
defer root.Close()
tt.setup(t, root)
files, err := filesForModel(root)
if !errors.Is(err, tt.wantErr) {
t.Fatalf("want %v error, got %v", tt.wantErr, err)
}
var relativeFiles []string
for _, file := range files {
rel, err := filepath.Rel(testDir, file)
if err != nil {
t.Fatalf("Failed to get relative path: %v", err)
}
relativeFiles = append(relativeFiles, rel)
}
if len(relativeFiles) != len(tt.wantFiles) {
t.Errorf("Expected %d files, got %d: %v", len(tt.wantFiles), len(relativeFiles), relativeFiles)
}
fileSet := make(map[string]bool)
for _, file := range relativeFiles {
fileSet[file] = true
}
for _, wantFile := range tt.wantFiles {
if !fileSet[wantFile] {
t.Errorf("Missing expected file: %s", wantFile)
}
if diff := cmp.Diff(tt.want, files); diff != "" {
t.Errorf("filesForModel() mismatch (-want +got):\n%s", diff)
}
})
}

View File

@@ -198,8 +198,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
name := t.Name
quantize := strings.HasSuffix(name, "weight")
// don't quantize vision encoder tensors (named with "v." prefix)
quantize = quantize && !strings.HasPrefix(name, "v.")
// don't quantize vision stuff
quantize = quantize && (!strings.Contains(name, "v.") || strings.Contains(name, "_v."))
quantize = quantize && !strings.Contains(name, "mm.")
// quantize only 2D and 3D tensors (experts)
@@ -219,9 +219,6 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
// NOTE: can't use LLM_TN here because the layer number is not known
quantize = quantize && !strings.Contains(name, "ssm_conv1d.weight")
// do not quantize LFM2's shortconv kernel weights
quantize = quantize && !strings.Contains(name, "shortconv.conv.weight")
// do not quantize RWKV's time_mix_first tensors
quantize = quantize && !strings.Contains(name, "time_mix_first.weight")
quantize = quantize && !strings.Contains(name, "time_mix_w1.weight")

View File

@@ -220,6 +220,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Handle image generation models
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
s.handleImageGenerate(c, req, name.String(), checkpointStart)
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return
@@ -315,7 +321,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// expire the runner if unload is requested (empty prompt, keep alive is 0)
// expire the runner
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m)
@@ -329,12 +335,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Handle image generation models
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
s.handleImageGenerate(c, req, name.String(), checkpointStart)
return
}
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
@@ -1149,9 +1149,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
Capabilities: m.Capabilities(),
ModifiedAt: manifest.fi.ModTime(),
Requires: m.Config.Requires,
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
// default we return an empty map.
ModelInfo: make(map[string]any),
}
if m.Config.RemoteHost != "" {

View File

@@ -2101,95 +2101,3 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
}
})
}
func TestGenerateUnload(t *testing.T) {
gin.SetMode(gin.TestMode)
var loadFnCalled bool
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mockRunner{}),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFnCalled = true
req.successCh <- &runnerRef{llama: &mockRunner{}}
return false
},
},
}
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test",
Files: map[string]string{"file.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("unload with empty prompt and keepalive 0", func(t *testing.T) {
loadFnCalled = false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "",
KeepAlive: &api.Duration{Duration: 0},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp api.GenerateResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.DoneReason != "unload" {
t.Errorf("expected done_reason 'unload', got %q", resp.DoneReason)
}
if !resp.Done {
t.Error("expected done to be true")
}
if loadFnCalled {
t.Error("expected model NOT to be loaded for unload request, but loadFn was called")
}
})
}

View File

@@ -1,38 +0,0 @@
package syncmap
import (
"maps"
"sync"
)
// SyncMap is a simple, generic thread-safe map implementation.
type SyncMap[K comparable, V any] struct {
mu sync.RWMutex
m map[K]V
}
func NewSyncMap[K comparable, V any]() *SyncMap[K, V] {
return &SyncMap[K, V]{
m: make(map[K]V),
}
}
func (s *SyncMap[K, V]) Load(key K) (V, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
val, ok := s.m[key]
return val, ok
}
func (s *SyncMap[K, V]) Store(key K, value V) {
s.mu.Lock()
defer s.mu.Unlock()
s.m[key] = value
}
func (s *SyncMap[K, V]) Items() map[K]V {
s.mu.RLock()
defer s.mu.RUnlock()
// shallow copy map items
return maps.Clone(s.m)
}

View File

@@ -6,9 +6,8 @@ import (
"io"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/ollama/ollama/envconfig"
)
// ManifestLayer represents a layer in the manifest.
@@ -33,15 +32,31 @@ type ModelManifest struct {
BlobDir string
}
// DefaultBlobDir returns the default blob storage directory.
func DefaultBlobDir() string {
return filepath.Join(envconfig.Models(), "blobs")
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
switch runtime.GOOS {
case "darwin":
return filepath.Join(home, ".ollama", "models", "blobs")
case "linux":
return filepath.Join(home, ".ollama", "models", "blobs")
case "windows":
return filepath.Join(home, ".ollama", "models", "blobs")
default:
return filepath.Join(home, ".ollama", "models", "blobs")
}
}
// DefaultManifestDir returns the manifest storage directory.
// Respects OLLAMA_MODELS.
// DefaultManifestDir returns the default manifest storage directory.
func DefaultManifestDir() string {
return filepath.Join(envconfig.Models(), "manifests")
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
return filepath.Join(home, ".ollama", "models", "manifests")
}
// LoadManifest loads a manifest for the given model name.

View File

@@ -1,26 +0,0 @@
package imagegen
import (
"path/filepath"
"testing"
)
func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
modelsDir := filepath.Join(t.TempDir(), "models")
// Simulate packaged/systemd environment
t.Setenv("OLLAMA_MODELS", modelsDir)
t.Setenv("HOME", "/usr/share/ollama")
// Manifest dir must respect OLLAMA_MODELS
wantManifest := filepath.Join(modelsDir, "manifests")
if got := DefaultManifestDir(); got != wantManifest {
t.Fatalf("DefaultManifestDir() = %q, want %q", got, wantManifest)
}
// Blob dir must respect OLLAMA_MODELS
wantBlobs := filepath.Join(modelsDir, "blobs")
if got := DefaultBlobDir(); got != wantBlobs {
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
}
}