Compare commits

...

13 Commits

Author SHA1 Message Date
nicole pardal
29a2d6d931 fixed converter 2025-12-10 16:11:52 -08:00
nicole pardal
b6f769ae60 rope options 2025-12-10 14:40:15 -08:00
ParthSareen
a613eca69c improvements 2025-12-09 22:26:30 -08:00
ParthSareen
3015146cda test + model 2025-12-09 18:28:29 -08:00
ParthSareen
5d50848c52 arch changes wip 2025-12-09 14:05:48 -08:00
ParthSareen
991a63b6ca renderers/parsers: olmo3 instruct 2025-12-09 10:52:57 -08:00
nicole pardal
2c147bc780 fixed pretokenizer 2025-12-09 10:52:37 -08:00
nicole pardal
d8bf6a5dee fixed generation issue 2025-12-09 10:52:37 -08:00
nicole pardal
3eea7f198b removed original olmo support 2025-12-09 10:52:37 -08:00
nicole pardal
494284770d removed olmo1 support 2025-12-09 10:52:37 -08:00
nicole pardal
57569274ec lint 2025-12-09 10:52:37 -08:00
nicole pardal
7505cd963e updated converter 2025-12-09 10:52:37 -08:00
nicole pardal
bdcf9e811b olmo model initial 2025-12-09 10:52:37 -08:00
12 changed files with 2389 additions and 0 deletions

View File

@@ -200,6 +200,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &qwen25VLModel{}
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
conv = &qwen3VLModel{}
case "OLMo2ForCausalLM", "Olmo2ForCausalLM", "OLMo3ForCausalLM", "Olmo3ForCausalLM":
conv = &olmoModel{}
case "BertModel":
conv = &bertModel{}
case "CohereForCausalLM":

124
convert/convert_olmo.go Normal file
View File

@@ -0,0 +1,124 @@
package convert
import (
"cmp"
"github.com/ollama/ollama/fs/ggml"
)
type ropeScaling struct {
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"`
AttentionFactor float32 `json:"attention_factor"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
RopeType string `json:"rope_type"`
ExtrapolationFactor float32 `json:"extrapolation_factor"`
}
type olmoModel struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling *ropeScaling `json:"rope_scaling"`
ClampKQV float32 `json:"f_clamp_kqv"`
SlidingWindow uint32 `json:"sliding_window"`
LayerTypes []string `json:"layer_types"`
}
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo2"
kv["olmo2.block_count"] = p.NumHiddenLayers
kv["olmo2.context_length"] = p.MaxPositionEmbeddings
kv["olmo2.embedding_length"] = p.HiddenSize
kv["olmo2.feed_forward_length"] = p.IntermediateSize
kv["olmo2.attention.head_count"] = p.NumAttentionHeads
kv["olmo2.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
if p.RopeTheta > 0 {
kv["olmo2.rope.freq_base"] = p.RopeTheta
} else {
kv["olmo2.rope.freq_base"] = float32(10000.0)
}
if p.RopeScaling != nil {
if p.RopeScaling.Factor > 0 {
kv["olmo2.rope.scaling.factor"] = p.RopeScaling.Factor
}
if p.RopeScaling.OriginalMaxPositionEmbeds > 0 {
kv["olmo2.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds
}
if p.RopeScaling.AttentionFactor > 0 {
kv["olmo2.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor
}
if p.RopeScaling.RopeType != "" {
kv["olmo2.rope.scaling.type"] = p.RopeScaling.RopeType
}
}
if p.RMSNormEPS > 0 {
kv["olmo2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
}
if p.ClampKQV > 0 {
kv["olmo2.attention.clamp_kqv"] = p.ClampKQV
}
if p.SlidingWindow > 0 {
kv["olmo2.attention.sliding_window"] = p.SlidingWindow
}
if len(p.LayerTypes) > 0 {
slidingPattern := make([]bool, len(p.LayerTypes))
for i, layerType := range p.LayerTypes {
slidingPattern[i] = (layerType == "sliding_attention")
}
kv["olmo2.attention.sliding_window_pattern"] = slidingPattern
}
return kv
}
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
out := make([]*ggml.Tensor, 0, len(ts))
for _, t := range ts {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *olmoModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"model.norm", "output_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_norm", "attn_k_norm",
"post_attention_layernorm", "post_attention_norm",
"post_feedforward_layernorm", "post_ffw_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
}
}

View File

@@ -252,6 +252,7 @@ func (kv KV) OllamaEngineRequired() bool {
"deepseekocr",
"deepseek2",
"nomic-bert",
"olmo2",
}, kv.Architecture())
}

View File

@@ -13,6 +13,7 @@ import (
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
_ "github.com/ollama/ollama/model/models/nomicbert"
_ "github.com/ollama/ollama/model/models/olmo"
_ "github.com/ollama/ollama/model/models/qwen2"
_ "github.com/ollama/ollama/model/models/qwen25vl"
_ "github.com/ollama/ollama/model/models/qwen3"

298
model/models/olmo/model.go Normal file
View File

@@ -0,0 +1,298 @@
package olmo
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
const (
cacheTypeSWA = iota
cacheTypeCausal
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
// headDim, ropeDim int
eps, ropeBase, ropeScale float32
originalContextLength int
attnFactor float32
ropeType string
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
slidingWindowPattern []bool
}
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Options
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
if c.String("tokenizer.ggml.model") != "gpt2" {
return nil, model.ErrUnsupportedTokenizer
}
var pretokenizers []string
if c.String("tokenizer.ggml.pre") != "default" {
pretokenizers = []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor := model.NewBytePairEncoding(&vocabulary, pretokenizers...)
hiddenSize := int(c.Uint("embedding_length"))
numHeads := int(c.Uint("attention.head_count"))
numKVHeads := int(c.Uint("attention.head_count_kv"))
eps := c.Float("attention.layer_norm_rms_epsilon")
ropeBase := c.Float("rope.freq_base", 1e4)
ropeScale := c.Float("rope.scaling.factor", 1)
originalContextLength := int(c.Uint("rope.scaling.original_context_length"))
attnFactor := c.Float("rope.scaling.attn_factor", 1)
ropeType := c.String("rope.scaling.type")
ropeExtrapolation := c.Float("rope.scaling.extrapolation_factor", 1.0)
fmt.Printf("hiddenSize: %d\n", hiddenSize)
fmt.Printf("numHeads: %d\n", numHeads)
fmt.Printf("numKVHeads: %d\n", numKVHeads)
fmt.Printf("eps: %f\n", eps)
fmt.Printf("ropeBase: %f\n", ropeBase)
fmt.Printf("ropeScale: %f\n", ropeScale)
fmt.Printf("originalContextLength: %d\n", originalContextLength)
fmt.Printf("attnFactor: %f\n", attnFactor)
fmt.Printf("ropeType: %s\n", ropeType)
fmt.Printf("ropeExtrapolation: %f\n", ropeExtrapolation)
fmt.Printf("sliding_window_pattern: %v\n", c.Bools("attention.sliding_window_pattern"))
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,
numKVHeads: numKVHeads,
eps: eps,
ropeBase: ropeBase,
ropeScale: ropeScale,
originalContextLength: originalContextLength,
attnFactor: attnFactor,
ropeType: ropeType,
ropeExtrapolation: ropeExtrapolation,
ropeBetaFast: 32.0,
ropeBetaSlow: 1.0,
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
},
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (m *Model) applyRoPE(ctx ml.Context, states, positions ml.Tensor, ropeDim int, isSWA bool) ml.Tensor {
var ropeOpts []func(*rope.Options)
ropeOpts = append(ropeOpts, rope.WithTypeNeoX())
// Both SWA and non-SWA use beta_fast and beta_slow
// defaults
ropeOpts = append(ropeOpts,
rope.WithBetaFast(m.ropeBetaFast),
rope.WithBetaSlow(m.ropeBetaSlow),
)
// SWA uses freq_scale=1.0, ext_factor=0.0, attn_factor=1.0
// Non-SWA uses full yarn parameters
if m.originalContextLength > 0 {
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(m.originalContextLength),
)
// no yarn for swa
if isSWA {
ropeOpts = append(ropeOpts,
rope.WithExtrapolationFactor(0),
rope.WithAttentionFactor(1.),
)
} else {
ropeOpts = append(ropeOpts,
rope.WithExtrapolationFactor(m.ropeExtrapolation),
rope.WithAttentionFactor(m.attnFactor),
)
}
}
freqScale := float32(1.0)
if !isSWA {
freqScale = 1. / m.ropeScale
}
return nn.RoPE(ctx, states, positions, ropeDim, m.ropeBase, freqScale, ropeOpts...)
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := m.hiddenSize / m.numHeads
ropeDim := headDim
query := sa.Query.Forward(ctx, hiddenState)
// double check type
query = sa.QNorm.Forward(ctx, query, m.eps)
query = query.Reshape(ctx, headDim, m.numHeads, batchSize)
//check here
query = m.applyRoPE(ctx, query, positions, ropeDim, isSWA)
key := sa.Key.Forward(ctx, hiddenState)
key = sa.KNorm.Forward(ctx, key, m.eps)
key = key.Reshape(ctx, headDim, m.numKVHeads, batchSize)
key = m.applyRoPE(ctx, key, positions, ropeDim, isSWA)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize)
// check attention scaling as well
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, m.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := m.hiddenSize / m.numHeads
isSWA := m.isSWALayer(layer)
return m.applyRoPE(ctx, key, shift, ropeDim, isSWA), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, m *Model) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
SelfAttention *SelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLP *MLP
PostFFWNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor {
residual := hiddenState
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, m, isSWA)
// return hiddenState
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
// i think this should be after getting the rows?
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLP.Forward(ctx, hiddenState, m)
hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState.Add(ctx, residual)
}
// Olmo3 has Sliding Window Attention (SWA) 3 out of 4 layers.
func (m *Model) isSWALayer(layerIdx int) bool {
return m.Options.slidingWindowPattern[layerIdx]
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
cacheType := cacheTypeSWA
isSWA := m.isSWALayer(i)
if !isSWA {
cacheType = cacheTypeCausal
}
wc := m.Cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
// would need to check the cache at the layer instead
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
// TODO: not sure about the index here
causal.SetCausal(ctx, kvcache.CausalOptions{Except: []int{}})
}
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m, isSWA)
// return hiddenState, nil
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {
model.Register("olmo2", New)
}

View File

@@ -0,0 +1,568 @@
package olmo
import (
"encoding/binary"
"encoding/json"
"flag"
"fmt"
"log/slog"
"math"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
typemodel "github.com/ollama/ollama/types/model"
)
var args struct {
model,
prompt string
layers int
}
func TestMain(m *testing.M) {
flag.StringVar(&args.model, "model", "", "path to model (e.g., olmo3:latest)")
flag.StringVar(&args.prompt, "prompt", "Hello, how are", "model prompt")
flag.IntVar(&args.layers, "layers", math.MaxInt, "num of gpu layers")
flag.Parse()
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
os.Exit(m.Run())
}
func blob(tb testing.TB, modelName string) string {
tb.Helper()
models := envconfig.Models()
manifest, err := os.Open(filepath.Join(models, "manifests", typemodel.ParseName(modelName).Filepath()))
if err != nil {
tb.Fatal(err)
}
defer manifest.Close()
var m struct {
Layers []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
} `json:"layers"`
}
if err := json.NewDecoder(manifest).Decode(&m); err != nil {
tb.Fatal(err)
}
for _, layer := range m.Layers {
if layer.MediaType == "application/vnd.ollama.image.model" {
tb.Log("using model blob", layer.Digest)
return filepath.Join(models, "blobs", strings.ReplaceAll(layer.Digest, ":", "-"))
}
}
tb.Fatal("model blob not found")
return ""
}
func loadFloatsFromBinary(filename string) ([]float32, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return nil, err
}
if fi.Size()%4 != 0 {
return nil, fmt.Errorf("file size %d not multiple of 4", fi.Size())
}
n := int(fi.Size() / 4)
floats := make([]float32, n)
if err := binary.Read(f, binary.LittleEndian, floats); err != nil {
return nil, err
}
return floats, nil
}
func TestTokenization(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
prompt := args.prompt
if prompt == "" {
prompt = "hello, how are you?"
}
tp := m.(model.TextProcessor)
tokens, err := tp.Encode(prompt, false)
if err != nil {
t.Fatal(err)
}
t.Logf("prompt: %q", prompt)
t.Logf("tokens: %v", tokens)
t.Logf("num tokens: %d", len(tokens))
decoded, err := tp.Decode(tokens)
if err != nil {
t.Fatal(err)
}
t.Logf("decoded: %q", decoded)
}
func TestAttentionForward(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
olmoModel := m.(*Model)
t.Logf("Model options: hiddenSize=%d, numHeads=%d, numKVHeads=%d",
olmoModel.hiddenSize, olmoModel.numHeads, olmoModel.numKVHeads)
t.Logf("Layer 0 attention: %+v", olmoModel.Layers[0].SelfAttention)
ctx := m.Backend().NewContext()
// Create test hidden states: (hiddenSize, batchSize)
batchSize := 4
hiddenSize := olmoModel.hiddenSize
hsFloats := make([]float32, hiddenSize*batchSize)
for i := range hsFloats {
hsFloats[i] = float32(i%100) / 100.0 // Simple test values
}
hiddenStates := ctx.Input().FromFloats(hsFloats, hiddenSize, batchSize)
t.Logf("hiddenStates shape: %v", hiddenStates.Shape())
positions := ctx.Input().FromInts([]int32{0, 1, 2, 3}, batchSize)
// Test attention forward (without cache for simplicity)
attentionBlock := olmoModel.Layers[0].SelfAttention
isSWA := olmoModel.isSWALayer(0)
t.Logf("Layer 0 isSWA: %v", isSWA)
result := attentionBlock.Forward(ctx, hiddenStates, positions, nil, olmoModel, isSWA)
result = result.Contiguous(ctx)
ctx.Forward(result).Compute(result)
t.Logf("Attention result shape: %v dtype: %v", result.Shape(), result.DType())
// Optionally dump to file
// if err := os.WriteFile("/tmp/olmo_attention_output.bin", result.Bytes(), 0644); err != nil {
// t.Fatal(err)
// }
}
func TestMLPForward(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
olmoModel := m.(*Model)
ctx := m.Backend().NewContext()
// Create test hidden states
batchSize := 4
hiddenSize := olmoModel.hiddenSize
hsFloats := make([]float32, hiddenSize*batchSize)
for i := range hsFloats {
hsFloats[i] = float32(i%100) / 100.0
}
hiddenStates := ctx.Input().FromFloats(hsFloats, hiddenSize, batchSize)
t.Logf("hiddenStates shape: %v", hiddenStates.Shape())
mlpBlock := olmoModel.Layers[0].MLP
result := mlpBlock.Forward(ctx, hiddenStates, olmoModel)
result = result.Contiguous(ctx)
ctx.Forward(result).Compute(result)
t.Logf("MLP result shape: %v dtype: %v", result.Shape(), result.DType())
// Parse result bytes to float32
resultBytes := result.Bytes()
resultFloats := make([]float32, len(resultBytes)/4)
for i := range resultFloats {
bits := binary.LittleEndian.Uint32(resultBytes[i*4 : (i+1)*4])
resultFloats[i] = math.Float32frombits(bits)
}
// Compute statistics
var minVal, maxVal, sum float32
minVal = resultFloats[0]
maxVal = resultFloats[0]
for _, v := range resultFloats {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
sum += v
}
mean := sum / float32(len(resultFloats))
// Build readable output
var sb strings.Builder
sb.WriteString("# MLP Forward Output\n\n")
sb.WriteString(fmt.Sprintf("# Input Shape: [%d, %d] (hiddenSize, batchSize)\n", hiddenSize, batchSize))
sb.WriteString(fmt.Sprintf("# Output Shape: %v\n", result.Shape()))
sb.WriteString(fmt.Sprintf("# DType: %v\n", result.DType()))
sb.WriteString(fmt.Sprintf("# Layer: 0\n\n"))
sb.WriteString("## Statistics\n\n")
sb.WriteString(fmt.Sprintf(" Total elements: %d\n", len(resultFloats)))
sb.WriteString(fmt.Sprintf(" Min: %v\n", minVal))
sb.WriteString(fmt.Sprintf(" Max: %v\n", maxVal))
sb.WriteString(fmt.Sprintf(" Mean: %v\n\n", mean))
sb.WriteString("## Input Hidden States (first 20 values)\n\n")
sb.WriteString(" [")
for i := 0; i < min(20, len(hsFloats)); i++ {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%v", hsFloats[i]))
}
sb.WriteString("]\n\n")
sb.WriteString("## Output Values\n\n")
// Per-position output (each position in batch)
for pos := 0; pos < batchSize; pos++ {
sb.WriteString(fmt.Sprintf("Position %d (hiddenSize=%d values):\n", pos, hiddenSize))
// Extract values for this position
posStart := pos * hiddenSize
posEnd := posStart + hiddenSize
if posEnd > len(resultFloats) {
posEnd = len(resultFloats)
}
posValues := resultFloats[posStart:posEnd]
// Full tensor values
sb.WriteString(" [")
for i, v := range posValues {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%v", v))
}
sb.WriteString("]\n\n")
}
// Save to file
if err := os.WriteFile("/tmp/olmo_mlp_forward.txt", []byte(sb.String()), 0644); err != nil {
t.Fatal(err)
}
t.Log("Saved /tmp/olmo_mlp_forward.txt")
// Also save binary
if err := os.WriteFile("/tmp/olmo_mlp_forward.bin", resultBytes, 0644); err != nil {
t.Fatal(err)
}
t.Log("Saved /tmp/olmo_mlp_forward.bin")
// Print summary to console
fmt.Println(sb.String())
}
func TestFullForward(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
ctx := m.Backend().NewContext()
prompt := args.prompt
if prompt == "" {
prompt = "Hello, how are you?"
}
tp := m.(model.TextProcessor)
tokens, err := tp.Encode(prompt, false)
if err != nil {
t.Fatal(err)
}
t.Logf("prompt: %q", prompt)
t.Logf("tokens: %v", tokens)
decoded, err := tp.Decode(tokens)
if err != nil {
t.Fatal(err)
}
t.Logf("decoded: %q", decoded)
seqLen := len(tokens)
inputsTensor := ctx.Input().FromInts(tokens, seqLen)
positions := make([]int32, seqLen)
sequences := make([]int, seqLen)
for i := range tokens {
positions[i] = int32(i)
sequences[i] = 0
}
// Output ALL positions
outputIndices := make([]int32, seqLen)
for i := range outputIndices {
outputIndices[i] = int32(i)
}
outputs := ctx.Input().FromInts(outputIndices, seqLen)
batch := input.Batch{
Inputs: inputsTensor,
Positions: positions,
Sequences: sequences,
Outputs: outputs,
}
// Initialize cache
if cache := m.Config().Cache; cache != nil {
cache.Init(m.Backend(), ml.DTypeF16, 1, 4096, seqLen)
}
result, err := model.Forward(ctx, m, batch)
if err != nil {
t.Fatal(err)
}
result = result.Contiguous(ctx)
ctx.Forward(result).Compute(result)
t.Logf("Forward pass completed, result shape: %v", result.Shape())
// Dump logits to binary file
if err := os.WriteFile("/tmp/olmo_logits.bin", result.Bytes(), 0644); err != nil {
t.Fatal(err)
}
t.Log("Saved /tmp/olmo_logits.bin")
// Parse logits from bytes for detailed analysis
logitsBytes := result.Bytes()
vocabSize := result.Shape()[0]
// Read float32 values - shape is (vocab_size, seq_len)
allLogits := make([]float32, len(logitsBytes)/4)
for i := range allLogits {
bits := binary.LittleEndian.Uint32(logitsBytes[i*4 : (i+1)*4])
allLogits[i] = math.Float32frombits(bits)
}
// Create detailed text dump matching Python format
var sb strings.Builder
sb.WriteString("# Full Forward Logits\n\n")
sb.WriteString(fmt.Sprintf("# Shape: [1, %d, %d]\n", seqLen, vocabSize))
sb.WriteString(fmt.Sprintf("# Layout: (batch=1, seq_len=%d, vocab_size=%d)\n", seqLen, vocabSize))
sb.WriteString(fmt.Sprintf("# Prompt: '%s'\n", prompt))
sb.WriteString(fmt.Sprintf("# Tokens: %v\n\n", tokens))
type logitPair struct {
tokenID int
value float32
}
// Process each position
for pos := 0; pos < seqLen; pos++ {
// Extract logits for this position
// Shape is (vocab_size, seq_len), so logits[v*seqLen + pos] gives logit for vocab v at position pos
posLogits := make([]float32, vocabSize)
for v := 0; v < vocabSize; v++ {
posLogits[v] = allLogits[v*seqLen+pos]
}
// Find top 10 logits
pairs := make([]logitPair, len(posLogits))
for i, v := range posLogits {
pairs[i] = logitPair{tokenID: i, value: v}
}
// Sort by value descending (simple bubble sort for small top-k)
for i := 0; i < min(10, len(pairs)); i++ {
for j := i + 1; j < len(pairs); j++ {
if pairs[j].value > pairs[i].value {
pairs[i], pairs[j] = pairs[j], pairs[i]
}
}
}
tokenStr, _ := tp.Decode([]int32{tokens[pos]})
sb.WriteString(fmt.Sprintf("Position %d (token_id=%d, token='%s'):\n", pos, tokens[pos], tokenStr))
sb.WriteString(" Top 10 logits:\n")
for i := 0; i < min(10, len(pairs)); i++ {
tokStr, _ := tp.Decode([]int32{int32(pairs[i].tokenID)})
// Pad token string to 20 chars for alignment
paddedTok := fmt.Sprintf("%-20s", fmt.Sprintf("'%s'", tokStr))
sb.WriteString(fmt.Sprintf(" %d. token_id=%6d (%s): %f\n", i+1, pairs[i].tokenID, paddedTok, pairs[i].value))
}
// First 20 logits
sb.WriteString(" Full logits (first 20): [")
for i := 0; i < min(20, len(posLogits)); i++ {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%v", posLogits[i]))
}
sb.WriteString("]\n")
// Last 20 logits
sb.WriteString(" Full logits (last 20): [")
start := max(0, len(posLogits)-20)
for i := start; i < len(posLogits); i++ {
if i > start {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%v", posLogits[i]))
}
sb.WriteString("]\n\n")
}
if err := os.WriteFile("/tmp/olmo_logits.txt", []byte(sb.String()), 0644); err != nil {
t.Fatal(err)
}
t.Log("Saved /tmp/olmo_logits.txt")
// Print to console as well
fmt.Println(sb.String())
}
func TestRoPE(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
olmoModel := m.(*Model)
// Test RoPE on a simple tensor
headDim := olmoModel.hiddenSize / olmoModel.numHeads
batchSize := 4
numHeads := olmoModel.numHeads
t.Logf("headDim: %d, numHeads: %d", headDim, numHeads)
t.Logf("ropeBase: %f, ropeScale: %f, originalContextLength: %d",
olmoModel.ropeBase, olmoModel.ropeScale, olmoModel.originalContextLength)
// Create test query tensor: (headDim, numHeads, batchSize)
queryFloats := make([]float32, headDim*numHeads*batchSize)
for i := range queryFloats {
queryFloats[i] = float32(i%100) / 100.0
}
// Test 1: Dump initial query values (fresh context)
{
ctx := m.Backend().NewContext()
query := ctx.Input().FromFloats(queryFloats, headDim, numHeads, batchSize)
t.Logf("query shape: %v", query.Shape())
query = query.Contiguous(ctx)
ctx.Forward(query).Compute(query)
dump := ml.Dump(ctx, query, ml.DumpWithPrecision(6), ml.DumpWithThreshold(1000000))
t.Logf("Query BEFORE RoPE sample values: %s", dump[:min(500, len(dump))])
// Write to file
header := fmt.Sprintf("Shape: %v\nDType: %v\n\n", query.Shape(), query.DType())
if err := os.WriteFile("/tmp/olmo_query_before_rope.txt", []byte(header+dump), 0644); err != nil {
t.Errorf("Failed to write file: %v", err)
}
if err := os.WriteFile("/tmp/olmo_query_before_rope.bin", query.Bytes(), 0644); err != nil {
t.Errorf("Failed to write binary file: %v", err)
}
t.Log("Wrote /tmp/olmo_query_before_rope.txt and .bin")
}
// Test 2: SWA RoPE (fresh context)
{
ctx := m.Backend().NewContext()
query := ctx.Input().FromFloats(queryFloats, headDim, numHeads, batchSize)
positions := ctx.Input().FromInts([]int32{0, 1, 2, 3}, batchSize)
resultSWA := olmoModel.applyRoPE(ctx, query, positions, headDim, true)
resultSWA = resultSWA.Contiguous(ctx)
ctx.Forward(resultSWA).Compute(resultSWA)
t.Logf("SWA RoPE result shape: %v", resultSWA.Shape())
dump := ml.Dump(ctx, resultSWA, ml.DumpWithPrecision(6), ml.DumpWithThreshold(1000000))
t.Logf("Query AFTER SWA RoPE sample values: %s", dump[:min(500, len(dump))])
// Write to file
header := fmt.Sprintf("Shape: %v\nDType: %v\nfreqScale: 1.0 (SWA)\n\n", resultSWA.Shape(), resultSWA.DType())
if err := os.WriteFile("/tmp/olmo_query_after_swa_rope.txt", []byte(header+dump), 0644); err != nil {
t.Errorf("Failed to write file: %v", err)
}
if err := os.WriteFile("/tmp/olmo_query_after_swa_rope.bin", resultSWA.Bytes(), 0644); err != nil {
t.Errorf("Failed to write binary file: %v", err)
}
t.Log("Wrote /tmp/olmo_query_after_swa_rope.txt and .bin")
}
// Test 3: Global (non-SWA) RoPE (fresh context)
{
ctx := m.Backend().NewContext()
query := ctx.Input().FromFloats(queryFloats, headDim, numHeads, batchSize)
positions := ctx.Input().FromInts([]int32{0, 1, 2, 3}, batchSize)
resultGlobal := olmoModel.applyRoPE(ctx, query, positions, headDim, false)
resultGlobal = resultGlobal.Contiguous(ctx)
ctx.Forward(resultGlobal).Compute(resultGlobal)
t.Logf("Global RoPE result shape: %v", resultGlobal.Shape())
dump := ml.Dump(ctx, resultGlobal, ml.DumpWithPrecision(6), ml.DumpWithThreshold(1000000))
t.Logf("Query AFTER Global RoPE sample values: %s", dump[:min(500, len(dump))])
// Write to file
header := fmt.Sprintf("Shape: %v\nDType: %v\nfreqScale: %f (Global, 1/ropeScale)\n\n",
resultGlobal.Shape(), resultGlobal.DType(), 1.0/olmoModel.ropeScale)
if err := os.WriteFile("/tmp/olmo_query_after_global_rope.txt", []byte(header+dump), 0644); err != nil {
t.Errorf("Failed to write file: %v", err)
}
if err := os.WriteFile("/tmp/olmo_query_after_global_rope.bin", resultGlobal.Bytes(), 0644); err != nil {
t.Errorf("Failed to write binary file: %v", err)
}
t.Log("Wrote /tmp/olmo_query_after_global_rope.txt and .bin")
}
}

469
model/parsers/olmo3.go Normal file
View File

@@ -0,0 +1,469 @@
package parsers
import (
"context"
"fmt"
"log/slog"
"regexp"
"strconv"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type olmo3ParserState int
const (
olmo3StateContent olmo3ParserState = iota
olmo3StateToolCalls
olmo3StateToolCallsDone
)
const (
olmo3FuncCallsOpenTag = "<function_calls>"
olmo3FuncCallsCloseTag = "</function_calls>"
)
type Olmo3Parser struct {
state olmo3ParserState
buffer strings.Builder
}
func (p *Olmo3Parser) HasToolSupport() bool {
return true
}
func (p *Olmo3Parser) HasThinkingSupport() bool {
return false
}
func (p *Olmo3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.state = olmo3StateContent
return tools
}
type olmo3ParserEvent interface {
isOlmo3ParserEvent()
}
type olmo3ParserEventContent struct {
content string
}
type olmo3ParserEventToolCalls struct {
calls []api.ToolCall
}
func (olmo3ParserEventContent) isOlmo3ParserEvent() {}
func (olmo3ParserEventToolCalls) isOlmo3ParserEvent() {}
func (p *Olmo3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
if done {
// Drain any remaining content
bufStr := p.buffer.String()
p.buffer.Reset()
if p.state == olmo3StateContent && len(bufStr) > 0 {
return bufStr, "", nil, nil
}
return "", "", nil, nil
}
events := p.parseEvents()
var contentSb strings.Builder
var allCalls []api.ToolCall
for _, event := range events {
switch event := event.(type) {
case olmo3ParserEventContent:
contentSb.WriteString(event.content)
case olmo3ParserEventToolCalls:
allCalls = append(allCalls, event.calls...)
}
}
return contentSb.String(), "", allCalls, nil
}
func (p *Olmo3Parser) parseEvents() []olmo3ParserEvent {
var all []olmo3ParserEvent
keepLooping := true
for keepLooping {
var events []olmo3ParserEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "olmo3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
func (p *Olmo3Parser) eat() ([]olmo3ParserEvent, bool) {
var events []olmo3ParserEvent
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case olmo3StateContent:
if strings.Contains(bufStr, olmo3FuncCallsOpenTag) {
// Found <function_calls> tag
split := strings.SplitN(bufStr, olmo3FuncCallsOpenTag, 2)
content := split[0]
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = olmo3StateToolCalls
if len(content) > 0 {
events = append(events, olmo3ParserEventContent{content: content})
}
return events, true
} else if overlapLen := overlap(bufStr, olmo3FuncCallsOpenTag); overlapLen > 0 {
// Partial <function_calls> tag - withhold ambiguous content
unambiguous := bufStr[:len(bufStr)-overlapLen]
ambiguous := bufStr[len(bufStr)-overlapLen:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, olmo3ParserEventContent{content: unambiguous})
}
return events, false
} else {
// Regular content - emit all
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, olmo3ParserEventContent{content: bufStr})
}
return events, false
}
case olmo3StateToolCalls:
if strings.Contains(bufStr, olmo3FuncCallsCloseTag) {
// Found </function_calls> tag
split := strings.SplitN(bufStr, olmo3FuncCallsCloseTag, 2)
toolCallsStr := split[0]
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = olmo3StateToolCallsDone
// Parse the function calls
calls, err := parseOlmo3FunctionCalls(toolCallsStr)
if err != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "failed to parse olmo3 function calls", "error", err, "content", toolCallsStr)
} else if len(calls) > 0 {
events = append(events, olmo3ParserEventToolCalls{calls: calls})
}
return events, true
} else if overlapLen := overlap(bufStr, olmo3FuncCallsCloseTag); overlapLen > 0 {
// Partial </function_calls> tag - wait for more
return events, false
}
// Still collecting tool calls, wait for close tag
return events, false
case olmo3StateToolCallsDone:
// After tool calls, emit remaining content
p.buffer.Reset()
p.state = olmo3StateContent
if len(bufStr) > 0 {
events = append(events, olmo3ParserEventContent{content: bufStr})
}
return events, false
}
return events, false
}
// parseOlmo3FunctionCalls parses function calls in Python-esque format:
// func_name(arg1="value1", arg2=123)
// Multiple calls are separated by newlines
func parseOlmo3FunctionCalls(s string) ([]api.ToolCall, error) {
var calls []api.ToolCall
s = strings.TrimSpace(s)
if s == "" {
return calls, nil
}
// Split by newlines for multiple function calls
lines := strings.Split(s, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
call, err := parseOlmo3SingleFunctionCall(line)
if err != nil {
return nil, fmt.Errorf("failed to parse function call %q: %w", line, err)
}
calls = append(calls, call)
}
return calls, nil
}
// Regex to match function call: func_name(args)
var funcCallRegex = regexp.MustCompile(`^(\w+)\((.*)\)$`)
// Regex to match a single argument: key=value
// Value can be: "string", 'string', number, true, false, null, or nested structures
var argRegex = regexp.MustCompile(`^(\w+)=(.+)$`)
func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) {
matches := funcCallRegex.FindStringSubmatch(s)
if matches == nil {
return api.ToolCall{}, fmt.Errorf("invalid function call format")
}
funcName := matches[1]
argsStr := matches[2]
args, err := parseOlmo3Arguments(argsStr)
if err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse arguments: %w", err)
}
return api.ToolCall{
Function: api.ToolCallFunction{
Name: funcName,
Arguments: args,
},
}, nil
}
// parseOlmo3Arguments parses comma-separated key=value pairs
// Handles nested parentheses, brackets, braces, and quoted strings
func parseOlmo3Arguments(s string) (map[string]any, error) {
args := make(map[string]any)
s = strings.TrimSpace(s)
if s == "" {
return args, nil
}
// Split by commas, but respect nested structures and quotes
parts := splitArguments(s)
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
// Find the first = sign
eqIdx := strings.Index(part, "=")
if eqIdx == -1 {
return nil, fmt.Errorf("invalid argument format: %s", part)
}
key := strings.TrimSpace(part[:eqIdx])
valueStr := strings.TrimSpace(part[eqIdx+1:])
value, err := parseOlmo3Value(valueStr)
if err != nil {
return nil, fmt.Errorf("failed to parse value for %s: %w", key, err)
}
args[key] = value
}
return args, nil
}
// splitArguments splits arguments by commas, respecting quotes and nested structures
func splitArguments(s string) []string {
var parts []string
var current strings.Builder
depth := 0
inString := false
stringChar := byte(0)
escaped := false
for i := 0; i < len(s); i++ {
c := s[i]
if escaped {
current.WriteByte(c)
escaped = false
continue
}
if c == '\\' && inString {
current.WriteByte(c)
escaped = true
continue
}
if (c == '"' || c == '\'') && !inString {
inString = true
stringChar = c
current.WriteByte(c)
continue
}
if c == stringChar && inString {
inString = false
stringChar = 0
current.WriteByte(c)
continue
}
if !inString {
switch c {
case '(', '[', '{':
depth++
current.WriteByte(c)
case ')', ']', '}':
depth--
current.WriteByte(c)
case ',':
if depth == 0 {
parts = append(parts, current.String())
current.Reset()
continue
}
current.WriteByte(c)
default:
current.WriteByte(c)
}
} else {
current.WriteByte(c)
}
}
if current.Len() > 0 {
parts = append(parts, current.String())
}
return parts
}
// parseOlmo3Value parses a value which can be a string, number, boolean, null, array, or object
func parseOlmo3Value(s string) (any, error) {
s = strings.TrimSpace(s)
// Check for quoted string
if (strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`)) ||
(strings.HasPrefix(s, `'`) && strings.HasSuffix(s, `'`)) {
// Remove quotes and unescape
inner := s[1 : len(s)-1]
return unescapeString(inner), nil
}
// Check for boolean
if s == "true" || s == "True" {
return true, nil
}
if s == "false" || s == "False" {
return false, nil
}
// Check for null/None
if s == "null" || s == "None" || s == "nil" {
return nil, nil
}
// Check for number
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
return i, nil
}
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f, nil
}
// Check for array [...]
if strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]") {
return parseOlmo3Array(s[1 : len(s)-1])
}
// Check for object {...}
if strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}") {
return parseOlmo3Object(s[1 : len(s)-1])
}
// Default to string without quotes
return s, nil
}
func parseOlmo3Array(s string) ([]any, error) {
s = strings.TrimSpace(s)
if s == "" {
return []any{}, nil
}
parts := splitArguments(s)
var arr []any
for _, part := range parts {
val, err := parseOlmo3Value(part)
if err != nil {
return nil, err
}
arr = append(arr, val)
}
return arr, nil
}
func parseOlmo3Object(s string) (map[string]any, error) {
s = strings.TrimSpace(s)
if s == "" {
return map[string]any{}, nil
}
// Objects use key: value or "key": value format
obj := make(map[string]any)
parts := splitArguments(s)
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
// Find colon separator
colonIdx := strings.Index(part, ":")
if colonIdx == -1 {
return nil, fmt.Errorf("invalid object entry: %s", part)
}
keyStr := strings.TrimSpace(part[:colonIdx])
valueStr := strings.TrimSpace(part[colonIdx+1:])
// Remove quotes from key if present
if (strings.HasPrefix(keyStr, `"`) && strings.HasSuffix(keyStr, `"`)) ||
(strings.HasPrefix(keyStr, `'`) && strings.HasSuffix(keyStr, `'`)) {
keyStr = keyStr[1 : len(keyStr)-1]
}
val, err := parseOlmo3Value(valueStr)
if err != nil {
return nil, fmt.Errorf("failed to parse value for key %s: %w", keyStr, err)
}
obj[keyStr] = val
}
return obj, nil
}
func unescapeString(s string) string {
// Handle common escape sequences
s = strings.ReplaceAll(s, `\\`, "\x00") // Placeholder for backslash
s = strings.ReplaceAll(s, `\"`, `"`)
s = strings.ReplaceAll(s, `\'`, `'`)
s = strings.ReplaceAll(s, `\n`, "\n")
s = strings.ReplaceAll(s, `\t`, "\t")
s = strings.ReplaceAll(s, `\r`, "\r")
s = strings.ReplaceAll(s, "\x00", `\`) // Restore backslash
return s
}

483
model/parsers/olmo3_test.go Normal file
View File

@@ -0,0 +1,483 @@
package parsers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestOlmo3Parser(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
expectedCalls []api.ToolCall
}{
{
name: "simple content",
input: "Hello, how can I help you?",
expectedContent: "Hello, how can I help you?",
},
{
name: "simple tool call",
input: `<function_calls>get_weather(location="San Francisco")</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "San Francisco"},
},
},
},
},
{
name: "content then tool call",
input: `Let me check the weather.<function_calls>get_weather(location="NYC")</function_calls>`,
expectedContent: "Let me check the weather.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "NYC"},
},
},
},
},
{
name: "tool call with multiple arguments",
input: `<function_calls>book_flight(from="SFO", to="NYC", date="2024-01-15")</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "book_flight",
Arguments: map[string]any{
"from": "SFO",
"to": "NYC",
"date": "2024-01-15",
},
},
},
},
},
{
name: "multiple tool calls",
input: `<function_calls>get_weather(location="San Francisco")
get_weather(location="New York")</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "San Francisco"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "New York"},
},
},
},
},
{
name: "tool call with numeric argument",
input: `<function_calls>set_temperature(value=72)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "set_temperature",
Arguments: map[string]any{"value": int64(72)},
},
},
},
},
{
name: "tool call with float argument",
input: `<function_calls>set_price(amount=19.99)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "set_price",
Arguments: map[string]any{"amount": 19.99},
},
},
},
},
{
name: "tool call with boolean argument",
input: `<function_calls>toggle_setting(enabled=true)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "toggle_setting",
Arguments: map[string]any{"enabled": true},
},
},
},
},
{
name: "tool call with null argument",
input: `<function_calls>clear_value(field=null)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "clear_value",
Arguments: map[string]any{"field": nil},
},
},
},
},
{
name: "tool call with array argument",
input: `<function_calls>process_items(items=["apple", "banana", "cherry"])</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process_items",
Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}},
},
},
},
},
{
name: "tool call with dict argument",
input: `<function_calls>update_config(settings={"theme": "dark", "fontSize": 14})</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "update_config",
Arguments: map[string]any{
"settings": map[string]any{
"theme": "dark",
"fontSize": int64(14),
},
},
},
},
},
},
{
name: "tool call with nested dict",
input: `<function_calls>create_request(data={"user": {"name": "John", "age": 30}, "active": true})</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "create_request",
Arguments: map[string]any{
"data": map[string]any{
"user": map[string]any{
"name": "John",
"age": int64(30),
},
"active": true,
},
},
},
},
},
},
{
name: "tool call with no arguments",
input: `<function_calls>get_current_time()</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_time",
Arguments: map[string]any{},
},
},
},
},
{
name: "tool call with single quotes",
input: `<function_calls>search(query='hello world')</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: map[string]any{"query": "hello world"},
},
},
},
},
{
name: "tool call with escaped quotes",
input: `<function_calls>search(query="say \"hello\"")</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: map[string]any{"query": `say "hello"`},
},
},
},
},
{
name: "tool call with mixed argument types",
input: `<function_calls>create_user(name="John", age=30, active=true)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "create_user",
Arguments: map[string]any{
"name": "John",
"age": int64(30),
"active": true,
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Olmo3Parser{}
p.Init(nil, nil, nil)
content, thinking, calls, err := p.Add(tt.input, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Drain remaining content
finalContent, finalThinking, finalCalls, err := p.Add("", true)
if err != nil {
t.Fatalf("unexpected error on done: %v", err)
}
content += finalContent
thinking += finalThinking
calls = append(calls, finalCalls...)
if diff := cmp.Diff(content, tt.expectedContent); diff != "" {
t.Errorf("content mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestOlmo3Parser_Streaming(t *testing.T) {
tests := []struct {
name string
chunks []string
expectedContent string
expectedCalls []api.ToolCall
}{
{
name: "streaming content",
chunks: []string{"Hello, ", "how ", "can I help?"},
expectedContent: "Hello, how can I help?",
},
{
name: "streaming tool call",
chunks: []string{"<function_", "calls>get_weather", "(location=\"SF\")", "</function_calls>"},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "SF"},
},
},
},
},
{
name: "streaming content then tool call",
chunks: []string{"Let me check.", "<function_calls>", "get_weather(location=\"NYC\")", "</function_calls>"},
expectedContent: "Let me check.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "NYC"},
},
},
},
},
{
name: "tool call tag split across chunks",
chunks: []string{"<func", "tion_calls>test()</function_calls>"},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: map[string]any{},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Olmo3Parser{}
p.Init(nil, nil, nil)
var allContent string
var allCalls []api.ToolCall
for _, chunk := range tt.chunks {
content, _, calls, err := p.Add(chunk, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
allContent += content
allCalls = append(allCalls, calls...)
}
// Drain
content, _, calls, err := p.Add("", true)
if err != nil {
t.Fatalf("unexpected error on done: %v", err)
}
allContent += content
allCalls = append(allCalls, calls...)
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
t.Errorf("content mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestOlmo3Parser_HasToolSupport(t *testing.T) {
p := &Olmo3Parser{}
if !p.HasToolSupport() {
t.Error("expected HasToolSupport to return true")
}
}
func TestOlmo3Parser_HasThinkingSupport(t *testing.T) {
p := &Olmo3Parser{}
if p.HasThinkingSupport() {
t.Error("expected HasThinkingSupport to return false")
}
}
func TestParseOlmo3FunctionCalls(t *testing.T) {
tests := []struct {
name string
input string
expected []api.ToolCall
wantErr bool
}{
{
name: "simple call",
input: `get_weather(location="SF")`,
expected: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "SF"},
},
},
},
},
{
name: "multiple args",
input: `send_email(to="user@example.com", subject="Hello", body="Test message")`,
expected: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "send_email",
Arguments: map[string]any{
"to": "user@example.com",
"subject": "Hello",
"body": "Test message",
},
},
},
},
},
{
name: "multiple calls with newlines",
input: `get_weather(location="SF")
get_time(timezone="PST")`,
expected: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "SF"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_time",
Arguments: map[string]any{"timezone": "PST"},
},
},
},
},
{
name: "empty input",
input: "",
expected: nil,
},
{
name: "whitespace only",
input: " \n ",
expected: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
calls, err := parseOlmo3FunctionCalls(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(calls, tt.expected); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestParseOlmo3Value(t *testing.T) {
tests := []struct {
name string
input string
expected any
}{
{"string double quotes", `"hello"`, "hello"},
{"string single quotes", `'hello'`, "hello"},
{"integer", "42", int64(42)},
{"negative integer", "-10", int64(-10)},
{"float", "3.14", 3.14},
{"boolean true", "true", true},
{"boolean True", "True", true},
{"boolean false", "false", false},
{"null", "null", nil},
{"None", "None", nil},
{"empty array", "[]", []any{}},
{"array with strings", `["a", "b"]`, []any{"a", "b"}},
{"array with numbers", "[1, 2, 3]", []any{int64(1), int64(2), int64(3)}},
{"empty object", "{}", map[string]any{}},
{"simple object", `{"name": "John"}`, map[string]any{"name": "John"}},
{"object with number", `{"age": 30}`, map[string]any{"age": int64(30)}},
{"object with multiple keys", `{"a": 1, "b": 2}`, map[string]any{"a": int64(1), "b": int64(2)}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseOlmo3Value(tt.input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if diff := cmp.Diff(result, tt.expected); diff != "" {
t.Errorf("value mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -60,6 +60,8 @@ func ParserForName(name string) Parser {
return &CogitoParser{}
case "olmo3-think":
return &Olmo3ThinkParser{}
case "olmo3":
return &Olmo3Parser{}
default:
return nil
}

148
model/renderers/olmo3.go Normal file
View File

@@ -0,0 +1,148 @@
package renderers
import (
"encoding/json"
"fmt"
"sort"
"strings"
"github.com/ollama/ollama/api"
)
const (
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
)
type Olmo3Renderer struct{}
func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
var sb strings.Builder
var systemMessage *api.Message
filteredMessages := make([]api.Message, 0, len(messages))
for i, message := range messages {
if message.Role == "system" {
if systemMessage == nil {
systemMessage = &messages[i]
}
continue
}
filteredMessages = append(filteredMessages, message)
}
// Render system message
if systemMessage != nil {
// Custom system message - single newline after "system"
sb.WriteString("<|im_start|>system\n")
sb.WriteString(systemMessage.Content)
if len(tools) > 0 {
functionsJSON, err := marshalWithSpaces(tools)
if err != nil {
return "", err
}
sb.WriteString("<functions>")
sb.WriteString(string(functionsJSON))
sb.WriteString("</functions>")
}
sb.WriteString("<|im_end|>\n")
} else {
// Default system message - single newline after "system"
sb.WriteString("<|im_start|>system\n")
sb.WriteString(olmo3DefaultSystemMessage)
if len(tools) > 0 {
functionsJSON, err := marshalWithSpaces(tools)
if err != nil {
return "", err
}
sb.WriteString(olmo3WithFunctionsMessage)
sb.WriteString("<functions>")
sb.WriteString(string(functionsJSON))
sb.WriteString("</functions>")
} else {
sb.WriteString(olmo3NoFunctionsMessage)
sb.WriteString("<functions></functions>")
}
sb.WriteString("<|im_end|>\n")
}
for i, message := range filteredMessages {
lastMessage := i == len(filteredMessages)-1
switch message.Role {
case "user":
sb.WriteString("<|im_start|>user\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
case "assistant":
sb.WriteString("<|im_start|>assistant\n")
if message.Content != "" {
sb.WriteString(message.Content)
}
if len(message.ToolCalls) > 0 {
sb.WriteString("<function_calls>")
for j, tc := range message.ToolCalls {
// Format as function_name(arg1="value1", arg2="value2")
sb.WriteString(tc.Function.Name)
sb.WriteString("(")
// Get sorted keys for deterministic output
keys := make([]string, 0, len(tc.Function.Arguments))
for k := range tc.Function.Arguments {
keys = append(keys, k)
}
sort.Strings(keys)
for k, key := range keys {
if k > 0 {
sb.WriteString(", ")
}
value, err := json.Marshal(tc.Function.Arguments[key])
if err != nil {
return "", err
}
sb.WriteString(fmt.Sprintf("%s=%s", key, string(value)))
}
sb.WriteString(")")
if j < len(message.ToolCalls)-1 {
sb.WriteString("\n")
}
}
sb.WriteString("</function_calls>")
}
// Add end tag unless it's the last message with content only (prefill)
if !lastMessage || len(message.ToolCalls) > 0 {
sb.WriteString("<|im_end|>\n")
}
case "tool":
sb.WriteString("<|im_start|>environment\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
}
}
// Add generation prompt if needed
needsGenerationPrompt := true
if len(filteredMessages) > 0 {
lastMsg := filteredMessages[len(filteredMessages)-1]
if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" {
needsGenerationPrompt = false
}
}
if needsGenerationPrompt {
sb.WriteString("<|im_start|>assistant\n\n")
}
return sb.String(), nil
}

View File

@@ -0,0 +1,290 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestOlmo3Renderer(t *testing.T) {
tests := []struct {
name string
msgs []api.Message
tools []api.Tool
expected string
}{
{
name: "basic without system - adds default system",
msgs: []api.Message{
{Role: "user", Content: "Hello!"},
},
expected: "<|im_start|>system\n" +
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"<|im_start|>user\n" +
"Hello!<|im_end|>\n" +
"<|im_start|>assistant\n\n",
},
{
name: "with system message no tools",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello!"},
},
expected: "<|im_start|>system\n" +
"You are a helpful assistant.<|im_end|>\n" +
"<|im_start|>user\n" +
"Hello!<|im_end|>\n" +
"<|im_start|>assistant\n\n",
},
{
name: "with system message and tools",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "What is the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
},
},
},
},
},
expected: "<|im_start|>system\n" +
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"What is the weather?<|im_end|>\n" +
"<|im_start|>assistant\n\n",
},
{
name: "default system with tools - includes function instruction",
msgs: []api.Message{
{Role: "user", Content: "What is the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
},
},
},
},
},
expected: "<|im_start|>system\n" +
"You are a helpful function-calling AI assistant. " +
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
`<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"What is the weather?<|im_end|>\n" +
"<|im_start|>assistant\n\n",
},
{
name: "assistant with tool calls - function call syntax",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "What is the weather in SF?"},
{
Role: "assistant",
Content: "Let me check the weather.",
ToolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{
"location": "San Francisco",
},
},
},
},
},
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
},
},
},
},
},
expected: "<|im_start|>system\n" +
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"What is the weather in SF?<|im_end|>\n" +
"<|im_start|>assistant\n" +
`Let me check the weather.<function_calls>get_weather(location="San Francisco")</function_calls><|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 68}<|im_end|>` + "\n" +
"<|im_start|>assistant\n\n",
},
{
name: "multi-turn conversation",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "How are you?"},
},
expected: "<|im_start|>system\n" +
"You are a helpful assistant.<|im_end|>\n" +
"<|im_start|>user\n" +
"Hello<|im_end|>\n" +
"<|im_start|>assistant\n" +
"Hi there!<|im_end|>\n" +
"<|im_start|>user\n" +
"How are you?<|im_end|>\n" +
"<|im_start|>assistant\n\n",
},
{
name: "parallel tool calls - newline separated",
msgs: []api.Message{
{Role: "user", Content: "Get weather in SF and NYC"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "San Francisco"},
},
},
{
ID: "call_2",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "New York"},
},
},
},
},
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
{Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
expected: "<|im_start|>system\n" +
"You are a helpful function-calling AI assistant. " +
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
`<functions>[{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"Get weather in SF and NYC<|im_end|>\n" +
"<|im_start|>assistant\n" +
`<function_calls>get_weather(location="San Francisco")` + "\n" +
`get_weather(location="New York")</function_calls><|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 68}<|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 55}<|im_end|>` + "\n" +
"<|im_start|>assistant\n\n",
},
{
name: "tool call with multiple arguments",
msgs: []api.Message{
{Role: "user", Content: "Book a flight"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "book_flight",
Arguments: map[string]any{
"from": "SFO",
"to": "NYC",
},
},
},
},
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "book_flight",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"from": {Type: api.PropertyType{"string"}},
"to": {Type: api.PropertyType{"string"}},
},
},
},
},
},
expected: "<|im_start|>system\n" +
"You are a helpful function-calling AI assistant. " +
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
`<functions>[{"type": "function", "function": {"name": "book_flight", "parameters": {"type": "object", "properties": {"from": {"type": "string"}, "to": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"Book a flight<|im_end|>\n" +
"<|im_start|>assistant\n" +
`<function_calls>book_flight(from="SFO", to="NYC")</function_calls><|im_end|>` + "\n" +
"<|im_start|>assistant\n\n",
},
{
name: "assistant prefill - no generation prompt",
msgs: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
},
expected: "<|im_start|>system\n" +
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"<|im_start|>user\n" +
"Hello<|im_end|>\n" +
"<|im_start|>assistant\n" +
"Hi there!",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := (&Olmo3Renderer{}).Render(tt.msgs, tt.tools, nil)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -62,6 +62,9 @@ func rendererForName(name string) Renderer {
case "olmo3-think":
renderer := &Olmo3ThinkRenderer{}
return renderer
case "olmo3":
renderer := &Olmo3Renderer{}
return renderer
default:
return nil
}