Compare commits

...

1 Commits

Author SHA1 Message Date
jmorganca
f7c5f40daa model: add lfm2 2026-01-20 01:56:45 -08:00
15 changed files with 2313 additions and 0 deletions

View File

@@ -313,6 +313,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &deepseek2Model{}
case "Glm4MoeLiteForCausalLM":
conv = &glm4MoeLiteModel{}
case "Lfm2ForCausalLM":
conv = &lfm2Model{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}

115
convert/convert_lfm2.go Normal file
View File

@@ -0,0 +1,115 @@
package convert
import (
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type lfm2Model struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
NormEps float32 `json:"norm_eps"`
ConvLCache uint32 `json:"conv_L_cache"`
LayerTypes []string `json:"layer_types"`
TieEmbedding bool `json:"tie_embedding"`
}
var _ ModelConverter = (*lfm2Model)(nil)
func (p *lfm2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "lfm2"
kv["lfm2.vocab_size"] = p.VocabSize
kv["lfm2.block_count"] = p.NumHiddenLayers
kv["lfm2.embedding_length"] = p.HiddenSize
kv["lfm2.feed_forward_length"] = p.IntermediateSize
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
// Build per-layer head count arrays based on layer_types
headCounts := make([]uint32, p.NumHiddenLayers)
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
for i := uint32(0); i < p.NumHiddenLayers; i++ {
if i < uint32(len(p.LayerTypes)) && p.LayerTypes[i] == "full_attention" {
headCounts[i] = p.NumAttentionHeads
kvHeadCounts[i] = p.NumKeyValueHeads
} else {
// Conv layers have 0 head counts
headCounts[i] = 0
kvHeadCounts[i] = 0
}
}
kv["lfm2.attention.head_count"] = headCounts
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
kv["lfm2.rope.freq_base"] = p.RopeTheta
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
// Renderer and parser config for thinking model
kv["tokenizer.chat_template.renderer"] = "lfm2-thinking"
kv["tokenizer.chat_template.parser"] = "lfm2-thinking"
return kv
}
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
shape := t.Shape()
// Squeeze conv weights from [D, 1, K] to [D, K]
// The shortconv.conv.weight tensors have shape [D, 1, K] in HuggingFace
// We remove the middle dimension (which is always 1) to get [D, K]
// Note: ollama's GGUF writer does NOT reverse shapes like llama.cpp's Python writer does,
// so we keep the shape in the same order as the original safetensors file
if strings.HasSuffix(t.Name(), "shortconv.conv.weight") && len(shape) == 3 && shape[1] == 1 {
// Squeeze: [D, 1, K] -> [D, K]
shape = []uint64{shape[0], shape[2]}
// No repacker needed - data layout is already correct since the middle dim is 1
t.SetRepacker(func(_ string, data []float32, _ []uint64) ([]float32, error) {
return data, nil
})
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
}
return out
}
func (p *lfm2Model) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.embedding_norm", "output_norm",
"model.layers", "blk",
"operator_norm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.out_proj", "attn_output",
"self_attn.q_layernorm", "attn_q_norm",
"self_attn.k_layernorm", "attn_k_norm",
"conv.conv", "shortconv.conv",
"conv.in_proj", "shortconv.in_proj",
"conv.out_proj", "shortconv.out_proj",
"feed_forward.w1", "ffn_gate",
"feed_forward.w2", "ffn_down",
"feed_forward.w3", "ffn_up",
"ffn_norm", "ffn_norm",
}
}

View File

@@ -270,6 +270,7 @@ func (kv KV) OllamaEngineRequired() bool {
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
"glm4moelite",
"lfm2",
}, kv.Architecture())
}

View File

@@ -162,6 +162,7 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
SSMConv(ctx Context, kernel Tensor) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor

View File

@@ -1641,6 +1641,13 @@ func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2,
return tt
}
func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_ssm_conv(ctx.(*Context).ctx, t.t, kernel.(*Tensor).t),
}
}
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{
b: t.b,

358
model/models/lfm2/cache.go Normal file
View File

@@ -0,0 +1,358 @@
package lfm2
import (
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
// HybridCache stores:
// - a standard causal KV cache for attention layers
// - a per-sequence recurrent conv state for shortconv layers
//
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
type HybridCache struct {
kv *kvcache.Causal
backend ml.Backend
dtype ml.DType
maxSequences int
hiddenSize int
dConv int
// slot mapping for recurrent state
slotForSeq map[int]int
refCount []int
freeSlots []int
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
// track any error from EnsureWritable to propagate later
writableError error
}
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
return &HybridCache{
kv: kvcache.NewCausalCache(shift),
hiddenSize: hiddenSize,
dConv: dConv,
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
}
}
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *HybridCache) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *HybridCache) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
// Derive equal-length sequence layout for shortconv.
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
seqCounts := make(map[int]int)
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if _, ok := seqCounts[s]; !ok {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
if len(c.curSeqs) == 0 {
return nil
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
}
}
c.curSeqTokens = want
// Ensure slots exist for sequences in this batch
c.curSlots = c.curSlots[:0]
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
}
c.curSlots = append(c.curSlots, slot)
}
// Create a tensor for the current slots
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
// Reset writable state for new forward pass
c.writableEnsured = false
c.writableError = nil
return nil
}
func (c *HybridCache) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, kvcache.ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *HybridCache) freeSlot(slot int) {
// Bounds check before freeing
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
// Returns an error if slot allocation fails.
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
// Copy existing conv state for all initialized layers
for _, buf := range c.convStates {
// buf: [dConv*hiddenSize, maxSlots]
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
ctx.Forward(buf.SetRows(ctx, src, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
}
}
// Rebuild current slots tensor
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
// On the first write to dst, EnsureWritable will create a private slot.
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
// Bounds check before decrementing
if dstSlot >= 0 && dstSlot < len(c.refCount) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
// src may not have a slot yet; dst will allocate on demand
return
}
// Bounds check before incrementing
if srcSlot >= 0 && srcSlot < len(c.refCount) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *HybridCache) CanResume(seq int, pos int32) bool {
return c.kv.CanResume(seq, pos)
}
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
// For recurrent state, any removal invalidates the state because
// the state at position N depends on all previous positions.
// Drop the slot mapping so it resets on next use.
slot, ok := c.slotForSeq[seq]
if !ok {
return nil
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.curSlotsInput
}
func (c *HybridCache) seqTokens() int {
return c.curSeqTokens
}
func (c *HybridCache) numSeqs() int {
return len(c.curSeqs)
}
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
c.convStates[layer] = buf
return buf
}
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
// Returns an error if copy-on-write allocation fails.
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
}
// UpdateConvState writes a new conv state for current batch sequences.
// newState must have shape [dConv, hiddenSize, nSeqs].
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
ctx.Forward(buf.SetRows(ctx, src, c.slotsTensor()))
}
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *HybridCache) Seqs() []int {
return slices.Clone(c.curSeqs)
}

284
model/models/lfm2/model.go Normal file
View File

@@ -0,0 +1,284 @@
package lfm2
import (
"cmp"
"fmt"
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize int
headDim, ropeDim int
eps, ropeBase, ropeScale float32
ropeType string
originalContextLength int
// per-layer head counts (LFM2 alternates attention and recurrent layers)
numHeadsByLayer []int
numKVHeadsByLayer []int
}
func (o Options) headDimValue() int {
// Head dim is shared across layers; fall back to first attention layer head count.
for _, h := range o.numHeadsByLayer {
if h > 0 {
return cmp.Or(o.headDim, o.hiddenSize/h)
}
}
return cmp.Or(o.headDim, o.hiddenSize)
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
opts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(attnFactor),
)
}
headCount := 1
for _, h := range o.numHeadsByLayer {
if h > 0 {
headCount = h
break
}
}
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/headCount), o.ropeBase, 1./o.ropeScale, opts...)
}
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Options
}
func New(c fs.Config) (model.Model, error) {
if c.Uint("expert_count") > 0 {
// TODO: support mixtures of experts
return nil, model.ErrUnsupportedModel
}
// Tokenizer
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model") {
case "gpt2":
// LFM2 uses a llama3-style BPE pretokenizer.
var pretokenizers []string
switch c.String("tokenizer.ggml.pre") {
case "lfm2", "llama3", "llama-v3", "llama-bpe":
pretokenizers = []string{
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
case "qwen2":
pretokenizers = []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
case "refact":
pretokenizers = []string{
`\p{N}`,
`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`,
}
case "tekken":
pretokenizers = []string{
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
case "default":
// no-op use the default bpe pretokenizer
default:
// use a llama-style pretokenizer
pretokenizers = []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
return nil, fmt.Errorf("unsupported tokenizer: llama")
default:
return nil, model.ErrUnsupportedTokenizer
}
if strings.HasPrefix(c.String("general.name"), "Qwen2-beta") {
return nil, fmt.Errorf("unsupported model: %s", c.String("general.name"))
}
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeType: c.String("rope.scaling.type"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
},
}
type headCounts interface {
HeadCount() []uint64
HeadCountKV() []uint64
}
hc, ok := c.(headCounts)
if !ok {
return nil, model.ErrUnsupportedModel
}
headCount := hc.HeadCount()
headCountKV := hc.HeadCountKV()
m.numHeadsByLayer = make([]int, len(m.Layers))
m.numKVHeadsByLayer = make([]int, len(m.Layers))
for i := range m.Layers {
m.numHeadsByLayer[i] = int(headCount[i])
m.numKVHeadsByLayer[i] = int(headCountKV[i])
if m.numKVHeadsByLayer[i] == 0 {
m.Layers[i].Operator = &ShortConv{}
} else {
m.Layers[i].Operator = &Attention{}
}
}
lCache := int(c.Uint("shortconv.l_cache"))
dConv := max(0, lCache-1)
m.Cache = NewHybridCache(m.Shift, m.hiddenSize, dConv)
return &m, nil
}
type Operator interface {
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output,alt:attn_out"`
}
func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
headDim := opts.headDimValue()
numHeads := opts.numHeadsByLayer[layer]
numKVHeads := opts.numKVHeadsByLayer[layer]
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, headDim, numHeads, batchSize)
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return sa.Output.Forward(ctx, attention)
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Operator Operator
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.Operator.Forward(ctx, hiddenState, positions, cache, layer, opts)
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, outputs, m.Cache.(*HybridCache), &m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {
model.Register("lfm2", New)
}

View File

@@ -0,0 +1,52 @@
package lfm2
import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type shortConvKernel struct {
Weight ml.Tensor `gguf:"weight"`
}
// ShortConv implements the LFM2 short-convolution block (GGML_OP_SSM_CONV) with a recurrent
// state stored in the HybridCache.
type ShortConv struct {
Conv *shortConvKernel `gguf:"shortconv.conv"`
InProj *nn.Linear `gguf:"shortconv.in_proj"`
OutProj *nn.Linear `gguf:"shortconv.out_proj"`
}
func (sc *ShortConv) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
nSeqs := cache.numSeqs()
seqTokens := cache.seqTokens()
hiddenSize := hiddenStates.Dim(0)
if nSeqs <= 0 || seqTokens <= 0 || hiddenStates.Dim(1) != nSeqs*seqTokens {
panic("lfm2: unsupported batch layout for shortconv")
}
bcx := sc.InProj.Forward(ctx, hiddenStates).Reshape(ctx, 3*hiddenSize, seqTokens, nSeqs)
elementSize := bcx.Stride(0)
b := bcx.View(ctx, 0*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
c := bcx.View(ctx, 1*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
x := bcx.View(ctx, 2*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
bx := b.Mul(ctx, x).Permute(ctx, 1, 0, 2, 3)
state, err := cache.ConvState(ctx, layer)
if err != nil {
panic("lfm2: failed to get conv state: " + err.Error())
}
sx := state.Concat(ctx, bx, 0)
// Cast weight to F32 for SSMConv (Metal requires F32)
weightF32 := sc.Conv.Weight.Cast(ctx, ml.DTypeF32)
convOut := sx.SSMConv(ctx, weightF32)
y := c.Mul(ctx, convOut)
dConv := sx.Dim(0) - seqTokens
cache.UpdateConvState(ctx, layer, sx.Slice(ctx, 0, sx.Dim(0)-dConv, sx.Dim(0), 1))
return sc.OutProj.Forward(ctx, y.Reshape(ctx, hiddenSize, seqTokens*nSeqs))
}

View File

@@ -9,6 +9,7 @@ import (
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/glm4moelite"
_ "github.com/ollama/ollama/model/models/gptoss"
_ "github.com/ollama/ollama/model/models/lfm2"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"

393
model/parsers/lfm2.go Normal file
View File

@@ -0,0 +1,393 @@
package parsers
import (
"encoding/json"
"errors"
"log/slog"
"strconv"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
type LFM2ParserState int
const (
LFM2CollectingThinking LFM2ParserState = iota
LFM2CollectingContent
LFM2CollectingToolCalls
)
const (
lfm2ThinkingOpenTag = "<think>"
lfm2ThinkingCloseTag = "</think>"
lfm2ToolCallStartTag = "<|tool_call_start|>"
lfm2ToolCallEndTag = "<|tool_call_end|>"
)
type LFM2Parser struct {
state LFM2ParserState
buffer strings.Builder
hasThinkingSupport bool
}
func (p *LFM2Parser) HasToolSupport() bool {
return true
}
func (p *LFM2Parser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.ThinkValue) {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
// Check both model capability AND request preference
thinkingEnabled := p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool())
if !thinkingEnabled {
p.state = LFM2CollectingContent
return
}
if prefill && lastMessage.Content != "" {
p.state = LFM2CollectingContent
return
}
p.state = LFM2CollectingThinking
}
func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.setInitialState(lastMessage, thinkValue)
return tools
}
type lfm2Event interface {
isLFM2Event()
}
type lfm2EventThinkingContent struct {
content string
}
type lfm2EventContent struct {
content string
}
type lfm2EventToolCall struct {
toolCall api.ToolCall
}
func (lfm2EventThinkingContent) isLFM2Event() {}
func (lfm2EventContent) isLFM2Event() {}
func (lfm2EventToolCall) isLFM2Event() {}
func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case lfm2EventToolCall:
toolCalls = append(toolCalls, event.toolCall)
case lfm2EventThinkingContent:
thinkingSb.WriteString(event.content)
case lfm2EventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *LFM2Parser) parseEvents() []lfm2Event {
var all []lfm2Event
keepLooping := true
for keepLooping {
var events []lfm2Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
return all
}
func (p *LFM2Parser) eat() ([]lfm2Event, bool) {
var events []lfm2Event
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case LFM2CollectingThinking:
// Strip opening <think> tag if present
if strings.HasPrefix(bufStr, lfm2ThinkingOpenTag) {
bufStr = bufStr[len(lfm2ThinkingOpenTag):]
bufStr = strings.TrimLeftFunc(bufStr, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(bufStr)
}
if strings.Contains(bufStr, lfm2ThinkingCloseTag) { // thinking[</think>] -> content
split := strings.SplitN(bufStr, lfm2ThinkingCloseTag, 2)
thinking := split[0]
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
remaining := split[1]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = LFM2CollectingContent
if len(thinking) > 0 {
events = append(events, lfm2EventThinkingContent{content: thinking})
}
return events, true
} else if overlapLen := overlap(bufStr, lfm2ThinkingCloseTag); overlapLen > 0 { // partial </think>
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, lfm2EventThinkingContent{content: unambiguous})
}
return events, false
} else { // otherwise its thinking content
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, lfm2EventThinkingContent{content: unambiguous})
}
return events, false
}
case LFM2CollectingContent:
if strings.Contains(bufStr, lfm2ToolCallStartTag) { // content[<|tool_call_start|>] -> tool calls
split := strings.SplitN(bufStr, lfm2ToolCallStartTag, 2)
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = LFM2CollectingToolCalls
if len(contentBefore) > 0 {
events = append(events, lfm2EventContent{content: contentBefore})
}
return events, true
} else { // otherwise its content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, lfm2EventContent{content: bufStr})
}
return events, false
}
case LFM2CollectingToolCalls:
// Look for complete tool call JSON between tags
if idx := strings.Index(bufStr, lfm2ToolCallEndTag); idx != -1 {
toolCallContent := bufStr[:idx]
if toolCall, err := p.parseToolCallContent(toolCallContent); err == nil {
remaining := bufStr[idx+len(lfm2ToolCallEndTag):]
// Check if there's another tool call
if strings.HasPrefix(remaining, lfm2ToolCallStartTag) {
remaining = remaining[len(lfm2ToolCallStartTag):]
} else {
// No more tool calls, go back to content
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.state = LFM2CollectingContent
}
p.buffer.Reset()
p.buffer.WriteString(remaining)
events = append(events, lfm2EventToolCall{toolCall: toolCall})
return events, true
} else {
slog.Warn("lfm2 tool call parsing failed", "error", err, "content", toolCallContent)
}
}
return events, false
}
return events, false
}
func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error) {
content = strings.TrimSpace(content)
// Try JSON format first: {"name": "func", "arguments": {...}}
var parsed struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" {
var args api.ToolCallFunctionArguments
if len(parsed.Arguments) > 0 {
if err := json.Unmarshal(parsed.Arguments, &args); err != nil {
return api.ToolCall{}, err
}
} else {
args = api.NewToolCallFunctionArguments()
}
return api.ToolCall{
Function: api.ToolCallFunction{
Name: parsed.Name,
Arguments: args,
},
}, nil
}
// Try Python-style format: [func(arg1='val1', arg2='val2')] or func(arg1='val1')
return p.parsePythonStyleToolCall(content)
}
// parsePythonStyleToolCall parses tool calls in Python function call syntax
// Examples: [bash(command='ls')] or bash(command='ls', flag='-la')
func (p *LFM2Parser) parsePythonStyleToolCall(content string) (api.ToolCall, error) {
content = strings.TrimSpace(content)
// Strip outer brackets if present: [func(...)] -> func(...)
if strings.HasPrefix(content, "[") && strings.HasSuffix(content, "]") {
content = content[1 : len(content)-1]
}
// Find function name and arguments: func(args)
parenIdx := strings.Index(content, "(")
if parenIdx == -1 {
return api.ToolCall{}, errors.New("invalid tool call: no opening parenthesis")
}
funcName := strings.TrimSpace(content[:parenIdx])
if funcName == "" {
return api.ToolCall{}, errors.New("invalid tool call: empty function name")
}
// Extract arguments between parentheses
if !strings.HasSuffix(content, ")") {
return api.ToolCall{}, errors.New("invalid tool call: no closing parenthesis")
}
argsStr := content[parenIdx+1 : len(content)-1]
args := api.NewToolCallFunctionArguments()
if argsStr != "" {
// Parse key='value' or key="value" pairs
if err := parsePythonArgs(argsStr, &args); err != nil {
return api.ToolCall{}, err
}
}
return api.ToolCall{
Function: api.ToolCallFunction{
Name: funcName,
Arguments: args,
},
}, nil
}
// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2"
func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error {
// Simple state machine to parse key='value' pairs
// Handles: command='ls', flag="-la", count=42, enabled=true
var key string
i := 0
for i < len(argsStr) {
// Skip whitespace
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
i++
}
if i >= len(argsStr) {
break
}
// Parse key
keyStart := i
for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' {
i++
}
if i >= len(argsStr) || argsStr[i] != '=' {
return errors.New("invalid argument: expected '='")
}
key = strings.TrimSpace(argsStr[keyStart:i])
i++ // skip '='
// Skip whitespace after =
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') {
i++
}
// Parse value
var value string
if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') {
// Quoted string
quote := argsStr[i]
i++
valueStart := i
for i < len(argsStr) && argsStr[i] != quote {
if argsStr[i] == '\\' && i+1 < len(argsStr) {
i += 2 // skip escaped char
} else {
i++
}
}
value = argsStr[valueStart:i]
if i < len(argsStr) {
i++ // skip closing quote
}
args.Set(key, value)
} else {
// Unquoted value (number, bool, etc)
valueStart := i
for i < len(argsStr) && argsStr[i] != ',' {
i++
}
value = strings.TrimSpace(argsStr[valueStart:i])
// Try to parse as number or bool
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
args.Set(key, v)
} else if v, err := strconv.ParseFloat(value, 64); err == nil {
args.Set(key, v)
} else if value == "true" {
args.Set(key, true)
} else if value == "false" {
args.Set(key, false)
} else {
args.Set(key, value)
}
}
// Skip comma and whitespace
for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
i++
}
}
return nil
}

549
model/parsers/lfm2_test.go Normal file
View File

@@ -0,0 +1,549 @@
package parsers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestLFM2Parser(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
expectedCalls []api.ToolCall
hasThinking bool
}{
{
name: "simple_content",
input: "Hello, how are you?",
expectedContent: "Hello, how are you?",
hasThinking: false,
},
{
name: "thinking_content",
input: "I need to think about this...</think>The answer is 42.",
expectedThinking: "I need to think about this...",
expectedContent: "The answer is 42.",
hasThinking: true,
},
{
name: "no_thinking_simple",
input: "Just a regular response.",
expectedContent: "Just a regular response.",
hasThinking: false,
},
{
name: "thinking_with_newlines",
input: "Let me think:\n- Point 1\n- Point 2</think>\n\nHere's my answer.",
expectedThinking: "Let me think:\n- Point 1\n- Point 2",
expectedContent: "Here's my answer.",
hasThinking: true,
},
{
name: "tool_call_simple",
input: "I'll check the weather.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
expectedContent: "I'll check the weather.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
hasThinking: false,
},
{
name: "multiple_tool_calls",
input: "Getting weather for both cities.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|><|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"London\"}}<|tool_call_end|>",
expectedContent: "Getting weather for both cities.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "London",
}),
},
},
},
hasThinking: false,
},
{
name: "complex_tool_arguments",
input: "Processing data.<|tool_call_start|>{\"name\":\"process_data\",\"arguments\":{\"items\":[\"item1\",\"item2\"],\"config\":{\"enabled\":true,\"threshold\":0.95}}}<|tool_call_end|>",
expectedContent: "Processing data.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: testArgs(map[string]any{
"items": []interface{}{"item1", "item2"},
"config": map[string]interface{}{"enabled": true, "threshold": 0.95},
}),
},
},
},
hasThinking: false,
},
{
name: "thinking_with_tool_call",
input: "Let me check the weather...</think>I'll get that for you.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
expectedThinking: "Let me check the weather...",
expectedContent: "I'll get that for you.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
hasThinking: true,
},
{
name: "empty_content",
input: "",
expectedContent: "",
hasThinking: false,
},
{
name: "only_thinking",
input: "Just thinking content</think>",
expectedThinking: "Just thinking content",
expectedContent: "",
hasThinking: true,
},
{
name: "unicode_content",
input: "مرحبا بالعالم! 你好世界! 🌍",
expectedContent: "مرحبا بالعالم! 你好世界! 🌍",
hasThinking: false,
},
{
name: "emoji_passthrough",
input: "Task completed ✅ 🎉",
expectedContent: "Task completed ✅ 🎉",
hasThinking: false,
},
{
name: "newlines_and_whitespace",
input: "Line 1\n\nLine 3\t\tTabbed content",
expectedContent: "Line 1\n\nLine 3\t\tTabbed content",
hasThinking: false,
},
{
name: "thinking_with_unicode",
input: "我在思考这个问题...</think>答案是42。",
expectedThinking: "我在思考这个问题...",
expectedContent: "答案是42。",
hasThinking: true,
},
{
name: "tool_call_with_unicode_args",
input: "Searching for information.<|tool_call_start|>{\"name\":\"search\",\"arguments\":{\"query\":\"北京天气\",\"language\":\"中文\"}}<|tool_call_end|>",
expectedContent: "Searching for information.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: testArgs(map[string]any{
"query": "北京天气",
"language": "中文",
}),
},
},
},
hasThinking: false,
},
{
name: "thinking_with_special_chars",
input: "Let me calculate: 2+2=4 & 3*3=9...</think>The results are correct!",
expectedThinking: "Let me calculate: 2+2=4 & 3*3=9...",
expectedContent: "The results are correct!",
hasThinking: true,
},
{
name: "empty_tool_call_args",
input: "Pinging server.<|tool_call_start|>{\"name\":\"ping\",\"arguments\":{}}<|tool_call_end|>",
expectedContent: "Pinging server.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "ping",
Arguments: api.NewToolCallFunctionArguments(),
},
},
},
hasThinking: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking}
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
content, thinking, calls, err := parser.Add(tt.input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
t.Errorf("Content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedCalls, calls, argsComparer); diff != "" {
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestLFM2Parser_Streaming(t *testing.T) {
tests := []struct {
name string
chunks []string
expectedContent string
expectedThinking string
expectedCalls []api.ToolCall
hasThinking bool
}{
{
name: "streaming_simple_content",
chunks: []string{"Hello, ", "how are ", "you?"},
expectedContent: "Hello, how are you?",
hasThinking: false,
},
{
name: "streaming_thinking",
chunks: []string{"I need to ", "think about this", "...</think>", "The answer is 42."},
expectedThinking: "I need to think about this...",
expectedContent: "The answer is 42.",
hasThinking: true,
},
{
name: "streaming_tool_call",
chunks: []string{"I'll check weather.", "<|tool_call_start|>", "{\"name\":\"get_weather\",", "\"arguments\":{\"location\":\"Paris\"}}", "<|tool_call_end|>"},
expectedContent: "I'll check weather.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
hasThinking: false,
},
{
name: "streaming_thinking_with_partial_tag",
chunks: []string{"Thinking about this", "...</", "think>", "Done thinking."},
expectedThinking: "Thinking about this...",
expectedContent: "Done thinking.",
hasThinking: true,
},
{
name: "streaming_unicode_content",
chunks: []string{"مرحبا ", "بالعالم! ", "你好", "世界!"},
expectedContent: "مرحبا بالعالم! 你好世界!",
hasThinking: false,
},
{
name: "streaming_tool_call_with_split_json",
chunks: []string{"Processing.", "<|tool_call_start|>{\"name\":\"calc\",\"arguments\":{\"x\":", "42,\"y\":", "24}}<|tool_call_end|>"},
expectedContent: "Processing.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "calc",
Arguments: testArgs(map[string]any{
"x": float64(42),
"y": float64(24),
}),
},
},
},
hasThinking: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking}
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
var allContent, allThinking string
var allCalls []api.ToolCall
for i, chunk := range tt.chunks {
done := i == len(tt.chunks)-1
content, thinking, calls, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
allContent += content
allThinking += thinking
allCalls = append(allCalls, calls...)
}
if diff := cmp.Diff(tt.expectedContent, allContent); diff != "" {
t.Errorf("Content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, allThinking); diff != "" {
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" {
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestLFM2Parser_HasThinkingSupport(t *testing.T) {
tests := []struct {
name string
hasThinking bool
expectedSupport bool
}{
{
name: "thinking_enabled",
hasThinking: true,
expectedSupport: true,
},
{
name: "thinking_disabled",
hasThinking: false,
expectedSupport: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking}
if got := parser.HasThinkingSupport(); got != tt.expectedSupport {
t.Errorf("HasThinkingSupport() = %v, want %v", got, tt.expectedSupport)
}
})
}
}
func TestLFM2Parser_HasToolSupport(t *testing.T) {
parser := &LFM2Parser{}
if !parser.HasToolSupport() {
t.Error("HasToolSupport() should return true")
}
}
func TestLFM2Parser_Init(t *testing.T) {
parser := &LFM2Parser{hasThinkingSupport: true}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "test_tool",
},
},
}
returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true})
if diff := cmp.Diff(tools, returnedTools, toolsComparer); diff != "" {
t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff)
}
// Test initial state is set to thinking when enabled
if parser.state != LFM2CollectingThinking {
t.Errorf("Expected initial state to be LFM2CollectingThinking, got %v", parser.state)
}
}
func TestLFM2Parser_parseToolCallContent(t *testing.T) {
tests := []struct {
name string
content string
expected api.ToolCall
expectError bool
}{
{
name: "valid_tool_call",
content: `{"name":"get_weather","arguments":{"location":"Paris"}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
{
name: "complex_arguments",
content: `{"name":"process_data","arguments":{"items":["a","b"],"config":{"enabled":true}}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: testArgs(map[string]any{
"items": []interface{}{"a", "b"},
"config": map[string]interface{}{"enabled": true},
}),
},
},
},
{
name: "empty_arguments",
content: `{"name":"ping","arguments":{}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "ping",
Arguments: api.NewToolCallFunctionArguments(),
},
},
},
{
name: "unicode_in_tool_name",
content: `{"name":"获取天气","arguments":{"城市":"北京"}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: testArgs(map[string]any{
"城市": "北京",
}),
},
},
},
{
name: "numeric_arguments",
content: `{"name":"calculate","arguments":{"x":3.14,"y":42,"enabled":true}}`,
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: testArgs(map[string]any{
"x": 3.14,
"y": float64(42),
"enabled": true,
}),
},
},
},
{
name: "invalid_json",
content: `{invalid json}`,
expectError: true,
},
{
name: "missing_name",
content: `{"arguments":{"arg":"value"}}`,
expectError: true,
},
{
name: "empty_name",
content: `{"name":"","arguments":{"arg":"value"}}`,
expectError: true,
},
}
parser := &LFM2Parser{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parser.parseToolCallContent(tt.content)
if tt.expectError {
if err == nil {
t.Error("Expected error but got none")
}
return
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" {
t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestLFM2Parser_EdgeCases(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
hasThinking bool
}{
{
name: "multiple_think_close_tags",
input: "First thought</think>Second thought</think>Final content",
expectedThinking: "First thought",
expectedContent: "Second thought</think>Final content",
hasThinking: true,
},
{
name: "empty_thinking_content",
input: "</think>Just content",
expectedThinking: "",
expectedContent: "Just content",
hasThinking: true,
},
{
name: "thinking_disabled_with_think_tags",
input: "Some content</think>More content",
expectedContent: "Some content</think>More content",
hasThinking: false,
},
{
name: "whitespace_only_content",
input: " \n\t ",
expectedContent: " \n\t ",
hasThinking: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking}
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
content, thinking, _, err := parser.Add(tt.input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
t.Errorf("Content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -70,6 +70,10 @@ func ParserForName(name string) Parser {
return &FunctionGemmaParser{}
case "glm-4.7":
return &GLM47Parser{}
case "lfm2":
return &LFM2Parser{hasThinkingSupport: false}
case "lfm2-thinking":
return &LFM2Parser{hasThinkingSupport: true}
default:
return nil
}

139
model/renderers/lfm2.go Normal file
View File

@@ -0,0 +1,139 @@
package renderers
import (
"encoding/json"
"strings"
"github.com/ollama/ollama/api"
)
type LFM2Renderer struct {
IsThinking bool
}
func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
// Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer
// Extract first system message if present (to combine with tools)
var firstSystemContent string
startIdx := 0
if len(messages) > 0 && messages[0].Role == "system" {
firstSystemContent = messages[0].Content
startIdx = 1
}
// Append tools to first system content
if len(tools) > 0 {
if firstSystemContent != "" {
firstSystemContent += "\n"
}
firstSystemContent += "List of tools: ["
for i, tool := range tools {
toolJSON, err := json.Marshal(tool)
if err != nil {
return "", err
}
firstSystemContent += string(toolJSON)
if i < len(tools)-1 {
firstSystemContent += ", "
}
}
firstSystemContent += "]\nOutput function calls as JSON."
}
// Output first system block if it has content
if firstSystemContent != "" {
sb.WriteString("<|im_start|>system\n")
sb.WriteString(firstSystemContent)
sb.WriteString("<|im_end|>\n")
}
// Find the index of the last assistant message for thinking stripping
lastAssistantIndex := -1
for i := len(messages) - 1; i >= startIdx; i-- {
if messages[i].Role == "assistant" {
lastAssistantIndex = i
break
}
}
// Track whether we need to add generation prompt
needsGenerationPrompt := len(messages) > 0
for i := startIdx; i < len(messages); i++ {
message := messages[i]
switch message.Role {
case "system":
// Additional system messages (after the first) are rendered normally
sb.WriteString("<|im_start|>system\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
case "user":
sb.WriteString("<|im_start|>user\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true
case "assistant":
sb.WriteString("<|im_start|>assistant\n")
// Check if this is the last assistant message
isLastAssistant := i == lastAssistantIndex
// Process content (may need thinking stripped)
content := message.Content
// For past assistant messages (not the last), strip thinking when keep_past_thinking=false
// This applies to both regular content and content before tool calls
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
if !isLastAssistant && !keepPastThinking && strings.Contains(content, "</think>") {
parts := strings.SplitN(content, "</think>", 2)
if len(parts) > 1 {
content = strings.TrimSpace(parts[1])
}
}
if len(message.ToolCalls) > 0 {
// Assistant with tool calls - write content first (if any after stripping)
if content != "" {
sb.WriteString(content)
}
for _, toolCall := range message.ToolCalls {
sb.WriteString("<|tool_call_start|>")
toolCallJSON := map[string]any{
"name": toolCall.Function.Name,
"arguments": toolCall.Function.Arguments,
}
callJSON, _ := json.Marshal(toolCallJSON)
sb.WriteString(string(callJSON))
sb.WriteString("<|tool_call_end|>")
}
} else {
sb.WriteString(content)
}
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true
case "tool":
// Tool responses are rendered as plain messages per the chat template
sb.WriteString("<|im_start|>tool\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true
}
}
// Add generation prompt
if needsGenerationPrompt {
sb.WriteString("<|im_start|>assistant\n")
// Note: Model is a "thinking-only" model - it will output <think> itself
// We don't add <think> tag to the prompt
}
return sb.String(), nil
}

View File

@@ -0,0 +1,403 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestLFM2Renderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
}{
{
name: "basic user message",
messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple system messages rendered separately",
messages: []api.Message{
{Role: "system", Content: "First instruction."},
{Role: "system", Content: "Second instruction."},
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multi-turn conversation",
messages: []api.Message{
{Role: "user", Content: "What is 2+2?"},
{Role: "assistant", Content: "The answer is 4."},
{Role: "user", Content: "Thanks!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "only system message",
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n",
},
{
// When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false)
name: "user-assistant-user: last assistant preserves thinking",
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reasoning</think>A1"},
{Role: "user", Content: "Q2"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reasoning</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n",
},
{
// With two assistants, first is stripped (not last), second preserved (is last)
name: "multi-turn thinking: first stripped, second preserved",
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reason1</think>A1"},
{Role: "user", Content: "Q2"},
{Role: "assistant", Content: "<think>reason2</think>A2"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
},
{
// With thinking enabled (keep_past_thinking=true), both preserved
name: "multi-turn thinking: both preserved when thinking enabled",
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reason1</think>A1"},
{Role: "user", Content: "Q2"},
{Role: "assistant", Content: "<think>reason2</think>A2"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "assistant with tool calls",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "assistant with content and tool calls",
messages: []api.Message{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: "Let me check.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tool response",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{Role: "assistant", Content: "Let me check."},
{Role: "tool", Content: "22C, Sunny"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple tool calls",
messages: []api.Message{
{Role: "user", Content: "Get weather for Paris and London"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "London",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tools definitions with system message",
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "What's the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]` + "\n" + `Output function calls as JSON.<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tools definitions without system message",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]` + "\n" + `Output function calls as JSON.<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "user-tool sequence",
messages: []api.Message{
{Role: "user", Content: "Check weather"},
{Role: "tool", Content: "22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "full tool call cycle",
messages: []api.Message{
{Role: "user", Content: "Check weather"},
{Role: "assistant", Content: "Let me check"},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "It's 22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "unicode content",
messages: []api.Message{
{Role: "user", Content: "你好世界! مرحبا 🌍"},
{Role: "assistant", Content: "Hello! 👋"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "newlines in content",
messages: []api.Message{
{Role: "user", Content: "Line 1\nLine 2\n\nLine 4"},
{Role: "assistant", Content: "Response with\nmultiple\nlines"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "empty assistant content",
messages: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: ""},
{Role: "user", Content: "OK"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n",
},
{
// Generation prompt does NOT include <think> - model outputs it
name: "generation prompt has no think tag",
messages: []api.Message{
{Role: "user", Content: "Think hard"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n",
},
{
// Interleaved: thinking before tool call - last assistant preserves thinking
name: "thinking before tool call (last assistant)",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>I need to check the weather</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>I need to check the weather</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
},
{
// Two assistants with tool calls - first has thinking stripped
name: "two assistants with tools: first thinking stripped",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>checking</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
},
{
// Two assistants with tools - both preserved when thinking enabled
name: "two assistants with tools: both preserved when thinking enabled",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>checking</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>checking</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
},
{
// Content before thinking before tool call
name: "content then thinking then tool call",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "Let me check.<think>Using weather API</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<think>Using weather API</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
},
}
renderer := &LFM2Renderer{IsThinking: true}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatalf("Render() error = %v", err)
}
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -82,6 +82,10 @@ func rendererForName(name string) Renderer {
return &FunctionGemmaRenderer{}
case "glm-4.7":
return &GLM47Renderer{}
case "lfm2":
return &LFM2Renderer{IsThinking: false}
case "lfm2-thinking":
return &LFM2Renderer{IsThinking: true}
default:
return nil
}