mirror of
https://github.com/ollama/ollama.git
synced 2026-02-07 14:13:29 -05:00
Compare commits
24 Commits
brucemacd/
...
parth/opt-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
beaa0e82f3 | ||
|
|
8bcb3125c1 | ||
|
|
6baf1e31e2 | ||
|
|
ed567ef43b | ||
|
|
a6e64fbdf2 | ||
|
|
60cfa2a203 | ||
|
|
55bbf3b4a1 | ||
|
|
6bda1d2479 | ||
|
|
9e125d884c | ||
|
|
a6fbfc880c | ||
|
|
502028968d | ||
|
|
5a8eb0e151 | ||
|
|
9f8a18ec05 | ||
|
|
6b04cad7e8 | ||
|
|
45f56355d5 | ||
|
|
0dabb4ef6a | ||
|
|
2e77aa1ae7 | ||
|
|
deaabe292d | ||
|
|
af21a5ac39 | ||
|
|
f63d7f68eb | ||
|
|
82ad1dbc07 | ||
|
|
feeabdadd2 | ||
|
|
fc0309615e | ||
|
|
09d308d6b6 |
@@ -40,10 +40,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
||||
|
||||
## Quickstart
|
||||
|
||||
To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2):
|
||||
To run and chat with [Gemma 3](https://ollama.com/library/gemma3):
|
||||
|
||||
```shell
|
||||
ollama run llama3.2
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
## Model library
|
||||
@@ -407,6 +407,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -451,6 +453,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
|
||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
|
||||
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
|
||||
15
api/types.go
15
api/types.go
@@ -285,6 +285,7 @@ type Options struct {
|
||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
ShiftContext bool `json:"shift_context,omitempty"`
|
||||
}
|
||||
|
||||
// Runner options which must be set when the model is loaded into memory
|
||||
@@ -457,13 +458,12 @@ type ProcessResponse struct {
|
||||
|
||||
// ListModelResponse is a single model description in [ListResponse].
|
||||
type ListModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessModelResponse is a single model description in [ProcessResponse].
|
||||
@@ -664,6 +664,7 @@ func DefaultOptions() Options {
|
||||
PresencePenalty: 0.0,
|
||||
FrequencyPenalty: 0.0,
|
||||
Seed: -1,
|
||||
ShiftContext: true,
|
||||
|
||||
Runner: Runner{
|
||||
// options set when the model is loaded
|
||||
|
||||
@@ -1,178 +0,0 @@
|
||||
package benchmark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Command line flags
|
||||
var modelFlag string
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
||||
flag.Lookup("m").DefValue = "model"
|
||||
}
|
||||
|
||||
// modelName returns the model name from flags, failing the test if not set
|
||||
func modelName(b *testing.B) string {
|
||||
if modelFlag == "" {
|
||||
b.Fatal("Error: -m flag is required for benchmark tests")
|
||||
}
|
||||
return modelFlag
|
||||
}
|
||||
|
||||
type TestCase struct {
|
||||
name string
|
||||
prompt string
|
||||
maxTokens int
|
||||
}
|
||||
|
||||
// runGenerateBenchmark contains the common generate and metrics logic
|
||||
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
||||
start := time.Now()
|
||||
var ttft time.Duration
|
||||
var metrics api.Metrics
|
||||
|
||||
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
if ttft == 0 && resp.Response != "" {
|
||||
ttft = time.Since(start)
|
||||
}
|
||||
if resp.Done {
|
||||
metrics = resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Report custom metrics as part of the benchmark results
|
||||
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
||||
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
||||
|
||||
// Token throughput metrics
|
||||
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
||||
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
||||
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
||||
b.ReportMetric(genThroughput, "gen_tok/s")
|
||||
|
||||
// Token counts
|
||||
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
||||
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
||||
func BenchmarkColdStart(b *testing.B) {
|
||||
client := setup(b)
|
||||
tests := []TestCase{
|
||||
{"short_prompt", "Write a long story", 100},
|
||||
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||
}
|
||||
m := modelName(b)
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
|
||||
// Set number of tokens as our throughput metric
|
||||
b.SetBytes(int64(tt.maxTokens))
|
||||
|
||||
for b.Loop() {
|
||||
b.StopTimer()
|
||||
// Ensure model is unloaded before each iteration
|
||||
unload(client, m, b)
|
||||
b.StartTimer()
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: m,
|
||||
Prompt: tt.prompt,
|
||||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
}
|
||||
|
||||
runGenerateBenchmark(b, ctx, client, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
||||
func BenchmarkWarmStart(b *testing.B) {
|
||||
client := setup(b)
|
||||
tests := []TestCase{
|
||||
{"short_prompt", "Write a long story", 100},
|
||||
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||
}
|
||||
m := modelName(b)
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
|
||||
// Pre-warm the model
|
||||
warmup(client, m, tt.prompt, b)
|
||||
|
||||
// Set number of tokens as our throughput metric
|
||||
b.SetBytes(int64(tt.maxTokens))
|
||||
|
||||
for b.Loop() {
|
||||
req := &api.GenerateRequest{
|
||||
Model: m,
|
||||
Prompt: tt.prompt,
|
||||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
}
|
||||
|
||||
runGenerateBenchmark(b, ctx, client, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setup verifies server and model availability
|
||||
func setup(b *testing.B) *api.Client {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := client.Show(b.Context(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||
b.Fatalf("Model unavailable: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// warmup ensures the model is loaded and warmed up
|
||||
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||
for range 3 {
|
||||
err := client.Generate(
|
||||
context.Background(),
|
||||
&api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
|
||||
},
|
||||
func(api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
b.Logf("Error during model warm-up: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// unload forces model unloading using KeepAlive: 0 parameter
|
||||
func unload(client *api.Client, model string, b *testing.B) {
|
||||
req := &api.GenerateRequest{
|
||||
Model: model,
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
}
|
||||
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
||||
b.Logf("Unload error: %v", err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"errors"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"regexp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -19,11 +19,12 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !strings.Contains(link, "Ollama.app") {
|
||||
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
|
||||
m := r.FindStringSubmatch(link)
|
||||
if len(m) != 1 {
|
||||
return errors.New("could not find ollama app")
|
||||
}
|
||||
path := strings.Split(link, "Ollama.app")
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", path[0]+"Ollama.app").Run(); err != nil {
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
return waitForServer(ctx, client)
|
||||
|
||||
@@ -47,7 +47,7 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||
}
|
||||
|
||||
cmd_path := "c:\\Windows\\system32\\cmd.exe"
|
||||
cmd := exec.Command(cmd_path, "/c", appExe, "hidden")
|
||||
cmd := exec.Command(cmd_path, "/c", appExe, "--hide", "--fast-startup")
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true}
|
||||
|
||||
cmd.Stdin = strings.NewReader("")
|
||||
|
||||
@@ -65,17 +65,17 @@ func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
for _, t := range ts {
|
||||
if strings.Contains(t.Name(), "patch_embed.proj") {
|
||||
for t := range splitDim(t, 2,
|
||||
strings.NewReplacer("patch_embed.proj", "patch_embd_0"),
|
||||
strings.NewReplacer("patch_embed.proj", "patch_embd_1"),
|
||||
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_0")},
|
||||
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_1")},
|
||||
) {
|
||||
t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 })
|
||||
out = append(out, t)
|
||||
}
|
||||
} else if strings.Contains(t.Name(), "attn.qkv") {
|
||||
out = append(out, slices.Collect(splitDim(t, 0,
|
||||
strings.NewReplacer("attn.qkv", "attn_q"),
|
||||
strings.NewReplacer("attn.qkv", "attn_k"),
|
||||
strings.NewReplacer("attn.qkv", "attn_v"),
|
||||
split{Replacer: strings.NewReplacer("attn.qkv", "attn_q")},
|
||||
split{Replacer: strings.NewReplacer("attn.qkv", "attn_k")},
|
||||
split{Replacer: strings.NewReplacer("attn.qkv", "attn_v")},
|
||||
))...)
|
||||
} else {
|
||||
out = append(out, &ggml.Tensor{
|
||||
|
||||
@@ -1,53 +1,73 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type split struct {
|
||||
*strings.Replacer
|
||||
dim int
|
||||
|
||||
// fn is an optional function to apply to the tensor after slicing
|
||||
fn func(tensor.Tensor) (tensor.Tensor, error)
|
||||
}
|
||||
|
||||
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
|
||||
// is split evenly based on the number of replacers provided.
|
||||
func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] {
|
||||
// is split evenly based on the number of replacers provided unless a specific count is given.
|
||||
func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
|
||||
return func(yield func(*ggml.Tensor) bool) {
|
||||
for i, replacer := range replacers {
|
||||
var offset int
|
||||
for _, split := range splits {
|
||||
t := t.Clone()
|
||||
shape := slices.Clone(t.Shape())
|
||||
shape[dim] = shape[dim] / uint64(len(replacers))
|
||||
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
|
||||
|
||||
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
|
||||
slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim]))
|
||||
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
|
||||
offset += int(shape[dim])
|
||||
|
||||
tt := t.Clone()
|
||||
tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
t, err := t.Slice(slice...)
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
tt, err := tt.Slice(slice...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t = tensor.Materialize(t)
|
||||
tt = tensor.Materialize(tt)
|
||||
|
||||
if split.fn != nil {
|
||||
tt, err = split.fn(tt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// flatten tensor so it can be written as a vector
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
})
|
||||
|
||||
if !yield(&ggml.Tensor{
|
||||
Name: replacer.Replace(t.Name()),
|
||||
Name: split.Replace(t.Name()),
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
WriterTo: tt,
|
||||
WriterTo: t,
|
||||
}) {
|
||||
break
|
||||
}
|
||||
|
||||
304
convert/tensor_test.go
Normal file
304
convert/tensor_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
)
|
||||
|
||||
type fakeTensor struct {
|
||||
name string
|
||||
shape []uint64
|
||||
data []float32
|
||||
|
||||
repacker Repacker
|
||||
}
|
||||
|
||||
func (f fakeTensor) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
func (f fakeTensor) Shape() []uint64 {
|
||||
return f.shape
|
||||
}
|
||||
|
||||
func (f fakeTensor) Kind() uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (f *fakeTensor) SetRepacker(fn Repacker) {
|
||||
f.repacker = fn
|
||||
}
|
||||
|
||||
func (f fakeTensor) Clone() Tensor {
|
||||
return &fakeTensor{
|
||||
name: f.name,
|
||||
shape: slices.Clone(f.shape),
|
||||
data: slices.Clone(f.data),
|
||||
repacker: f.repacker,
|
||||
}
|
||||
}
|
||||
|
||||
func (f fakeTensor) WriteTo(w io.Writer) (n int64, err error) {
|
||||
data := f.data
|
||||
if f.repacker != nil {
|
||||
data, err = f.repacker(f.name, data, f.shape)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := binary.Write(w, binary.LittleEndian, data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int64(len(data) * 4), nil
|
||||
}
|
||||
|
||||
func mul(shape []uint64) int {
|
||||
n := 1
|
||||
for _, dim := range shape {
|
||||
n *= int(dim)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func TestSplitDim(t *testing.T) {
|
||||
r := fakeTensor{
|
||||
name: "a.b",
|
||||
shape: []uint64{3, 4},
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||
}
|
||||
|
||||
t.Run("no split", func(t *testing.T) {
|
||||
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatalf("expected name 'x', got '%s'", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 4}) {
|
||||
t.Fatalf("expected shape [3, 4], got %v", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) {
|
||||
t.Fatalf("expected data [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], got %v", f32s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("even split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x")},
|
||||
split{Replacer: strings.NewReplacer("b", "y")},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
||||
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
|
||||
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
||||
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{2, 3, 6, 7, 10, 11}) {
|
||||
t.Fatal("expected data [2, 3, 6, 7, 10, 11], got", f32s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uneven split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 0,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{2, 4}) {
|
||||
t.Fatal("expected shape [2, 4], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}) {
|
||||
t.Fatal("expected data [0, 1, 2, 3, 4, 5, 6, 7], got", f32s)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{1, 4}) {
|
||||
t.Fatal("expected shape [1, 4], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{8, 9, 10, 11}) {
|
||||
t.Fatal("expected data [8, 9, 10, 11], got", f32s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("split with transpose", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x")},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
|
||||
return tensor.Transpose(tt, 1, 0)
|
||||
}},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
||||
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
|
||||
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
||||
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{2, 6, 10, 3, 7, 11}) {
|
||||
t.Fatal("expected data [2, 6, 10, 3, 7, 11], got", f32s)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
29
docs/api.md
29
docs/api.md
@@ -1157,15 +1157,11 @@ A single JSON object will be returned.
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
|
||||
"model": "codellama:13b",
|
||||
"modified_at": "2023-11-04T14:56:49.277302595-07:00",
|
||||
"size": 7365960935,
|
||||
"digest": "9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697",
|
||||
"capabilities": [
|
||||
"completion"
|
||||
],
|
||||
|
||||
"name": "deepseek-r1:latest",
|
||||
"model": "deepseek-r1:latest",
|
||||
"modified_at": "2025-05-10T08:06:48.639712648-07:00",
|
||||
"size": 4683075271,
|
||||
"digest": "0a8c266910232fd3291e71e5ba1e058cc5af9d411192cf88b6d30e92b6e73163",
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
@@ -1178,16 +1174,11 @@ A single JSON object will be returned.
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
"model": "llama4:latest",
|
||||
"modified_at": "2023-12-07T09:32:18.757212583-08:00",
|
||||
"size": 3825819519,
|
||||
"digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
|
||||
"capabilities": [
|
||||
"completion",
|
||||
"vision"
|
||||
],
|
||||
|
||||
"name": "llama3.2:latest",
|
||||
"model": "llama3.2:latest",
|
||||
"modified_at": "2025-05-04T17:37:44.706015396-07:00",
|
||||
"size": 2019393189,
|
||||
"digest": "a80c4f17acd55265feec403c7aef86be0c25983ab279d83f3bcd3abbcb5b8b72",
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
# Benchmark
|
||||
|
||||
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
|
||||
|
||||
## When to use
|
||||
|
||||
Run these benchmarks when:
|
||||
- Making changes to the model inference engine
|
||||
- Modifying model loading/unloading logic
|
||||
- Changing prompt processing or token generation code
|
||||
- Implementing a new model architecture
|
||||
- Testing performance across different hardware setups
|
||||
|
||||
## Prerequisites
|
||||
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
|
||||
## Usage and Examples
|
||||
|
||||
>[!NOTE]
|
||||
>All commands must be run from the root directory of the Ollama project.
|
||||
|
||||
Basic syntax:
|
||||
```bash
|
||||
go test -bench=. ./benchmark/... -m $MODEL_NAME
|
||||
```
|
||||
|
||||
Required flags:
|
||||
- `-bench=.`: Run all benchmarks
|
||||
- `-m`: Model name to benchmark
|
||||
|
||||
Optional flags:
|
||||
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
|
||||
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
|
||||
|
||||
Common usage patterns:
|
||||
|
||||
Single benchmark run with a model specified:
|
||||
```bash
|
||||
go test -bench=. ./benchmark/... -m llama3.3
|
||||
```
|
||||
|
||||
## Output metrics
|
||||
|
||||
The benchmark reports several key metrics:
|
||||
|
||||
- `gen_tok/s`: Generated tokens per second
|
||||
- `prompt_tok/s`: Prompt processing tokens per second
|
||||
- `ttft_ms`: Time to first token in milliseconds
|
||||
- `load_ms`: Model load time in milliseconds
|
||||
- `gen_tokens`: Total tokens generated
|
||||
- `prompt_tokens`: Total prompt tokens processed
|
||||
|
||||
Each benchmark runs two scenarios:
|
||||
- Cold start: Model is loaded from disk for each test
|
||||
- Warm start: Model is pre-loaded in memory
|
||||
|
||||
Three prompt lengths are tested for each scenario:
|
||||
- Short prompt (100 tokens)
|
||||
- Medium prompt (500 tokens)
|
||||
- Long prompt (1000 tokens)
|
||||
@@ -112,8 +112,8 @@ sudo systemctl status ollama
|
||||
> While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
> kernel source, the version is older and may not support all ROCm features. We
|
||||
> recommend you install the latest driver from
|
||||
> https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
> GPU.
|
||||
> [AMD](https://www.amd.com/en/support/download/linux-drivers.html) for best support
|
||||
> of your Radeon GPU.
|
||||
|
||||
## Customizing
|
||||
|
||||
|
||||
@@ -527,23 +527,17 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
keys := slices.Collect(maps.Keys(kv))
|
||||
slices.Sort(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
||||
if err := ggufWriteKV(f, key, kv[key]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortStableFunc(ts, func(a, b *Tensor) int {
|
||||
if i, j := a.block(), b.block(); i < 0 && j > 0 {
|
||||
return 1
|
||||
} else if i > 0 && j < 0 {
|
||||
return -1
|
||||
} else {
|
||||
if i, j := a.block(), b.block(); i > 0 && j > 0 {
|
||||
return cmp.Compare(i, j)
|
||||
}
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
var s uint64
|
||||
|
||||
@@ -2,62 +2,82 @@ package ggml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math/rand/v2"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWriteGGUF(t *testing.T) {
|
||||
w, err := os.CreateTemp(t.TempDir(), "*.bin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer w.Close()
|
||||
r := rand.New(rand.NewPCG(0, 0))
|
||||
for range 8 {
|
||||
t.Run("shuffle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if err := WriteGGUF(w, KV{
|
||||
"general.alignment": uint32(16),
|
||||
}, []*Tensor{
|
||||
{Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ts := []*Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
}
|
||||
|
||||
r, err := os.Open(w.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
r.Shuffle(len(ts), func(i, j int) {
|
||||
ts[i], ts[j] = ts[j], ts[i]
|
||||
})
|
||||
|
||||
ff, err := Decode(r, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
if diff := cmp.Diff(ff.KV(), KV{
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(36),
|
||||
}); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if err := WriteGGUF(w, KV{
|
||||
"general.alignment": uint32(16),
|
||||
}, ts); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(ff.Tensors(), Tensors{
|
||||
Offset: 336,
|
||||
items: []*Tensor{
|
||||
{Name: "test.0", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "test.1", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "test.3", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "test.4", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "test.5", Offset: 160, Shape: []uint64{2, 3}},
|
||||
},
|
||||
}, cmp.AllowUnexported(Tensors{})); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
r, err := os.Open(w.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
ff, err := Decode(r, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(KV{
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(54),
|
||||
}, ff.KV()); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(Tensors{
|
||||
Offset: 608,
|
||||
items: []*Tensor{
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
|
||||
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
|
||||
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
|
||||
},
|
||||
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -700,6 +700,8 @@ const (
|
||||
DoneReasonStop DoneReason = iota
|
||||
// DoneReasonLength indicates the completion stopped due to length limits
|
||||
DoneReasonLength
|
||||
// DoneReasonContextShift indicates the completion stopped due to context shift
|
||||
DoneReasonContextShift
|
||||
// DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed
|
||||
DoneReasonConnectionClosed
|
||||
)
|
||||
@@ -710,6 +712,8 @@ func (d DoneReason) String() string {
|
||||
return "length"
|
||||
case DoneReasonStop:
|
||||
return "stop"
|
||||
case DoneReasonContextShift:
|
||||
return "context_limit_reached"
|
||||
default:
|
||||
return "" // closed
|
||||
}
|
||||
|
||||
@@ -63,9 +63,9 @@ func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOp
|
||||
}
|
||||
|
||||
type TextExperts struct {
|
||||
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
|
||||
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
|
||||
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
|
||||
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
|
||||
hiddenStates = hiddenStates.Mul(ctx, scores)
|
||||
|
||||
upStates := e.Up.MulmatID(ctx, hiddenStates, experts)
|
||||
gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts)
|
||||
downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||
upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
|
||||
gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
|
||||
downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||
|
||||
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
|
||||
@@ -66,9 +66,9 @@ type MLP interface {
|
||||
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
|
||||
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
|
||||
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
@@ -87,13 +87,13 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
|
||||
upStates := mlp.Up.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
|
||||
hiddenStates = mlp.Gate.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
hiddenStates = hiddenStates.SILU(ctx)
|
||||
hiddenStates = hiddenStates.Mul(ctx, upStates)
|
||||
|
||||
experts := mlp.Down.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
|
||||
@@ -87,7 +87,7 @@ func (v *Vocabulary) Decode(id int32) string {
|
||||
func (v *Vocabulary) SpecialVocabulary() []string {
|
||||
v.specialOnce.Do(func() {
|
||||
for i := range v.Values {
|
||||
if v.Types[i] == TOKEN_TYPE_CONTROL {
|
||||
if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
|
||||
v.special = append(v.special, v.Values[i])
|
||||
}
|
||||
}
|
||||
|
||||
16
model/vocabulary_test.go
Normal file
16
model/vocabulary_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package model
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestVocabulary_SpecialVocabulary(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
|
||||
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
|
||||
}
|
||||
|
||||
specialVocab := vocab.SpecialVocabulary()
|
||||
|
||||
if len(specialVocab) != 4 {
|
||||
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
|
||||
}
|
||||
}
|
||||
@@ -292,13 +292,18 @@ func filesForModel(path string) ([]string, error) {
|
||||
}
|
||||
files = append(files, js...)
|
||||
|
||||
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
||||
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
||||
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||
files = append(files, tks...)
|
||||
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
||||
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
||||
files = append(files, tks...)
|
||||
// only include tokenizer.model is tokenizer.json is not present
|
||||
if !slices.ContainsFunc(files, func(s string) bool {
|
||||
return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json")
|
||||
}) {
|
||||
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
||||
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
||||
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||
files = append(files, tks...)
|
||||
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
||||
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
||||
files = append(files, tks...)
|
||||
}
|
||||
}
|
||||
|
||||
return files, nil
|
||||
|
||||
@@ -80,6 +80,9 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
// true if context shifting should be enabled
|
||||
shiftContext bool
|
||||
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
@@ -90,11 +93,12 @@ type Sequence struct {
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int
|
||||
samplingParams *llama.SamplingParams
|
||||
embedding bool
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int
|
||||
samplingParams *llama.SamplingParams
|
||||
embedding bool
|
||||
enableContextShift bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
@@ -120,7 +124,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// Ensure that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
if len(inputs) > s.cache.numCtx {
|
||||
if len(inputs) > s.cache.numCtx && params.enableContextShift {
|
||||
discard := len(inputs) - s.cache.numCtx
|
||||
newInputs := inputs[:params.numKeep]
|
||||
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||
@@ -155,6 +159,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
shiftContext: params.enableContextShift,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -300,13 +305,26 @@ func flushPending(seq *Sequence) bool {
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
if seq == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark the sequence as being removed to prevent further processing
|
||||
s.seqs[seqIndex] = nil
|
||||
|
||||
if seq.cache != nil {
|
||||
seq.cache.InUse = false
|
||||
}
|
||||
|
||||
if len(seq.pendingResponses) > 0 {
|
||||
flushPending(seq)
|
||||
}
|
||||
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
s.seqs[seqIndex] = nil
|
||||
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
|
||||
@@ -340,7 +358,7 @@ func (s *Server) run(ctx context.Context) {
|
||||
default:
|
||||
err := s.processBatch(tokenBatch, embedBatch)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
slog.Error("error processing batch", "error", err)
|
||||
}
|
||||
|
||||
tokenBatch.Clear()
|
||||
@@ -382,6 +400,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
for i, input := range seq.inputs {
|
||||
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
|
||||
if !seq.shiftContext {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonContextShift)
|
||||
continue
|
||||
}
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
@@ -573,11 +595,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: req.Options.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: req.Options.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
enableContextShift: req.Options.ShiftContext,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
|
||||
152
runner/llamarunner/runner_test.go
Normal file
152
runner/llamarunner/runner_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package llamarunner
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func TestContextShiftLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
enableContextShift bool
|
||||
contextLength int32
|
||||
cacheInputs int
|
||||
pendingInputs int
|
||||
minBatch int
|
||||
expectedDoneReason llm.DoneReason
|
||||
shouldRemove bool
|
||||
}{
|
||||
{
|
||||
name: "context shifting enabled - should shift",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - should remove",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - within limits",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 50,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "pending inputs - should break batch",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 50,
|
||||
pendingInputs: 20,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "no pending inputs - should shift",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "long prompt with context shifting disabled - will be handled at runtime",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 0,
|
||||
pendingInputs: 0,
|
||||
minBatch: 150, // Simulates a long prompt
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test the core logic from processBatch
|
||||
if int32(tt.cacheInputs+tt.pendingInputs+tt.minBatch) > tt.contextLength {
|
||||
if tt.pendingInputs != 0 {
|
||||
// Should break batch
|
||||
if tt.shouldRemove {
|
||||
t.Error("should not remove sequence when pending inputs exist")
|
||||
}
|
||||
} else if !tt.enableContextShift {
|
||||
// Should remove with DoneReasonContextShift
|
||||
if !tt.shouldRemove {
|
||||
t.Error("should remove sequence when context shifting disabled")
|
||||
}
|
||||
if tt.expectedDoneReason != llm.DoneReasonContextShift {
|
||||
t.Errorf("expected DoneReason %v, got %v", llm.DoneReasonContextShift, tt.expectedDoneReason)
|
||||
}
|
||||
} else {
|
||||
// Should shift context
|
||||
if tt.shouldRemove {
|
||||
t.Error("should not remove sequence when context shifting enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPredictLimitLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numPredict int
|
||||
numPredicted int
|
||||
expectRemove bool
|
||||
}{
|
||||
{
|
||||
name: "predict limit not reached",
|
||||
numPredict: 5,
|
||||
numPredicted: 3,
|
||||
expectRemove: false,
|
||||
},
|
||||
{
|
||||
name: "predict limit reached",
|
||||
numPredict: 5,
|
||||
numPredicted: 5,
|
||||
expectRemove: true,
|
||||
},
|
||||
{
|
||||
name: "predict limit exceeded",
|
||||
numPredict: 5,
|
||||
numPredicted: 6,
|
||||
expectRemove: true,
|
||||
},
|
||||
{
|
||||
name: "no predict limit",
|
||||
numPredict: 0,
|
||||
numPredicted: 100,
|
||||
expectRemove: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test the core logic from processBatch
|
||||
shouldRemove := tt.numPredict > 0 && tt.numPredicted >= tt.numPredict
|
||||
if shouldRemove != tt.expectRemove {
|
||||
t.Errorf("expected remove=%v, got %v", tt.expectRemove, shouldRemove)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -85,6 +85,9 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
// true if context shifting should be enabled
|
||||
shiftContext bool
|
||||
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
@@ -95,11 +98,12 @@ type Sequence struct {
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int32
|
||||
sampler sample.Sampler
|
||||
embedding bool
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int32
|
||||
sampler sample.Sampler
|
||||
embedding bool
|
||||
enableContextShift bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
@@ -121,7 +125,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// Ensure that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
if int32(len(inputs)) > s.cache.numCtx {
|
||||
if int32(len(inputs)) > s.cache.numCtx && params.enableContextShift {
|
||||
discard := int32(len(inputs)) - s.cache.numCtx
|
||||
promptStart := params.numKeep + discard
|
||||
|
||||
@@ -175,6 +179,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
shiftContext: params.enableContextShift,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -341,13 +346,25 @@ func flushPending(seq *Sequence) bool {
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
if seq == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark the sequence as being removed to prevent further processing
|
||||
s.seqs[seqIndex] = nil
|
||||
|
||||
if seq.cache != nil {
|
||||
seq.cache.InUse = false
|
||||
}
|
||||
|
||||
if len(seq.pendingResponses) > 0 {
|
||||
flushPending(seq)
|
||||
}
|
||||
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
s.seqs[seqIndex] = nil
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
|
||||
@@ -431,6 +448,11 @@ func (s *Server) processBatch() error {
|
||||
break
|
||||
}
|
||||
|
||||
if !seq.shiftContext {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonContextShift)
|
||||
continue
|
||||
}
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
var reprocess *ErrReprocessInputs
|
||||
@@ -629,11 +651,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: int32(req.Options.NumKeep),
|
||||
sampler: sampler,
|
||||
embedding: false,
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: int32(req.Options.NumKeep),
|
||||
sampler: sampler,
|
||||
embedding: false,
|
||||
enableContextShift: req.Options.ShiftContext,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
|
||||
167
runner/ollamarunner/runner_test.go
Normal file
167
runner/ollamarunner/runner_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func TestEnableContextShiftLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
enableContextShift bool
|
||||
contextLength int32
|
||||
cacheInputs int
|
||||
pendingInputs int
|
||||
minBatch int
|
||||
expectedDoneReason llm.DoneReason
|
||||
shouldRemove bool
|
||||
}{
|
||||
{
|
||||
name: "context shifting enabled - should shift",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - should remove with DoneReasonContextShift",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - within limits",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 50,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - exact limit",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 100,
|
||||
pendingInputs: 0,
|
||||
minBatch: 1,
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
{
|
||||
name: "pending inputs - should break batch",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 50,
|
||||
pendingInputs: 20,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "no pending inputs - should shift",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "long prompt with context shifting disabled - will be handled at runtime",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 0,
|
||||
pendingInputs: 0,
|
||||
minBatch: 150, // Simulates a long prompt
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test the core logic from processBatch - matches actual implementation
|
||||
if int32(tt.cacheInputs+tt.pendingInputs+tt.minBatch) > tt.contextLength {
|
||||
if tt.pendingInputs != 0 {
|
||||
// Should break batch - don't remove sequence
|
||||
if tt.shouldRemove {
|
||||
t.Error("should not remove sequence when pending inputs exist")
|
||||
}
|
||||
} else if !tt.enableContextShift {
|
||||
// Should remove with DoneReasonContextShift
|
||||
if !tt.shouldRemove {
|
||||
t.Error("should remove sequence when context shifting disabled")
|
||||
}
|
||||
if tt.expectedDoneReason != llm.DoneReasonContextShift {
|
||||
t.Errorf("expected DoneReason %v, got %v", llm.DoneReasonContextShift, tt.expectedDoneReason)
|
||||
}
|
||||
} else {
|
||||
// Should shift context - don't remove sequence
|
||||
if tt.shouldRemove {
|
||||
t.Error("should not remove sequence when context shifting enabled")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Within limits - should not remove
|
||||
if tt.shouldRemove {
|
||||
t.Errorf("should not remove sequence when within context limits")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPredictLimitLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numPredict int
|
||||
numPredicted int
|
||||
expectRemove bool
|
||||
}{
|
||||
{
|
||||
name: "predict limit not reached",
|
||||
numPredict: 5,
|
||||
numPredicted: 3,
|
||||
expectRemove: false,
|
||||
},
|
||||
{
|
||||
name: "predict limit reached",
|
||||
numPredict: 5,
|
||||
numPredicted: 5,
|
||||
expectRemove: true,
|
||||
},
|
||||
{
|
||||
name: "predict limit exceeded",
|
||||
numPredict: 5,
|
||||
numPredicted: 6,
|
||||
expectRemove: true,
|
||||
},
|
||||
{
|
||||
name: "no predict limit",
|
||||
numPredict: 0,
|
||||
numPredicted: 100,
|
||||
expectRemove: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test the core logic from processBatch
|
||||
shouldRemove := tt.numPredict > 0 && tt.numPredicted >= tt.numPredict
|
||||
if shouldRemove != tt.expectRemove {
|
||||
t.Errorf("expected remove=%v, got %v", tt.expectRemove, shouldRemove)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
2
server/internal/cache/blob/cache.go
vendored
2
server/internal/cache/blob/cache.go
vendored
@@ -59,7 +59,7 @@ type DiskCache struct {
|
||||
testHookBeforeFinalWrite func(f *os.File)
|
||||
}
|
||||
|
||||
// PutString is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
|
||||
// PutBytes is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
|
||||
func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
|
||||
return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
|
||||
}
|
||||
|
||||
@@ -63,6 +63,9 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
}
|
||||
|
||||
if ctxLen > opts.NumCtx {
|
||||
if !opts.ShiftContext {
|
||||
return "", nil, fmt.Errorf("context length of %d tokens exceeded, context shifting is disabled", opts.NumCtx)
|
||||
}
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||
break
|
||||
} else {
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -56,7 +57,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "A test. And a thumping good one at that, I'd wager. ",
|
||||
error: fmt.Errorf("context length of 1 tokens exceeded, context shifting is disabled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -69,10 +70,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
},
|
||||
error: fmt.Errorf("context length of 64 tokens exceeded, context shifting is disabled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -85,10 +83,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
error: fmt.Errorf("context length of 64 tokens exceeded, context shifting is disabled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -156,10 +151,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
error: fmt.Errorf("context length of 1024 tokens exceeded, context shifting is disabled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -208,12 +200,25 @@ func TestChatPrompt(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
|
||||
// For truncation tests, disable context shifting to test the truncation behavior
|
||||
if tt.name == "truncate messages" ||
|
||||
tt.name == "truncate messages with image" ||
|
||||
tt.name == "truncate messages with images" ||
|
||||
tt.name == "truncate message with interleaved images" {
|
||||
opts.ShiftContext = false
|
||||
}
|
||||
|
||||
think := false
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
||||
} else if tt.error != nil && err != nil {
|
||||
if err.Error() != tt.error.Error() {
|
||||
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
||||
}
|
||||
} else if tt.error != nil && err == nil {
|
||||
t.Fatalf("expected err '%q', got nil", tt.error)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
||||
|
||||
@@ -929,7 +929,8 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
r := api.ListModelResponse{
|
||||
// tag should never be masked
|
||||
models = append(models, api.ListModelResponse{
|
||||
Model: n.DisplayShortest(),
|
||||
Name: n.DisplayShortest(),
|
||||
Size: m.Size(),
|
||||
@@ -942,16 +943,7 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
ParameterSize: cf.ModelType,
|
||||
QuantizationLevel: cf.FileType,
|
||||
},
|
||||
}
|
||||
|
||||
model, err := GetModel(n.String())
|
||||
if err != nil {
|
||||
slog.Warn("bad model details", "name", n, "error", err)
|
||||
} else {
|
||||
r.Capabilities = model.Capabilities()
|
||||
}
|
||||
|
||||
models = append(models, r)
|
||||
})
|
||||
}
|
||||
|
||||
slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
|
||||
@@ -1534,12 +1526,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
|
||||
var toolParser *tools.Parser
|
||||
if len(req.Tools) > 0 {
|
||||
toolParser, err = tools.NewParser(m.Template.Template)
|
||||
if err != nil {
|
||||
slog.Error("failed to create tool parser", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
toolParser = tools.NewParser(m.Template.Template, req.Tools)
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
@@ -1592,6 +1579,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
// don't return
|
||||
} else {
|
||||
if r.Done {
|
||||
res.Message.Content = toolParser.Content()
|
||||
ch <- res
|
||||
}
|
||||
return
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -968,3 +970,154 @@ func TestWaitForStream(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableContextShiftNonStreamingResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
enableContextShift bool
|
||||
responses []llm.CompletionResponse
|
||||
expectedDone bool
|
||||
expectedDoneReason string
|
||||
}{
|
||||
{
|
||||
name: "context shifting disabled - should have DoneReasonLength",
|
||||
enableContextShift: false,
|
||||
responses: []llm.CompletionResponse{
|
||||
{Content: "Hello", Done: false},
|
||||
{Content: " world", Done: false},
|
||||
{Content: "", Done: true, DoneReason: llm.DoneReasonLength},
|
||||
},
|
||||
expectedDone: true,
|
||||
expectedDoneReason: "length",
|
||||
},
|
||||
{
|
||||
name: "context shifting enabled - should have DoneReasonStop",
|
||||
enableContextShift: true,
|
||||
responses: []llm.CompletionResponse{
|
||||
{Content: "Hello", Done: false},
|
||||
{Content: " world", Done: false},
|
||||
{Content: "", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedDone: true,
|
||||
expectedDoneReason: "stop",
|
||||
},
|
||||
{
|
||||
name: "no final response with Done=true",
|
||||
enableContextShift: false,
|
||||
responses: []llm.CompletionResponse{
|
||||
{Content: "Hello", Done: false},
|
||||
{Content: " world", Done: false},
|
||||
},
|
||||
expectedDone: false,
|
||||
expectedDoneReason: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// The last response in the channel will naturally be the final state
|
||||
lastResponse := tt.responses[len(tt.responses)-1]
|
||||
|
||||
if lastResponse.Done != tt.expectedDone {
|
||||
t.Errorf("expected Done=%v, got %v", tt.expectedDone, lastResponse.Done)
|
||||
}
|
||||
|
||||
if tt.expectedDoneReason != "" {
|
||||
if lastResponse.DoneReason.String() != tt.expectedDoneReason {
|
||||
t.Errorf("expected DoneReason=%s, got %s", tt.expectedDoneReason, lastResponse.DoneReason.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleScheduleError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errorMessage string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "context length exceeded error",
|
||||
errorMessage: "context length of 100 tokens exceeded, context shifting is disabled",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
errorMessage: "some other error",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
err := errors.New(tt.errorMessage)
|
||||
|
||||
handleScheduleError(c, "test-model", err)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
var response map[string]any
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if errorMsg, ok := response["error"].(string); !ok || errorMsg != tt.errorMessage {
|
||||
t.Errorf("expected error message '%s', got '%s'", tt.errorMessage, errorMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableContextShiftOptions(t *testing.T) {
|
||||
t.Run("default options have enableContextShift=true", func(t *testing.T) {
|
||||
opts := api.DefaultOptions()
|
||||
if !opts.ShiftContext {
|
||||
t.Errorf("expected EnableContextShift=true by default, got %v", opts.ShiftContext)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("can set enableContextShift to false", func(t *testing.T) {
|
||||
opts := api.DefaultOptions()
|
||||
opts.ShiftContext = false
|
||||
if opts.ShiftContext {
|
||||
t.Errorf("expected EnableContextShift=false after setting, got %v", opts.ShiftContext)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("JSON serialization omits false values", func(t *testing.T) {
|
||||
opts := api.DefaultOptions()
|
||||
opts.ShiftContext = false
|
||||
|
||||
data, err := json.Marshal(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal options: %v", err)
|
||||
}
|
||||
|
||||
// Check that enable_context_shift is not in the JSON when false
|
||||
if bytes.Contains(data, []byte("enable_context_shift")) {
|
||||
t.Errorf("expected enable_context_shift to be omitted from JSON when false, but found it in: %s", string(data))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("JSON serialization includes true values", func(t *testing.T) {
|
||||
opts := api.DefaultOptions()
|
||||
opts.ShiftContext = true
|
||||
|
||||
data, err := json.Marshal(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal options: %v", err)
|
||||
}
|
||||
|
||||
// Check that enable_context_shift is in the JSON when true
|
||||
if !bytes.Contains(data, []byte("enable_context_shift")) {
|
||||
t.Errorf("expected enable_context_shift to be in JSON when true, but not found in: %s", string(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
156
tools/template.go
Normal file
156
tools/template.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
"text/template/parse"
|
||||
)
|
||||
|
||||
// parseTag finds the tool calling tag from a Go template
|
||||
// often <tool_call> [TOOL_CALL] or similar by finding the
|
||||
// first text node after .ToolCalls and returning the content
|
||||
// if no tag is found, return "{" to indicate that json objects
|
||||
// should be attempted to be parsed as tool calls
|
||||
func parseTag(tmpl *template.Template) string {
|
||||
if tmpl == nil || tmpl.Tree == nil {
|
||||
slog.Debug("template or tree is nil")
|
||||
return "{"
|
||||
}
|
||||
|
||||
tc := findToolCallNode(tmpl.Tree.Root.Nodes)
|
||||
if tc == nil {
|
||||
return "{"
|
||||
}
|
||||
|
||||
tn := findTextNode(tc.List.Nodes)
|
||||
if tn == nil {
|
||||
return "{"
|
||||
}
|
||||
|
||||
tag := string(tn.Text)
|
||||
tag = strings.ReplaceAll(tag, "\r\n", "\n")
|
||||
|
||||
// avoid parsing { onwards as this may be a tool call
|
||||
// however keep '{' as a prefix if there is no tag
|
||||
// so that all json objects will be attempted to
|
||||
// be parsed as tool calls
|
||||
tag, _, _ = strings.Cut(tag, "{")
|
||||
tag = strings.TrimSpace(tag)
|
||||
if tag == "" {
|
||||
tag = "{"
|
||||
}
|
||||
|
||||
return tag
|
||||
}
|
||||
|
||||
// findToolCallNode searches for and returns an IfNode with .ToolCalls
|
||||
func findToolCallNode(nodes []parse.Node) *parse.IfNode {
|
||||
isToolCallsNode := func(n *parse.IfNode) bool {
|
||||
for _, cmd := range n.Pipe.Cmds {
|
||||
for _, arg := range cmd.Args {
|
||||
if field, ok := arg.(*parse.FieldNode); ok {
|
||||
if slices.Contains(field.Ident, "ToolCalls") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
switch n := node.(type) {
|
||||
case *parse.IfNode:
|
||||
if isToolCallsNode(n) {
|
||||
return n
|
||||
}
|
||||
// Recursively search in nested IfNodes
|
||||
if result := findToolCallNode(n.List.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
if n.ElseList != nil {
|
||||
if result := findToolCallNode(n.ElseList.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
case *parse.ListNode:
|
||||
if result := findToolCallNode(n.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
case *parse.RangeNode:
|
||||
if result := findToolCallNode(n.List.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
if n.ElseList != nil {
|
||||
if result := findToolCallNode(n.ElseList.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
case *parse.WithNode:
|
||||
if result := findToolCallNode(n.List.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
if n.ElseList != nil {
|
||||
if result := findToolCallNode(n.ElseList.Nodes); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// findTextNode does a depth-first search for the first text content in nodes,
|
||||
// stopping at template constructs to avoid parsing text after the tool calls
|
||||
func findTextNode(nodes []parse.Node) *parse.TextNode {
|
||||
for _, node := range nodes {
|
||||
switch n := node.(type) {
|
||||
case *parse.TextNode:
|
||||
// skip whitespace-only text nodes
|
||||
if len(bytes.TrimSpace(n.Text)) == 0 {
|
||||
continue
|
||||
}
|
||||
return n
|
||||
case *parse.IfNode:
|
||||
if text := findTextNode(n.List.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
if n.ElseList != nil {
|
||||
if text := findTextNode(n.ElseList.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case *parse.ListNode:
|
||||
if text := findTextNode(n.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
case *parse.RangeNode:
|
||||
if text := findTextNode(n.List.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
if n.ElseList != nil {
|
||||
if text := findTextNode(n.ElseList.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case *parse.WithNode:
|
||||
if text := findTextNode(n.List.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
if n.ElseList != nil {
|
||||
if text := findTextNode(n.ElseList.Nodes); text != nil {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case *parse.ActionNode:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
139
tools/template_test.go
Normal file
139
tools/template_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
func TestParseTag(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
template: "",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "no tag",
|
||||
template: "{{if .ToolCalls}}{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "no tag with range",
|
||||
template: "{{if .ToolCalls}}{{range .ToolCalls}}{{ . }}{{end}}{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "tool call with json format",
|
||||
template: "{{if .ToolCalls}}```json\n{{end}}",
|
||||
want: "```json",
|
||||
},
|
||||
{
|
||||
name: "square brackets",
|
||||
template: "{{if .ToolCalls}}[{{range .ToolCalls}}{{ . }}{{end}}]{{end}}",
|
||||
want: "[",
|
||||
},
|
||||
{
|
||||
name: "square brackets with whitespace",
|
||||
template: "{{if .ToolCalls}}\n [ {{range .ToolCalls}}{{ . }}{{end}}]{{end}}",
|
||||
want: "[",
|
||||
},
|
||||
{
|
||||
name: "tailing ]",
|
||||
template: "{{if .ToolCalls}}{{range .ToolCalls}}{{ . }}{{end}}]{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
template: "{{if .ToolCalls}} {{range .ToolCalls}}{{ . }}{{end}}{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "whitespace only in range",
|
||||
template: "{{if .ToolCalls}}{{range .ToolCalls}}\n{{ . }}\n{{end}}{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "json objects",
|
||||
template: `{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{end}}{{end}}`,
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "json objects with whitespace",
|
||||
template: "{{if .ToolCalls}}{{range .ToolCalls}}\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "json objects with CRLF",
|
||||
template: "{{if .ToolCalls}}{{range .ToolCalls}}\r\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "json objects with whitespace before and after range",
|
||||
template: "{{if .ToolCalls}}\n{{range .ToolCalls}}\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\r\n{{end}}\r\n{{end}}",
|
||||
want: "{",
|
||||
},
|
||||
{
|
||||
name: "before and after range",
|
||||
template: "{{if .ToolCalls}}<|tool▁calls▁begin|>{{range .ToolCalls}}<|tool▁call▁begin|>functionget_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|>\n{{end}}<|tool▁calls▁end|>{{end}}",
|
||||
want: "<|tool▁calls▁begin|>",
|
||||
},
|
||||
{
|
||||
name: "after range",
|
||||
template: "{{if .ToolCalls}}{{range .ToolCalls}}<tool_call>{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}</tool_call>{{end}}{{end}}",
|
||||
want: "<tool_call>",
|
||||
},
|
||||
{
|
||||
name: "after range with leading whitespace before range",
|
||||
template: "{{if .ToolCalls}}\n{{range .ToolCalls}}<tool_call>{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}</tool_call>{{end}}{{end}}",
|
||||
want: "<tool_call>",
|
||||
},
|
||||
{
|
||||
name: "tool call in range with {",
|
||||
template: `{{if .ToolCalls}}{{range .ToolCalls}}<tool_call>{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}<tool_call>{{end}}{{end}}`,
|
||||
want: "<tool_call>",
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple text nodes",
|
||||
template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}",
|
||||
want: "First text",
|
||||
},
|
||||
{
|
||||
name: "action tag",
|
||||
template: "{{if .ToolCalls}}Action: ```json{{end}}",
|
||||
want: "Action: ```json",
|
||||
},
|
||||
{
|
||||
name: "incomplete functools bracket",
|
||||
template: "{{if .ToolCalls}}functools[{{end}}",
|
||||
want: "functools[",
|
||||
},
|
||||
{
|
||||
name: "uppercase tool call with incomplete bracket",
|
||||
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
|
||||
want: "[TOOL_CALL] [",
|
||||
},
|
||||
{
|
||||
name: "uppercase tool call with adjacent bracket",
|
||||
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
|
||||
want: "[TOOL_CALL][",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.template)
|
||||
if err != nil && tc.template != "" {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
got := parseTag(tmpl)
|
||||
if got != tc.want {
|
||||
t.Errorf("got text %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
67
tools/testdata/command-r-plus.gotmpl
vendored
67
tools/testdata/command-r-plus.gotmpl
vendored
@@ -1,67 +0,0 @@
|
||||
{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
|
||||
{{- if .Tools }}# Safety Preamble
|
||||
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
|
||||
|
||||
# System Preamble
|
||||
## Basic Rules
|
||||
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
|
||||
|
||||
{{ if .System }}# User Preamble
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
|
||||
## Available Tools
|
||||
Here is a list of tools that you have available to you:
|
||||
{{- range .Tools }}
|
||||
|
||||
```python
|
||||
def {{ .Function.Name }}(
|
||||
{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]:
|
||||
"""{{ .Function.Description }}
|
||||
|
||||
{{- if .Function.Parameters.Properties }}
|
||||
|
||||
Args:
|
||||
{{- range $name, $property := .Function.Parameters.Properties }}
|
||||
{{ $name }} ({{ $property.Type }}): {{ $property.Description }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
"""
|
||||
pass
|
||||
```
|
||||
{{- end }}
|
||||
{{- else if .System }}{{ .System }}
|
||||
{{- end }}<|END_OF_TURN_TOKEN|>
|
||||
{{- end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "system" }}
|
||||
{{- continue }}
|
||||
{{- end }}<|START_OF_TURN_TOKEN|>
|
||||
{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }}
|
||||
{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|>
|
||||
{{- if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}
|
||||
Action: ```json
|
||||
[
|
||||
{{- range .ToolCalls }}
|
||||
{
|
||||
"tool_name": "{{ .Function.Name }}",
|
||||
"parameters": {{ .Function.Arguments }}
|
||||
}
|
||||
{{- end }}
|
||||
]```
|
||||
{{ continue }}
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|><results>
|
||||
{{ .Content }}</results>
|
||||
{{- end }}<|END_OF_TURN_TOKEN|>
|
||||
{{- end }}
|
||||
{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"tool_name": title of the tool in the specification,
|
||||
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
|
||||
}
|
||||
]```
|
||||
{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
39
tools/testdata/command-r-plus.out
vendored
39
tools/testdata/command-r-plus.out
vendored
@@ -1,39 +0,0 @@
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble
|
||||
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
|
||||
|
||||
# System Preamble
|
||||
## Basic Rules
|
||||
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
|
||||
|
||||
# User Preamble
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
## Available Tools
|
||||
Here is a list of tools that you have available to you:
|
||||
|
||||
```python
|
||||
def get_current_weather(format: string, location: string, ) -> List[Dict]:
|
||||
"""Get the current weather
|
||||
|
||||
Args:
|
||||
format (string): The temperature unit to use. Infer this from the user's location.
|
||||
location (string): The city and state, e.g. San Francisco, CA
|
||||
"""
|
||||
pass
|
||||
```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
Action: ```json
|
||||
[
|
||||
{
|
||||
"tool_name": "get_current_weather",
|
||||
"parameters": {"format":"celsius","location":"Paris, France"}
|
||||
}
|
||||
]```
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>
|
||||
22</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"tool_name": title of the tool in the specification,
|
||||
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
|
||||
}
|
||||
]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
31
tools/testdata/firefunction.gotmpl
vendored
31
tools/testdata/firefunction.gotmpl
vendored
@@ -1,31 +0,0 @@
|
||||
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||
|
||||
Use the following rule to decide when to call a function:
|
||||
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
|
||||
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
|
||||
|
||||
If you decide to call functions:
|
||||
* prefix function calls with functools marker (no closing marker required)
|
||||
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
|
||||
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
|
||||
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
|
||||
* make sure you pick the right functions that match the user intent
|
||||
|
||||
Available functions as JSON spec:
|
||||
{{- if .Tools }}
|
||||
{{ .Tools }}
|
||||
{{- end }}<|eot_id|>
|
||||
{{- end }}
|
||||
{{- range .Messages }}<|start_header_id|>
|
||||
{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }}
|
||||
{{- end }}<|end_header_id|>
|
||||
{{- if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }} functools[
|
||||
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
|
||||
{{- end }}]
|
||||
{{- end }}<|eot_id|>
|
||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
17
tools/testdata/firefunction.out
vendored
17
tools/testdata/firefunction.out
vendored
@@ -1,17 +0,0 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||
|
||||
Use the following rule to decide when to call a function:
|
||||
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
|
||||
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
|
||||
|
||||
If you decide to call functions:
|
||||
* prefix function calls with functools marker (no closing marker required)
|
||||
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
|
||||
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
|
||||
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
|
||||
* make sure you pick the right functions that match the user intent
|
||||
|
||||
Available functions as JSON spec:
|
||||
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgeable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
43
tools/testdata/llama3-groq-tool-use.gotmpl
vendored
43
tools/testdata/llama3-groq-tool-use.gotmpl
vendored
@@ -1,43 +0,0 @@
|
||||
{{- if .Messages }}
|
||||
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
{{ .System }}
|
||||
{{- if .Tools }} You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
<tool_call>
|
||||
{"name": <function-name>,"arguments": <args-dict>}
|
||||
</tool_call>
|
||||
|
||||
Here are the available tools:
|
||||
<tools>
|
||||
{{- range .Tools }} {{ .Function }}
|
||||
{{- end }} </tools>
|
||||
{{- end }}
|
||||
{{- end }}<|eot_id|>
|
||||
{{- range .Messages }}
|
||||
{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|>
|
||||
|
||||
{{ if eq .Role "user" }}{{ .Content }}
|
||||
{{- else if eq .Role "assistant" }}
|
||||
{{- if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{- end }}
|
||||
</tool_call>
|
||||
{{- end }}
|
||||
{{- else if eq .Role "tool" }}<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response>
|
||||
{{- end }}<|eot_id|>
|
||||
{{- end }}
|
||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ else }}
|
||||
{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ end }}{{ .Response }}
|
||||
{{- if .Response }}<|eot_id|>
|
||||
{{- end }}
|
||||
24
tools/testdata/llama3-groq-tool-use.out
vendored
24
tools/testdata/llama3-groq-tool-use.out
vendored
@@ -1,24 +0,0 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
<tool_call>
|
||||
{"name": <function-name>,"arguments": <args-dict>}
|
||||
</tool_call>
|
||||
|
||||
Here are the available tools:
|
||||
<tools> {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} </tools><|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||
</tool_call><|eot_id|><|start_header_id|>tool<|end_header_id|>
|
||||
|
||||
<tool_response>
|
||||
22
|
||||
</tool_response><|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
44
tools/testdata/llama3.2.gotmpl
vendored
44
tools/testdata/llama3.2.gotmpl
vendored
@@ -1,44 +0,0 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
|
||||
{{ if .System }}{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question.
|
||||
|
||||
You are a helpful assistant with tool calling capabilities.
|
||||
{{- end }}<|eot_id|>
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
||||
{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
|
||||
{{- if and $.Tools $last }}
|
||||
|
||||
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
|
||||
|
||||
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
|
||||
|
||||
{{ range $.Tools }}
|
||||
{{- . }}
|
||||
{{ end }}
|
||||
{{ .Content }}<|eot_id|>
|
||||
{{- else }}
|
||||
|
||||
{{ .Content }}<|eot_id|>
|
||||
{{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ end }}
|
||||
{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
|
||||
{{- if .ToolCalls }}
|
||||
{{ range .ToolCalls }}
|
||||
{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
|
||||
{{- else }}
|
||||
|
||||
{{ .Content }}
|
||||
{{- end }}{{ if not $last }}<|eot_id|>{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
|
||||
|
||||
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
24
tools/testdata/llama3.2.out
vendored
24
tools/testdata/llama3.2.out
vendored
@@ -1,24 +0,0 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.When you receive a tool call response, use the output to format an answer to the orginal user question.
|
||||
|
||||
You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{"name": "get_current_weather", "parameters": {"format":"celsius","location":"Paris, France"}}<|eot_id|><|start_header_id|>ipython<|end_header_id|>
|
||||
|
||||
22<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
|
||||
|
||||
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
|
||||
|
||||
{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||
|
||||
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
39
tools/testdata/messages.json
vendored
39
tools/testdata/messages.json
vendored
@@ -1,39 +0,0 @@
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a knowledgeable assistant. You can answer questions and perform tasks."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like today in Paris?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "89a1e453-0bce-4de3-a456-c54bed09c520",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": {
|
||||
"location": "Paris, France",
|
||||
"format": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520",
|
||||
"content": "22"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The current temperature in Paris, France is 22 degrees Celsius."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like today in San Francisco and Toronto?"
|
||||
}
|
||||
]
|
||||
15
tools/testdata/mistral.gotmpl
vendored
15
tools/testdata/mistral.gotmpl
vendored
@@ -1,15 +0,0 @@
|
||||
{{- range $index, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}
|
||||
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
|
||||
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
||||
|
||||
{{ end }}{{ .Content }}[/INST]
|
||||
{{- else if eq .Role "assistant" }}
|
||||
{{- if .Content }} {{ .Content }}</s>
|
||||
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{- end }}]</s>
|
||||
{{- end }}
|
||||
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
3
tools/testdata/mistral.out
vendored
3
tools/testdata/mistral.out
vendored
@@ -1,3 +0,0 @@
|
||||
[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]</s>[TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.</s>[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
What's the weather like today in San Francisco and Toronto?[/INST]
|
||||
33
tools/testdata/nemotron.gotmpl
vendored
33
tools/testdata/nemotron.gotmpl
vendored
@@ -1,33 +0,0 @@
|
||||
{{- if (or .Tools .System) }}<extra_id_0>System
|
||||
{{ if .System }}{{ .System }}
|
||||
|
||||
|
||||
{{ end }}
|
||||
{{- if .Tools }}
|
||||
{{- range .Tools }}<tool> {{ . }} </tool>{{ end }}
|
||||
|
||||
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- range $i, $m := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{- if eq .Role "user" }}<extra_id_1>User
|
||||
{{ .Content }}
|
||||
{{- if $last }}
|
||||
<extra_id_1>Assistant
|
||||
{{- end }}
|
||||
{{ else if eq .Role "tool" }}<extra_id_1>Tool
|
||||
{{ .Content }}
|
||||
{{- if $last }}
|
||||
<extra_id_1>Assistant
|
||||
{{- end }}
|
||||
{{ else if eq .Role "assistant" }}<extra_id_1>Assistant
|
||||
{{- if .ToolCalls }}
|
||||
{{ range .ToolCalls }}<toolcall> {"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} </toolcall> {{ end }}
|
||||
{{ else }}
|
||||
{{ .Content }}
|
||||
{{- if not $last }}
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
18
tools/testdata/nemotron.out
vendored
18
tools/testdata/nemotron.out
vendored
@@ -1,18 +0,0 @@
|
||||
<extra_id_0>System
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
|
||||
<tool> {"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} </tool>
|
||||
|
||||
|
||||
<extra_id_1>User
|
||||
What's the weather like today in Paris?
|
||||
<extra_id_1>Assistant
|
||||
<toolcall> {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} </toolcall>
|
||||
<extra_id_1>Tool
|
||||
22
|
||||
<extra_id_1>Assistant
|
||||
The current temperature in Paris, France is 22 degrees Celsius.
|
||||
<extra_id_1>User
|
||||
What's the weather like today in San Francisco and Toronto?
|
||||
<extra_id_1>Assistant
|
||||
51
tools/testdata/qwen2.5.gotmpl
vendored
51
tools/testdata/qwen2.5.gotmpl
vendored
@@ -1,51 +0,0 @@
|
||||
{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
|
||||
{{- else if .Messages }}
|
||||
{{- if or .System .Tools }}<|im_start|>system
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{{- range .Tools }}
|
||||
{"type": "function", "function": {{ .Function }}}
|
||||
{{- end }}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
{{- end }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{- if eq .Role "user" }}<|im_start|>user
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||
{{ if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{ end }}</tool_call>
|
||||
{{- end }}{{ if not $last }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||
<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response><|im_end|>
|
||||
{{ end }}
|
||||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- else }}
|
||||
{{- if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}
|
||||
31
tools/testdata/qwen2.5.out
vendored
31
tools/testdata/qwen2.5.out
vendored
@@ -1,31 +0,0 @@
|
||||
<|im_start|>system
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in Paris?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
22
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in San Francisco and Toronto?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
50
tools/testdata/qwen3.gotmpl
vendored
50
tools/testdata/qwen3.gotmpl
vendored
@@ -1,50 +0,0 @@
|
||||
{{- if .Messages }}
|
||||
{{- if or .System .Tools }}<|im_start|>system
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{{- range .Tools }}
|
||||
{"type": "function", "function": {{ .Function }}}
|
||||
{{- end }}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
{{- end }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{- if eq .Role "user" }}<|im_start|>user
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||
{{ if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{ end }}</tool_call>
|
||||
{{- end }}{{ if not $last }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||
<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response><|im_end|>
|
||||
{{ end }}
|
||||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- else }}
|
||||
{{- if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}
|
||||
31
tools/testdata/qwen3.out
vendored
31
tools/testdata/qwen3.out
vendored
@@ -1,31 +0,0 @@
|
||||
<|im_start|>system
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in Paris?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
22
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in San Francisco and Toronto?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
30
tools/testdata/tools.json
vendored
30
tools/testdata/tools.json
vendored
@@ -1,30 +0,0 @@
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"celsius",
|
||||
"fahrenheit"
|
||||
],
|
||||
"description": "The temperature unit to use. Infer this from the user's location."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location",
|
||||
"format"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
45
tools/testdata/xlam.gotmpl
vendored
45
tools/testdata/xlam.gotmpl
vendored
@@ -1,45 +0,0 @@
|
||||
{{- if .System }}{{ .System }}
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}### Instruction:
|
||||
{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }}
|
||||
[BEGIN OF TASK INSTRUCTION]
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the functions can be used, point it out and refuse to answer.
|
||||
If the given question lacks the parameters required by the function, also point it out.
|
||||
[END OF TASK INSTRUCTION]
|
||||
|
||||
[BEGIN OF AVAILABLE TOOLS]
|
||||
{{ $.Tools }}
|
||||
[END OF AVAILABLE TOOLS]
|
||||
|
||||
[BEGIN OF FORMAT INSTRUCTION]
|
||||
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||||
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
|
||||
```
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||
... (more tool calls as required)
|
||||
]
|
||||
}
|
||||
```
|
||||
[END OF FORMAT INSTRUCTION]
|
||||
|
||||
[BEGIN OF QUERY]
|
||||
{{ .Content }}
|
||||
[END OF QUERY]
|
||||
|
||||
|
||||
{{ else }}
|
||||
{{ .Content }}
|
||||
{{ end }}
|
||||
{{- else if .ToolCalls }}### Response:
|
||||
{"tool_calls": [{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}]}
|
||||
<|EOT|>
|
||||
{{ else if eq .Role "assistant" }}### Response:
|
||||
{{ .Content }}
|
||||
<|EOT|>
|
||||
{{ end }}
|
||||
{{- end }}### Response:
|
||||
40
tools/testdata/xlam.out
vendored
40
tools/testdata/xlam.out
vendored
@@ -1,40 +0,0 @@
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
### Instruction:
|
||||
What's the weather like today in Paris?
|
||||
### Response:
|
||||
{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]}
|
||||
<|EOT|>
|
||||
### Response:
|
||||
The current temperature in Paris, France is 22 degrees Celsius.
|
||||
<|EOT|>
|
||||
### Instruction:
|
||||
[BEGIN OF TASK INSTRUCTION]
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the functions can be used, point it out and refuse to answer.
|
||||
If the given question lacks the parameters required by the function, also point it out.
|
||||
[END OF TASK INSTRUCTION]
|
||||
|
||||
[BEGIN OF AVAILABLE TOOLS]
|
||||
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]
|
||||
[END OF AVAILABLE TOOLS]
|
||||
|
||||
[BEGIN OF FORMAT INSTRUCTION]
|
||||
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||||
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
|
||||
```
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||
... (more tool calls as required)
|
||||
]
|
||||
}
|
||||
```
|
||||
[END OF FORMAT INSTRUCTION]
|
||||
|
||||
[BEGIN OF QUERY]
|
||||
What's the weather like today in San Francisco and Toronto?
|
||||
[END OF QUERY]
|
||||
|
||||
|
||||
### Response:
|
||||
511
tools/tools.go
511
tools/tools.go
@@ -1,253 +1,294 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
gotmpl "text/template"
|
||||
"text/template"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidToolCall = errors.New("invalid tool call format")
|
||||
errAccumulateMore = errors.New("need to accumulate more content")
|
||||
type toolsState int
|
||||
|
||||
const (
|
||||
toolsState_LookingForTag toolsState = iota
|
||||
toolsState_ToolCalling
|
||||
toolsState_Done
|
||||
)
|
||||
|
||||
type Parser struct {
|
||||
greedyParseJSON bool
|
||||
prefix string
|
||||
prefixFound bool
|
||||
tmpl gotmpl.Template
|
||||
sb strings.Builder
|
||||
index int
|
||||
name string
|
||||
arguments string
|
||||
tag string
|
||||
tools []api.Tool
|
||||
|
||||
state toolsState
|
||||
buffer []byte
|
||||
n int
|
||||
}
|
||||
|
||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||
//
|
||||
// Parameters:
|
||||
// - s: The string to parse
|
||||
// - name: The field name from template that identifies the tool call name
|
||||
// - arguments: The field name from template that identifies the tool call arguments
|
||||
//
|
||||
// Returns:
|
||||
// - []api.ToolCall: The parsed tool calls if successful
|
||||
// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful
|
||||
func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) {
|
||||
// Check for balanced braces before attempting to parse
|
||||
braceCount := 0
|
||||
squareCount := 0
|
||||
startIndex := -1
|
||||
var rawToolCalls []string
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
// Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case.
|
||||
trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[")
|
||||
for i, c := range s {
|
||||
switch c {
|
||||
case '{':
|
||||
braceCount++
|
||||
if startIndex == -1 {
|
||||
startIndex = i
|
||||
}
|
||||
case '}':
|
||||
braceCount--
|
||||
if braceCount == 0 {
|
||||
rawToolCalls = append(rawToolCalls, s[startIndex:i+1])
|
||||
startIndex = -1
|
||||
}
|
||||
case '[':
|
||||
if trackSquareBrackets {
|
||||
squareCount++
|
||||
}
|
||||
case ']':
|
||||
if trackSquareBrackets {
|
||||
squareCount--
|
||||
}
|
||||
}
|
||||
|
||||
// Negative means we have an extra closing brace/bracket
|
||||
if braceCount < 0 || squareCount < 0 {
|
||||
return nil, errInvalidToolCall
|
||||
}
|
||||
}
|
||||
|
||||
// If braces/brackets aren't balanced, need more input
|
||||
if braceCount > 0 || squareCount > 0 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
|
||||
t := strings.TrimSpace(s)
|
||||
if len(t) == 0 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
// If the input is a single square bracket, it's not a valid tool call
|
||||
if t[0] == '[' && len(t) == 1 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
|
||||
// Attempt full unmarshal of the JSON
|
||||
var toolCalls []api.ToolCall
|
||||
for _, rawToolCall := range rawToolCalls {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Collect nested objects that could contain tool calls
|
||||
objs := collect(resp)
|
||||
if len(objs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract tool calls from objects
|
||||
for _, kv := range objs {
|
||||
n, nok := kv[name].(string)
|
||||
a, aok := kv[arguments].(map[string]any)
|
||||
if nok && aok {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: n,
|
||||
Arguments: a,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
slog.Debug("No valid tool call found in object.", "object", kv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Valid JSON, no tool calls found
|
||||
if len(toolCalls) == 0 {
|
||||
slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls)
|
||||
return nil, errInvalidToolCall
|
||||
}
|
||||
|
||||
return toolCalls, nil
|
||||
// NewParser creates a new tool call parser from a model's chat
|
||||
// template and a list of provided tools.
|
||||
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
|
||||
return NewParserWithTag(tools, parseTag(tmpl))
|
||||
}
|
||||
|
||||
// checkPrefix processes a string to find and handle a prefix pattern.
|
||||
//
|
||||
// Returns:
|
||||
// - The processed string with prefix removed if found
|
||||
// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful
|
||||
func (p *Parser) checkPrefix(s string) (string, error) {
|
||||
if s == "" || p.prefix == "" {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Check for prefix at start of string
|
||||
if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix {
|
||||
// Found prefix at start - accumulate for potential tool
|
||||
p.prefixFound = true
|
||||
return cut, nil
|
||||
}
|
||||
|
||||
// Check if prefix overlaps end of string
|
||||
if idx := suffixOverlap(s, p.prefix); idx != -1 {
|
||||
// Return everything except overlapping portion
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(s[idx:])
|
||||
return s[:idx], errAccumulateMore
|
||||
}
|
||||
|
||||
// Check if prefix appears in middle of string
|
||||
if idx := strings.Index(s, p.prefix); idx != -1 {
|
||||
// Save remainder starting at prefix for next pass
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(strings.TrimSpace(s[idx:]))
|
||||
// Return everything before prefix
|
||||
return s[:idx], errAccumulateMore
|
||||
}
|
||||
|
||||
// No partial prefix found
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Add processes a string input to parse tool calls and content.
|
||||
// It handles prefix detection and JSON parsing to extract tool calls.
|
||||
//
|
||||
// Returns:
|
||||
// - tools: Any parsed tool calls
|
||||
// - content: Non-tool call content
|
||||
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
||||
p.sb.WriteString(s)
|
||||
s = p.sb.String()
|
||||
|
||||
// Check for prefix pattern in input
|
||||
s, err := p.checkPrefix(s)
|
||||
if err != nil {
|
||||
// Need more input to complete prefix
|
||||
return nil, s
|
||||
}
|
||||
|
||||
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
||||
if !p.greedyParseJSON && !p.prefixFound {
|
||||
p.sb.Reset()
|
||||
return nil, s
|
||||
}
|
||||
|
||||
toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix)
|
||||
if err != nil {
|
||||
if errors.Is(err, errAccumulateMore) {
|
||||
return nil, ""
|
||||
}
|
||||
p.sb.Reset()
|
||||
// Only do greedy JSON parsing if there is no prefix from template
|
||||
if p.prefix != "" {
|
||||
p.greedyParseJSON = false
|
||||
}
|
||||
if p.index != 0 && p.prefix == "" {
|
||||
return nil, ""
|
||||
}
|
||||
if p.prefixFound {
|
||||
// Drop tokens since prefix was found
|
||||
return nil, ""
|
||||
}
|
||||
return nil, s
|
||||
}
|
||||
|
||||
for _, tc := range toolCalls {
|
||||
tc.Function.Index = p.index
|
||||
p.index++
|
||||
}
|
||||
|
||||
p.sb.Reset()
|
||||
return toolCalls, ""
|
||||
}
|
||||
|
||||
// NewParser creates a new tool call parser from a template. It extracts the tool call format,
|
||||
// prefix, and field names from the template to use for parsing tool calls from model output.
|
||||
//
|
||||
// Returns an error if the template does not contain valid tool call formatting.
|
||||
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||
parsed, err := template.Parse(templateToProcess.Root.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tt, err := toolTemplate(parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp := toolPrefix(templateToProcess)
|
||||
|
||||
name, arguments, err := extractToolArgs(tt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
|
||||
return &Parser{
|
||||
tmpl: *tt,
|
||||
sb: strings.Builder{},
|
||||
prefix: tp,
|
||||
greedyParseJSON: true,
|
||||
name: name,
|
||||
arguments: arguments,
|
||||
}, nil
|
||||
tag: tag,
|
||||
tools: tools,
|
||||
}
|
||||
}
|
||||
|
||||
// Add processes a string input to parse tool calls and content that
|
||||
// should be sent back to the user.
|
||||
func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
|
||||
if p.state == toolsState_Done {
|
||||
return nil, s
|
||||
}
|
||||
|
||||
p.buffer = append(p.buffer, s...)
|
||||
|
||||
if p.state == toolsState_LookingForTag {
|
||||
i, found := p.findTag()
|
||||
if i == -1 {
|
||||
content = string(p.buffer)
|
||||
p.buffer = []byte{}
|
||||
} else {
|
||||
content = string(p.buffer[:i])
|
||||
p.buffer = p.buffer[i:]
|
||||
}
|
||||
|
||||
// for models where { or [ are used as tool calling
|
||||
// tags, we only support parsing tools if the first non-
|
||||
// whitespace character is { or [
|
||||
if p.tag == "{" || p.tag == "[" {
|
||||
if strings.TrimSpace(content) != "" {
|
||||
p.state = toolsState_Done
|
||||
return nil, content + string(p.buffer)
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, content
|
||||
}
|
||||
|
||||
p.state = toolsState_ToolCalling
|
||||
}
|
||||
|
||||
for {
|
||||
call := p.parseToolCall()
|
||||
if call == nil {
|
||||
break
|
||||
}
|
||||
|
||||
calls = append(calls, *call)
|
||||
}
|
||||
|
||||
if p.done() {
|
||||
p.state = toolsState_Done
|
||||
content = string(p.buffer)
|
||||
p.buffer = []byte{}
|
||||
}
|
||||
|
||||
return calls, content
|
||||
}
|
||||
|
||||
// findTag searches the buffer to find and handle a tool calling tag
|
||||
// returning true if the tag was found and false otherwise, and
|
||||
// a string content signaling any content that should be sent back to the user
|
||||
func (p *Parser) findTag() (int, bool) {
|
||||
// First check for complete substring anywhere in s
|
||||
if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
|
||||
return i, true
|
||||
}
|
||||
|
||||
// Then check for partial suffix overlap
|
||||
max := min(len(p.buffer), len(p.tag))
|
||||
for i := max; i > 0; i-- {
|
||||
if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
|
||||
return len(p.buffer) - i, false
|
||||
}
|
||||
}
|
||||
return -1, false
|
||||
}
|
||||
|
||||
// parseToolCall finds the next complete tool call in the buffer
|
||||
// incrementing n and advancing the buffer.
|
||||
func (p *Parser) parseToolCall() *api.ToolCall {
|
||||
var tool *api.Tool
|
||||
var end int = len(p.buffer)
|
||||
var i int
|
||||
|
||||
// find tool name
|
||||
for _, t := range p.tools {
|
||||
n := t.Function.Name
|
||||
if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
|
||||
if i+len(n) < end {
|
||||
tool = &t
|
||||
end = i + len(n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tool == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// only look for arguments if the tool has parameters
|
||||
args := map[string]any{}
|
||||
if len(tool.Function.Parameters.Properties) > 0 {
|
||||
if args, i = p.findArguments(*tool); args == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if i > end {
|
||||
end = i
|
||||
}
|
||||
}
|
||||
|
||||
tc := &api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: tool.Function.Name,
|
||||
Arguments: args,
|
||||
Index: p.n,
|
||||
},
|
||||
}
|
||||
|
||||
p.n++
|
||||
p.buffer = p.buffer[end:]
|
||||
return tc
|
||||
}
|
||||
|
||||
// findArguments returns the first object that appears to be
|
||||
// arguments for the provided tool, returning nil
|
||||
func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) {
|
||||
if len(p.buffer) == 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
// no arguments to parse
|
||||
if len(tool.Function.Parameters.Properties) == 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var braces int
|
||||
var start int = -1
|
||||
var end int
|
||||
var object []byte
|
||||
|
||||
// find any outer json object
|
||||
for i, c := range p.buffer {
|
||||
if c == '{' {
|
||||
braces++
|
||||
if start == -1 {
|
||||
start = i
|
||||
}
|
||||
}
|
||||
|
||||
if c == '}' {
|
||||
if start != -1 {
|
||||
braces--
|
||||
if braces == 0 {
|
||||
end = i + 1
|
||||
object = p.buffer[start:end]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if braces > 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
|
||||
// not valid json
|
||||
if err := json.Unmarshal(object, &data); err != nil {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var find func(obj any) map[string]any
|
||||
find = func(obj any) map[string]any {
|
||||
switch obj := obj.(type) {
|
||||
case map[string]any:
|
||||
found := true
|
||||
for key := range obj {
|
||||
if _, exists := tool.Function.Parameters.Properties[key]; !exists {
|
||||
found = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if found {
|
||||
return obj
|
||||
}
|
||||
|
||||
for _, value := range obj {
|
||||
if result := find(value); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for _, item := range obj {
|
||||
if result := find(item); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
result := find(data)
|
||||
if result != nil {
|
||||
return result, end
|
||||
}
|
||||
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
// done checks if the parser is done parsing by looking
|
||||
// for closing tag. currently only } and ] are supported
|
||||
// for closing tags as {} or [] pairs may not always
|
||||
// represent tool calls and we need to send the content back
|
||||
func (p *Parser) done() bool {
|
||||
var open, close rune
|
||||
switch p.tag {
|
||||
case "{":
|
||||
open, close = '{', '}'
|
||||
case "[":
|
||||
open, close = '[', ']'
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
var count int
|
||||
for _, c := range p.buffer {
|
||||
if c == byte(open) {
|
||||
count++
|
||||
} else if c == byte(close) {
|
||||
count--
|
||||
if count == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Content returns any remaining content that
|
||||
// should be sent to the user. This should be the empty string
|
||||
// string unless the tag is { or [ and a tool call was not found
|
||||
func (p *Parser) Content() string {
|
||||
if p.n > 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if p.tag == "{" || p.tag == "[" {
|
||||
return string(p.buffer)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
1487
tools/tools_test.go
1487
tools/tools_test.go
File diff suppressed because it is too large
Load Diff
@@ -1,222 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
gotmpl "text/template"
|
||||
"text/template/parse"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition.
|
||||
// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any
|
||||
// immediate text nodes that follow. This is used to identify tool call prefixes and formatting.
|
||||
//
|
||||
// Returns:
|
||||
// - string: The extracted text following the first ".ToolCalls" condition found
|
||||
// - bool: Whether a ".ToolCalls" condition was found in the template
|
||||
func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) {
|
||||
if tmpl == nil || tmpl.Tree == nil {
|
||||
slog.Debug("template or tree is nil")
|
||||
return "", false
|
||||
}
|
||||
|
||||
var result string
|
||||
var found bool
|
||||
|
||||
var walk func(nodes []parse.Node)
|
||||
walk = func(nodes []parse.Node) {
|
||||
for _, node := range nodes {
|
||||
if found {
|
||||
return
|
||||
}
|
||||
|
||||
switch n := node.(type) {
|
||||
case *parse.IfNode:
|
||||
if isToolCallsNode(n) {
|
||||
// Collect immediate TextNode(s) at start of IfNode's list
|
||||
var sb strings.Builder
|
||||
for _, innerNode := range n.List.Nodes {
|
||||
if tn, ok := innerNode.(*parse.TextNode); ok {
|
||||
sb.Write(tn.Text)
|
||||
} else {
|
||||
// Stop at first non-text node
|
||||
break
|
||||
}
|
||||
}
|
||||
result = sb.String()
|
||||
found = true
|
||||
return
|
||||
}
|
||||
// Recurse into child nodes
|
||||
walk(n.List.Nodes)
|
||||
if n.ElseList != nil {
|
||||
walk(n.ElseList.Nodes)
|
||||
}
|
||||
case *parse.ListNode:
|
||||
walk(n.Nodes)
|
||||
case *parse.RangeNode:
|
||||
walk(n.List.Nodes)
|
||||
if n.ElseList != nil {
|
||||
walk(n.ElseList.Nodes)
|
||||
}
|
||||
case *parse.WithNode:
|
||||
walk(n.List.Nodes)
|
||||
if n.ElseList != nil {
|
||||
walk(n.ElseList.Nodes)
|
||||
}
|
||||
default:
|
||||
// Continue to next node
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
walk(tmpl.Tree.Root.Nodes)
|
||||
return result, found
|
||||
}
|
||||
|
||||
// isToolCallsNode detects if a node's condition includes ".ToolCalls"
|
||||
func isToolCallsNode(n *parse.IfNode) bool {
|
||||
for _, cmd := range n.Pipe.Cmds {
|
||||
for _, arg := range cmd.Args {
|
||||
if field, ok := arg.(*parse.FieldNode); ok {
|
||||
if slices.Contains(field.Ident, "ToolCalls") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func toolPrefix(tmpl *gotmpl.Template) string {
|
||||
tokenText, ok := extractToolCallsFormat(tmpl)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
tokenText = strings.TrimSpace(tokenText)
|
||||
tokenText = strings.ReplaceAll(tokenText, "\r", "")
|
||||
tokenText = strings.ReplaceAll(tokenText, "\n", " ")
|
||||
|
||||
return tokenText
|
||||
}
|
||||
|
||||
// toolTemplate creates a subtree from the node that ranges over .ToolCalls
|
||||
//
|
||||
// Returns:
|
||||
// - *gotmpl.Template: The subtree containing the .ToolCalls range
|
||||
// - error: Error if parsing failed
|
||||
func toolTemplate(t *template.Template) (*gotmpl.Template, error) {
|
||||
tmpl := t.Subtree(func(n parse.Node) bool {
|
||||
if t, ok := n.(*parse.RangeNode); ok {
|
||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
if tmpl == nil {
|
||||
return nil, errors.New("failed to find tool template")
|
||||
}
|
||||
|
||||
return tmpl, nil
|
||||
}
|
||||
|
||||
// suffixOverlap returns the index in s where the longest suffix overlap with prefix begins
|
||||
//
|
||||
// Returns:
|
||||
// - int: The starting index in s where the suffix overlap begins
|
||||
func suffixOverlap(s, prefix string) int {
|
||||
max := min(len(prefix), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, prefix[:i]) {
|
||||
return len(s) - i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// extractToolArgs executes a template with a known tool call format to extract the name and arguments
|
||||
//
|
||||
// Returns:
|
||||
// - string: The name of the tool call
|
||||
// - string: The arguments of the tool call
|
||||
// - error: Error if parsing failed
|
||||
func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) {
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||
"ToolCalls": {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "@@name@@",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"@@argument@@": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Extract JSON object between curly braces
|
||||
// JSON arrays are also valid as they will not be repeated in the template
|
||||
output := b.String()
|
||||
start := strings.Index(output, "{")
|
||||
end := strings.LastIndex(output, "}")
|
||||
if start == -1 || end == -1 || start > end {
|
||||
return "", "", errors.New("no valid JSON object found in template output")
|
||||
}
|
||||
jsonStr := output[start : end+1]
|
||||
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Find name and arguments fields
|
||||
for k, v := range obj {
|
||||
if str, ok := v.(string); ok && str == "@@name@@" {
|
||||
name = k
|
||||
} else if _, ok := v.(map[string]any); ok {
|
||||
arguments = k
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" || arguments == "" {
|
||||
slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments)
|
||||
return "", "", errors.New("missing required fields in tool call template")
|
||||
}
|
||||
|
||||
return name, arguments, nil
|
||||
}
|
||||
|
||||
// collect recursively traverses an object to collect all nested maps
|
||||
//
|
||||
// Returns:
|
||||
// - []map[string]any: A slice of all nested maps found in the object
|
||||
func collect(obj any) []map[string]any {
|
||||
var all []map[string]any
|
||||
switch o := obj.(type) {
|
||||
case map[string]any:
|
||||
all = append(all, o)
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
case []any:
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
@@ -1,497 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"testing"
|
||||
gotmpl "text/template"
|
||||
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
func TestExtractToolCallsFormat(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want string
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
name: "nil template",
|
||||
template: "",
|
||||
want: "",
|
||||
found: false,
|
||||
},
|
||||
{
|
||||
name: "basic tool call with text",
|
||||
template: "{{if .ToolCalls}}Hello world{{end}}",
|
||||
want: "Hello world",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "tool call with json format",
|
||||
template: "{{if .ToolCalls}}```json\n{{end}}",
|
||||
want: "```json\n",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "tool call in range",
|
||||
template: "{{range .ToolCalls}}tool: {{.}}{{end}}",
|
||||
want: "",
|
||||
found: false,
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple text nodes",
|
||||
template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}",
|
||||
want: "First text",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "nested if without tool calls",
|
||||
template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}",
|
||||
want: "",
|
||||
found: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tc.template)
|
||||
if err != nil && tc.template != "" {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
got, found := extractToolCallsFormat(tmpl)
|
||||
if got != tc.want {
|
||||
t.Errorf("got text %q, want %q", got, tc.want)
|
||||
}
|
||||
if found != tc.found {
|
||||
t.Errorf("got found %v, want %v", found, tc.found)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolPrefix(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "basic tool call with action prefix",
|
||||
template: "{{if .ToolCalls}}Action: ```json{{end}}",
|
||||
want: "Action: ```json",
|
||||
},
|
||||
{
|
||||
name: "incomplete functools bracket",
|
||||
template: "{{if .ToolCalls}}functools[{{end}}",
|
||||
want: "functools[",
|
||||
},
|
||||
{
|
||||
name: "tool call with angle brackets",
|
||||
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
|
||||
want: "Hello, world! <tool_call>",
|
||||
},
|
||||
{
|
||||
name: "multiple tool call formats",
|
||||
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
|
||||
want: "[tool_call] <tool_call>",
|
||||
},
|
||||
{
|
||||
name: "single angle bracket tool call",
|
||||
template: "{{if .ToolCalls}}<tool_call>{{end}}",
|
||||
want: "<tool_call>",
|
||||
},
|
||||
{
|
||||
name: "incomplete angle bracket after tool call",
|
||||
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
|
||||
want: "[tool_call] <",
|
||||
},
|
||||
{
|
||||
name: "angle bracket prefix with tool call",
|
||||
template: "{{if .ToolCalls}}> <tool_call>{{end}}",
|
||||
want: "> <tool_call>",
|
||||
},
|
||||
{
|
||||
name: "uppercase tool call with incomplete bracket",
|
||||
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
|
||||
want: "[TOOL_CALL] [",
|
||||
},
|
||||
{
|
||||
name: "uppercase tool call with adjacent bracket",
|
||||
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
|
||||
want: "[TOOL_CALL][",
|
||||
},
|
||||
{
|
||||
name: "tool call with pipe delimiters",
|
||||
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
|
||||
want: "<|tool_call|>",
|
||||
},
|
||||
{
|
||||
name: "tool with no prefix",
|
||||
template: "{{if .ToolCalls}}{{end}}",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
got := toolPrefix(tmpl)
|
||||
if got != tt.want {
|
||||
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolTemplate(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "basic tool call range",
|
||||
template: "{{range .ToolCalls}}test{{end}}",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no tool calls",
|
||||
template: "{{range .Other}}test{{end}}",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nested tool calls",
|
||||
template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty template",
|
||||
template: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "tool calls in if statement",
|
||||
template: "{{if .ToolCalls}}test{{end}}",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
parsed, err := template.Parse(tmpl.Root.String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
_, err = toolTemplate(parsed)
|
||||
if err != nil && tt.want {
|
||||
t.Errorf("toolTemplate() = %v; want %v", err, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuffixOverlap(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
s string
|
||||
d string
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "no overlap",
|
||||
s: "hello world",
|
||||
d: "<tool_call>",
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "full overlap",
|
||||
s: "<tool_call>",
|
||||
d: "<tool_call>",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "partial overlap",
|
||||
s: "text <tool_call>",
|
||||
d: "<tool_call>",
|
||||
want: 5,
|
||||
},
|
||||
{
|
||||
name: "delimiter longer than string",
|
||||
s: "<tool>",
|
||||
d: "<tool_call>",
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
s: "",
|
||||
d: "<tool_call>",
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "empty delimiter",
|
||||
s: "<tool_call>",
|
||||
d: "",
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "single char overlap",
|
||||
s: "test<",
|
||||
d: "<tool_call>",
|
||||
want: 4,
|
||||
},
|
||||
{
|
||||
name: "partial tool call",
|
||||
s: "hello <tool_",
|
||||
d: "<tool_call>",
|
||||
want: 6,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := suffixOverlap(tt.s, tt.d)
|
||||
if got != tt.want {
|
||||
t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolArgs(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
wantName string
|
||||
wantArgs string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic tool call",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}`,
|
||||
wantName: "name",
|
||||
wantArgs: "parameters",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "tool call with whitespace",
|
||||
template: `{{range .ToolCalls}}
|
||||
{"name": "{{.Function.Name}}", "parameters": {{.Function.Arguments}}}
|
||||
{{end}}`,
|
||||
wantName: "name",
|
||||
wantArgs: "parameters",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "tool call with extra content",
|
||||
template: `Before {{range .ToolCalls}}
|
||||
{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}} After`,
|
||||
wantName: "name",
|
||||
wantArgs: "arguments",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no tool calls",
|
||||
template: `{{if .Something}}no tools here{{end}}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty template",
|
||||
template: ``,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "prefix within tool call",
|
||||
template: `{{- if .ToolCalls }}
|
||||
{{ range .ToolCalls }}
|
||||
<tool_call>
|
||||
{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
</tool_call>{{ end }}{{- end }}`,
|
||||
wantName: "name",
|
||||
wantArgs: "arguments",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "JSON array",
|
||||
template: `{{ range .ToolCalls }}
|
||||
[{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}]{{ end }}`,
|
||||
wantName: "name",
|
||||
wantArgs: "arguments",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}, invalid}{{ end }}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing name field",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"parameters": {{ .Function.Arguments }}}{{ end }}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing arguments field",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"name": "{{ .Function.Name }}"}{{ end }}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed JSON",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"name": {{ .Function.Name }}, "arguments": {{ .Function.Arguments }}{{ end }}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
gotName, gotArgs, err := extractToolArgs(tmpl)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("extractToolArgs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if gotName != tt.wantName {
|
||||
t.Errorf("extractToolArgs() gotName = %q, want %q", gotName, tt.wantName)
|
||||
}
|
||||
if gotArgs != tt.wantArgs {
|
||||
t.Errorf("extractToolArgs() gotArgs = %q, want %q", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollect(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
obj any
|
||||
want []map[string]any
|
||||
}{
|
||||
{
|
||||
name: "simple map",
|
||||
obj: map[string]any{
|
||||
"key": "value",
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"key": "value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested map",
|
||||
obj: map[string]any{
|
||||
"outer": map[string]any{
|
||||
"inner": "value",
|
||||
},
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"outer": map[string]any{"inner": "value"}},
|
||||
{"inner": "value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array of maps",
|
||||
obj: []any{
|
||||
map[string]any{"key1": "val1"},
|
||||
map[string]any{"key2": "val2"},
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"key1": "val1"},
|
||||
{"key2": "val2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deeply nested",
|
||||
obj: map[string]any{
|
||||
"l1": map[string]any{
|
||||
"l2": map[string]any{
|
||||
"l3": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"l1": map[string]any{"l2": map[string]any{"l3": "value"}}},
|
||||
{"l2": map[string]any{"l3": "value"}},
|
||||
{"l3": "value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-map value",
|
||||
obj: "string",
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := collect(tt.obj)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want))
|
||||
return
|
||||
}
|
||||
|
||||
// Compare each map in the result
|
||||
for i := range tt.want {
|
||||
if !mapsEqual(got[i], tt.want[i]) {
|
||||
t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mapsEqual compares two maps for deep equality
|
||||
func mapsEqual(m1, m2 map[string]any) bool {
|
||||
if len(m1) != len(m2) {
|
||||
return false
|
||||
}
|
||||
for k, v1 := range m1 {
|
||||
v2, ok := m2[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch val1 := v1.(type) {
|
||||
case map[string]any:
|
||||
val2, ok := v2.(map[string]any)
|
||||
if !ok || !mapsEqual(val1, val2) {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if v1 != v2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user