mirror of
https://github.com/ollama/ollama.git
synced 2026-02-27 12:36:54 -05:00
Compare commits
30 Commits
v0.7.1
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cf62838ce | ||
|
|
22aed78048 | ||
|
|
d3cbbbfd85 | ||
|
|
e8d1933b99 | ||
|
|
735e80787b | ||
|
|
8e3998b9dd | ||
|
|
a8ed68bd93 | ||
|
|
2ae65ae471 | ||
|
|
a3b6886b7d | ||
|
|
c6a6d7294d | ||
|
|
2cf007c9d1 | ||
|
|
0683efa637 | ||
|
|
0943001193 | ||
|
|
5c42800fca | ||
|
|
65f10c2823 | ||
|
|
aaa7818000 | ||
|
|
f15ffc4320 | ||
|
|
5f57b0ef42 | ||
|
|
aa25aff10d | ||
|
|
ea79003180 | ||
|
|
9239a254e0 | ||
|
|
066d0f4746 | ||
|
|
aea6fb9b58 | ||
|
|
012cf65340 | ||
|
|
a45231af47 | ||
|
|
2307fc2bcd | ||
|
|
6623898198 | ||
|
|
eda472df1b | ||
|
|
f18e0cb550 | ||
|
|
e8b981fa5d |
@@ -406,6 +406,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -449,6 +450,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
|
||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
@@ -585,6 +587,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
|
||||
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
||||
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
||||
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
|
||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
||||
|
||||
### Supported backends
|
||||
|
||||
@@ -24,7 +24,10 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -76,6 +79,14 @@ func NewClient(base *url.URL, http *http.Client) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
func getAuthorizationToken(ctx context.Context, challenge string) (string, error) {
|
||||
token, err := auth.Sign(ctx, []byte(challenge))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
|
||||
var reqBody io.Reader
|
||||
var data []byte
|
||||
@@ -97,6 +108,21 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
}
|
||||
|
||||
requestURL := c.base.JoinPath(path)
|
||||
|
||||
var token string
|
||||
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
|
||||
now := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
|
||||
token, err = getAuthorizationToken(ctx, chal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q := requestURL.Query()
|
||||
q.Set("ts", now)
|
||||
requestURL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -106,6 +132,10 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
request.Header.Set("Accept", "application/json")
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
|
||||
if token != "" {
|
||||
request.Header.Set("Authorization", token)
|
||||
}
|
||||
|
||||
respObj, err := c.http.Do(request)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -143,6 +173,22 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
}
|
||||
|
||||
requestURL := c.base.JoinPath(path)
|
||||
|
||||
var token string
|
||||
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
|
||||
var err error
|
||||
now := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
|
||||
token, err = getAuthorizationToken(ctx, chal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q := requestURL.Query()
|
||||
q.Set("ts", now)
|
||||
requestURL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -152,6 +198,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
request.Header.Set("Accept", "application/x-ndjson")
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
|
||||
if token != "" {
|
||||
request.Header.Set("Authorization", token)
|
||||
}
|
||||
|
||||
response, err := c.http.Do(request)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
34
api/types.go
34
api/types.go
@@ -83,6 +83,12 @@ type GenerateRequest struct {
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding. Needs to be a pointer so we can distinguish between false
|
||||
// (request that thinking _not_ be used) and unset (use the old behavior
|
||||
// before this option was introduced)
|
||||
Think *bool `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -108,6 +114,10 @@ type ChatRequest struct {
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding
|
||||
Think *bool `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@@ -126,8 +136,11 @@ func (t Tool) String() string {
|
||||
// role ("system", "user", or "assistant"), the content and an optional list
|
||||
// of images.
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
// Thinking contains the text that was inside thinking tags in the
|
||||
// original model output when ChatRequest.Think is enabled.
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
@@ -444,12 +457,13 @@ type ProcessResponse struct {
|
||||
|
||||
// ListModelResponse is a single model description in [ListResponse].
|
||||
type ListModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessModelResponse is a single model description in [ProcessResponse].
|
||||
@@ -478,6 +492,10 @@ type GenerateResponse struct {
|
||||
// Response is the textual response itself.
|
||||
Response string `json:"response"`
|
||||
|
||||
// Thinking contains the text that was inside thinking tags in the
|
||||
// original model output when ChatRequest.Think is enabled.
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
|
||||
// Done specifies if the response is complete.
|
||||
Done bool `json:"done"`
|
||||
|
||||
|
||||
@@ -372,3 +372,50 @@ func TestPropertyType_MarshalJSON(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
trueVal := true
|
||||
falseVal := false
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedThinking *bool
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "true",
|
||||
input: `{ "think": true }`,
|
||||
expectedThinking: &trueVal,
|
||||
},
|
||||
{
|
||||
name: "false",
|
||||
input: `{ "think": false }`,
|
||||
expectedThinking: &falseVal,
|
||||
},
|
||||
{
|
||||
name: "unset",
|
||||
input: `{ }`,
|
||||
expectedThinking: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
input: `{ "think": "true" }`,
|
||||
expectedThinking: nil,
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var req GenerateRequest
|
||||
err := json.Unmarshal([]byte(test.input), &req)
|
||||
if test.expectedError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expectedThinking, req.Think)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
178
cmd/cmd.go
178
cmd/cmd.go
@@ -39,6 +39,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/runner"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -46,6 +47,23 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, cap := range resp.Capabilities {
|
||||
if cap == model.CapabilityThinking {
|
||||
return
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
||||
}
|
||||
|
||||
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
|
||||
|
||||
func getModelfileName(cmd *cobra.Command) (string, error) {
|
||||
@@ -265,6 +283,9 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
req := &api.GenerateRequest{
|
||||
Model: opts.Model,
|
||||
KeepAlive: opts.KeepAlive,
|
||||
|
||||
// pass Think here so we fail before getting to the chat prompt if the model doesn't support it
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
|
||||
@@ -299,6 +320,22 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
opts.Format = format
|
||||
|
||||
thinkFlag := cmd.Flags().Lookup("think")
|
||||
if thinkFlag.Changed {
|
||||
think, err := cmd.Flags().GetBool("think")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Think = &think
|
||||
} else {
|
||||
opts.Think = nil
|
||||
}
|
||||
hidethinking, err := cmd.Flags().GetBool("hidethinking")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.HideThinking = hidethinking
|
||||
|
||||
keepAlive, err := cmd.Flags().GetString("keepalive")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -362,6 +399,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||
|
||||
// TODO: remove the projector info and vision info checks below,
|
||||
@@ -923,17 +965,19 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
type generateContextKey string
|
||||
|
||||
type runOptions struct {
|
||||
Model string
|
||||
ParentModel string
|
||||
Prompt string
|
||||
Messages []api.Message
|
||||
WordWrap bool
|
||||
Format string
|
||||
System string
|
||||
Images []api.ImageData
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Model string
|
||||
ParentModel string
|
||||
Prompt string
|
||||
Messages []api.Message
|
||||
WordWrap bool
|
||||
Format string
|
||||
System string
|
||||
Images []api.ImageData
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Think *bool
|
||||
HideThinking bool
|
||||
}
|
||||
|
||||
type displayResponseState struct {
|
||||
@@ -989,6 +1033,26 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
|
||||
}
|
||||
}
|
||||
|
||||
func thinkingOutputOpeningText(plainText bool) string {
|
||||
text := "Thinking...\n"
|
||||
|
||||
if plainText {
|
||||
return text
|
||||
}
|
||||
|
||||
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault + readline.ColorGrey
|
||||
}
|
||||
|
||||
func thinkingOutputClosingText(plainText bool) string {
|
||||
text := "...done thinking.\n\n"
|
||||
|
||||
if plainText {
|
||||
return text
|
||||
}
|
||||
|
||||
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault
|
||||
}
|
||||
|
||||
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -1016,14 +1080,34 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
var role string
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
p.StopAndClear()
|
||||
if response.Message.Content != "" || !opts.HideThinking {
|
||||
p.StopAndClear()
|
||||
}
|
||||
|
||||
latest = response
|
||||
|
||||
role = response.Message.Role
|
||||
if response.Message.Thinking != "" && !opts.HideThinking {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(false))
|
||||
thinkTagOpened = true
|
||||
}
|
||||
displayResponse(response.Message.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
if thinkTagOpened && !thinkTagClosed && content != "" {
|
||||
fmt.Print(thinkingOutputClosingText(false))
|
||||
thinkTagClosed = true
|
||||
}
|
||||
// purposefully not putting thinking blocks in the response, which would
|
||||
// only be needed if we later added tool calling to the cli (they get
|
||||
// filtered out anyway since current models don't expect them unless you're
|
||||
// about to finish some tool calls)
|
||||
fullResponse.WriteString(content)
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
@@ -1040,6 +1124,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
Messages: opts.Messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
if opts.KeepAlive != nil {
|
||||
@@ -1101,13 +1186,32 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
|
||||
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
p.StopAndClear()
|
||||
|
||||
latest = response
|
||||
content := response.Response
|
||||
|
||||
if response.Response != "" || !opts.HideThinking {
|
||||
p.StopAndClear()
|
||||
}
|
||||
|
||||
if response.Thinking != "" && !opts.HideThinking {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(plainText))
|
||||
thinkTagOpened = true
|
||||
}
|
||||
displayResponse(response.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
if thinkTagOpened && !thinkTagClosed && content != "" {
|
||||
fmt.Print(thinkingOutputClosingText(plainText))
|
||||
thinkTagClosed = true
|
||||
}
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
||||
return nil
|
||||
@@ -1133,6 +1237,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
System: opts.System,
|
||||
Options: opts.Options,
|
||||
KeepAlive: opts.KeepAlive,
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
@@ -1348,6 +1453,8 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||
runCmd.Flags().Bool("think", false, "Whether to use thinking mode for supported models")
|
||||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
@@ -1399,7 +1506,6 @@ func NewCLI() *cobra.Command {
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: ListRunningHandler,
|
||||
}
|
||||
|
||||
copyCmd := &cobra.Command{
|
||||
Use: "cp SOURCE DESTINATION",
|
||||
Short: "Copy a model",
|
||||
@@ -1488,3 +1594,45 @@ func NewCLI() *cobra.Command {
|
||||
|
||||
return rootCmd
|
||||
}
|
||||
|
||||
// If the user has explicitly set thinking options, either through the CLI or
|
||||
// through the `/set think` or `set nothink` interactive options, then we
|
||||
// respect them. Otherwise, we check model capabilities to see if the model
|
||||
// supports thinking. If the model does support thinking, we enable it.
|
||||
// Otherwise, we unset the thinking option (which is different than setting it
|
||||
// to false).
|
||||
//
|
||||
// If capabilities are not provided, we fetch them from the server.
|
||||
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*bool, error) {
|
||||
if explicitlySetByUser {
|
||||
return runOpts.Think, nil
|
||||
}
|
||||
|
||||
if caps == nil {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := client.Show(context.Background(), &api.ShowRequest{
|
||||
Model: runOpts.Model,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
caps = &ret.Capabilities
|
||||
}
|
||||
|
||||
thinkingSupported := false
|
||||
for _, cap := range *caps {
|
||||
if cap == model.CapabilityThinking {
|
||||
thinkingSupported = true
|
||||
}
|
||||
}
|
||||
|
||||
if thinkingSupported {
|
||||
thinking := true
|
||||
return &thinking, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -62,6 +62,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
|
||||
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set think Enable thinking")
|
||||
fmt.Fprintln(os.Stderr, " /set nothink Disable thinking")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
@@ -128,6 +130,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
var sb strings.Builder
|
||||
var multiline MultilineState
|
||||
var thinkExplicitlySet bool = opts.Think != nil
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -195,11 +198,19 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
opts.Model = args[1]
|
||||
opts.Messages = []api.Message{}
|
||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if strings.Contains(err.Error(), "does not support thinking") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
continue
|
||||
@@ -260,6 +271,22 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "think":
|
||||
think := true
|
||||
opts.Think = &think
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
fmt.Println("Set 'think' mode.")
|
||||
case "nothink":
|
||||
think := false
|
||||
opts.Think = &think
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
fmt.Println("Set 'nothink' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
@@ -448,6 +475,11 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "does not support thinking") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
sb.Reset()
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
if assistant != nil {
|
||||
|
||||
@@ -23,7 +23,7 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||
return errors.New("could not find ollama app")
|
||||
}
|
||||
path := strings.Split(link, "Ollama.app")
|
||||
if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil {
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", path[0]+"Ollama.app").Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
return waitForServer(ctx, client)
|
||||
|
||||
@@ -45,14 +45,11 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
// log.Printf("XXX attempting to start app %s", appExe)
|
||||
|
||||
cmd_path := "c:\\Windows\\system32\\cmd.exe"
|
||||
cmd := exec.Command(cmd_path, "/c", appExe)
|
||||
// TODO - these hide flags aren't working - still pops up a command window for some reason
|
||||
cmd := exec.Command(cmd_path, "/c", appExe, "hidden")
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true}
|
||||
|
||||
// TODO this didn't help either...
|
||||
cmd.Stdin = strings.NewReader("")
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -74,7 +71,16 @@ func isProcRunning(procName string) []uint32 {
|
||||
slog.Debug("failed to check for running installers", "error", err)
|
||||
return nil
|
||||
}
|
||||
pids = pids[:ret]
|
||||
if ret > uint32(len(pids)) {
|
||||
pids = make([]uint32, ret+10)
|
||||
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
|
||||
slog.Debug("failed to check for running installers", "error", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if ret < uint32(len(pids)) {
|
||||
pids = pids[:ret]
|
||||
}
|
||||
var matches []uint32
|
||||
for _, pid := range pids {
|
||||
if pid == 0 {
|
||||
|
||||
63
cmd/warn_thinking_test.go
Normal file
63
cmd/warn_thinking_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Test that a warning is printed when thinking is requested but not supported.
|
||||
func TestWarnMissingThinking(t *testing.T) {
|
||||
cases := []struct {
|
||||
capabilities []model.Capability
|
||||
expectWarn bool
|
||||
}{
|
||||
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
|
||||
{capabilities: []model.Capability{}, expectWarn: true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
|
||||
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
|
||||
}
|
||||
var req api.ShowRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
resp := api.ShowResponse{Capabilities: tc.capabilities}
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("encode response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
oldStderr := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
ensureThinkingSupport(t.Context(), client, "m")
|
||||
w.Close()
|
||||
os.Stderr = oldStderr
|
||||
out, _ := io.ReadAll(r)
|
||||
|
||||
warned := strings.Contains(string(out), "warning:")
|
||||
if tc.expectWarn && !warned {
|
||||
t.Errorf("expected warning, got none")
|
||||
}
|
||||
if !tc.expectWarn && warned {
|
||||
t.Errorf("did not expect warning, got: %s", string(out))
|
||||
}
|
||||
}
|
||||
}
|
||||
32
docs/api.md
32
docs/api.md
@@ -43,6 +43,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
||||
- `prompt`: the prompt to generate a response for
|
||||
- `suffix`: the text after the model response
|
||||
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
|
||||
- `think`: (for thinking models) should the model think before responding?
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
@@ -490,11 +491,13 @@ Generate the next message in a chat with a provided model. This is a streaming e
|
||||
- `model`: (required) the [model name](#model-names)
|
||||
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
||||
- `tools`: list of tools in JSON for the model to use if supported
|
||||
- `think`: (for thinking models) should the model think before responding?
|
||||
|
||||
The `message` object has the following fields:
|
||||
|
||||
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
|
||||
- `content`: the content of the message
|
||||
- `thinking`: (for thinking models) the model's thinking process
|
||||
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
||||
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
|
||||
|
||||
@@ -1154,11 +1157,15 @@ A single JSON object will be returned.
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"name": "deepseek-r1:latest",
|
||||
"model": "deepseek-r1:latest",
|
||||
"modified_at": "2025-05-10T08:06:48.639712648-07:00",
|
||||
"size": 4683075271,
|
||||
"digest": "0a8c266910232fd3291e71e5ba1e058cc5af9d411192cf88b6d30e92b6e73163",
|
||||
|
||||
"model": "codellama:13b",
|
||||
"modified_at": "2023-11-04T14:56:49.277302595-07:00",
|
||||
"size": 7365960935,
|
||||
"digest": "9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697",
|
||||
"capabilities": [
|
||||
"completion"
|
||||
],
|
||||
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
@@ -1171,11 +1178,16 @@ A single JSON object will be returned.
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "llama3.2:latest",
|
||||
"model": "llama3.2:latest",
|
||||
"modified_at": "2025-05-04T17:37:44.706015396-07:00",
|
||||
"size": 2019393189,
|
||||
"digest": "a80c4f17acd55265feec403c7aef86be0c25983ab279d83f3bcd3abbcb5b8b72",
|
||||
|
||||
"model": "llama4:latest",
|
||||
"modified_at": "2023-12-07T09:32:18.757212583-08:00",
|
||||
"size": 3825819519,
|
||||
"digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
|
||||
"capabilities": [
|
||||
"completion",
|
||||
"vision"
|
||||
],
|
||||
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
|
||||
@@ -118,7 +118,7 @@ To run tests, use `go test`:
|
||||
go test ./...
|
||||
```
|
||||
|
||||
> NOTE: In rare cirumstances, you may nedd to change a package using the new
|
||||
> NOTE: In rare cirumstances, you may need to change a package using the new
|
||||
> "synctest" package in go1.24.
|
||||
>
|
||||
> If you do not have the "synctest" package enabled, you will not see build or
|
||||
|
||||
@@ -132,22 +132,12 @@ success
|
||||
|
||||
### Supported Quantizations
|
||||
|
||||
- `q4_0`
|
||||
- `q4_1`
|
||||
- `q5_0`
|
||||
- `q5_1`
|
||||
- `q8_0`
|
||||
|
||||
#### K-means Quantizations
|
||||
|
||||
- `q3_K_S`
|
||||
- `q3_K_M`
|
||||
- `q3_K_L`
|
||||
- `q4_K_S`
|
||||
- `q4_K_M`
|
||||
- `q5_K_S`
|
||||
- `q5_K_M`
|
||||
- `q6_K`
|
||||
|
||||
|
||||
## Sharing your model on ollama.com
|
||||
|
||||
@@ -183,6 +183,8 @@ var (
|
||||
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
||||
// ContextLength sets the default context length
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||
// Auth enables authentication between the Ollama client and server
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
|
||||
350
fs/gguf/gguf.go
Normal file
350
fs/gguf/gguf.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
typeUint8 uint32 = iota
|
||||
typeInt8
|
||||
typeUint16
|
||||
typeInt16
|
||||
typeUint32
|
||||
typeInt32
|
||||
typeFloat32
|
||||
typeBool
|
||||
typeString
|
||||
typeArray
|
||||
typeUint64
|
||||
typeInt64
|
||||
typeFloat64
|
||||
)
|
||||
|
||||
var ErrUnsupported = errors.New("unsupported")
|
||||
|
||||
type File struct {
|
||||
Magic [4]byte
|
||||
Version uint32
|
||||
|
||||
keyValues *lazy[KeyValue]
|
||||
tensors *lazy[TensorInfo]
|
||||
offset int64
|
||||
|
||||
file *os.File
|
||||
reader *readSeeker
|
||||
bts []byte
|
||||
}
|
||||
|
||||
func Open(path string) (f *File, err error) {
|
||||
f = &File{bts: make([]byte, 4096)}
|
||||
f.file, err = os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.reader = newReadSeeker(f.file, 32<<10)
|
||||
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if bytes.Equal(f.Magic[:], []byte("gguf")) {
|
||||
return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic)
|
||||
}
|
||||
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.Version != 3 {
|
||||
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
|
||||
}
|
||||
|
||||
f.tensors, err = newLazy(f, f.readTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.tensors.doneFunc = func() error {
|
||||
offset, err := f.reader.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32)
|
||||
f.offset = offset + (alignment-offset%alignment)%alignment
|
||||
return nil
|
||||
}
|
||||
|
||||
f.keyValues, err = newLazy(f, f.readKeyValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *File) readTensor() (TensorInfo, error) {
|
||||
name, err := readString(f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
dims, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
shape := make([]uint64, dims)
|
||||
for i := range dims {
|
||||
shape[i], err = read[uint64](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
}
|
||||
|
||||
type_, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
offset, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
return TensorInfo{
|
||||
Name: name,
|
||||
Offset: offset,
|
||||
Shape: shape,
|
||||
Type: TensorType(type_),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *File) readKeyValue() (KeyValue, error) {
|
||||
key, err := readString(f)
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
t, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
value, err := func() (any, error) {
|
||||
switch t {
|
||||
case typeUint8:
|
||||
return read[uint8](f)
|
||||
case typeInt8:
|
||||
return read[int8](f)
|
||||
case typeUint16:
|
||||
return read[uint16](f)
|
||||
case typeInt16:
|
||||
return read[int16](f)
|
||||
case typeUint32:
|
||||
return read[uint32](f)
|
||||
case typeInt32:
|
||||
return read[int32](f)
|
||||
case typeUint64:
|
||||
return read[uint64](f)
|
||||
case typeInt64:
|
||||
return read[int64](f)
|
||||
case typeFloat32:
|
||||
return read[float32](f)
|
||||
case typeFloat64:
|
||||
return read[float64](f)
|
||||
case typeBool:
|
||||
return read[bool](f)
|
||||
case typeString:
|
||||
return readString(f)
|
||||
case typeArray:
|
||||
return readArray(f)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
return KeyValue{
|
||||
Key: key,
|
||||
Value: Value{value},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func read[T any](f *File) (t T, err error) {
|
||||
err = binary.Read(f.reader, binary.LittleEndian, &t)
|
||||
return t, err
|
||||
}
|
||||
|
||||
func readString(f *File) (string, error) {
|
||||
n, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if int(n) > len(f.bts) {
|
||||
f.bts = make([]byte, n)
|
||||
}
|
||||
|
||||
bts := f.bts[:n]
|
||||
if _, err := io.ReadFull(f.reader, bts); err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer clear(bts)
|
||||
|
||||
return string(bts), nil
|
||||
}
|
||||
|
||||
func readArray(f *File) (any, error) {
|
||||
t, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case typeUint8:
|
||||
return readArrayData[uint8](f, n)
|
||||
case typeInt8:
|
||||
return readArrayData[int8](f, n)
|
||||
case typeUint16:
|
||||
return readArrayData[uint16](f, n)
|
||||
case typeInt16:
|
||||
return readArrayData[int16](f, n)
|
||||
case typeUint32:
|
||||
return readArrayData[uint32](f, n)
|
||||
case typeInt32:
|
||||
return readArrayData[int32](f, n)
|
||||
case typeUint64:
|
||||
return readArrayData[uint64](f, n)
|
||||
case typeInt64:
|
||||
return readArrayData[int64](f, n)
|
||||
case typeFloat32:
|
||||
return readArrayData[float32](f, n)
|
||||
case typeFloat64:
|
||||
return readArrayData[float64](f, n)
|
||||
case typeBool:
|
||||
return readArrayData[bool](f, n)
|
||||
case typeString:
|
||||
return readArrayString(f, n)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||
}
|
||||
}
|
||||
|
||||
func readArrayData[T any](f *File, n uint64) (s []T, err error) {
|
||||
s = make([]T, n)
|
||||
for i := range n {
|
||||
e, err := read[T](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s[i] = e
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func readArrayString(f *File, n uint64) (s []string, err error) {
|
||||
s = make([]string, n)
|
||||
for i := range n {
|
||||
e, err := readString(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s[i] = e
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (f *File) Close() error {
|
||||
f.keyValues.stop()
|
||||
f.tensors.stop()
|
||||
return f.file.Close()
|
||||
}
|
||||
|
||||
func (f *File) KeyValue(key string) KeyValue {
|
||||
if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") {
|
||||
key = f.KeyValue("general.architecture").String() + "." + key
|
||||
}
|
||||
|
||||
if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool {
|
||||
return kv.Key == key
|
||||
}); index >= 0 {
|
||||
return f.keyValues.values[index]
|
||||
}
|
||||
|
||||
for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() {
|
||||
if keyValue.Key == key {
|
||||
return keyValue
|
||||
}
|
||||
}
|
||||
|
||||
return KeyValue{}
|
||||
}
|
||||
|
||||
func (f *File) NumKeyValues() int {
|
||||
return int(f.keyValues.count)
|
||||
}
|
||||
|
||||
func (f *File) KeyValues() iter.Seq2[int, KeyValue] {
|
||||
return f.keyValues.All()
|
||||
}
|
||||
|
||||
func (f *File) TensorInfo(name string) TensorInfo {
|
||||
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool {
|
||||
return t.Name == name
|
||||
}); index >= 0 {
|
||||
return f.tensors.values[index]
|
||||
}
|
||||
|
||||
// fast-forward through key values if we haven't already
|
||||
_ = f.keyValues.rest()
|
||||
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() {
|
||||
if tensor.Name == name {
|
||||
return tensor
|
||||
}
|
||||
}
|
||||
|
||||
return TensorInfo{}
|
||||
}
|
||||
|
||||
func (f *File) NumTensors() int {
|
||||
return int(f.tensors.count)
|
||||
}
|
||||
|
||||
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] {
|
||||
// fast forward through key values if we haven't already
|
||||
f.keyValues.rest()
|
||||
return f.tensors.All()
|
||||
}
|
||||
|
||||
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
|
||||
t := f.TensorInfo(name)
|
||||
if t.NumBytes() == 0 {
|
||||
return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name)
|
||||
}
|
||||
|
||||
// fast forward through tensor info if we haven't already
|
||||
_ = f.tensors.rest()
|
||||
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil
|
||||
}
|
||||
320
fs/gguf/gguf_test.go
Normal file
320
fs/gguf/gguf_test.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
// Setup
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test.gguf")
|
||||
|
||||
if err := createTestGGUFFile(tempFile, map[string]any{
|
||||
"general.architecture": "llama",
|
||||
"general.alignment": int64(32),
|
||||
}, []testTensorInfo{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||
{Name: "output.weight", Shape: []uint64{512, 1000}, Type: 1}, // F16
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f, err := Open(tempFile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Test
|
||||
if got := f.NumKeyValues(); got != 2 {
|
||||
t.Errorf("NumKeyValues() = %d, want %d", got, 2)
|
||||
}
|
||||
if got := f.NumTensors(); got != 2 {
|
||||
t.Errorf("NumTensors() = %d, want %d", got, 2)
|
||||
}
|
||||
archKV := f.KeyValue("general.architecture")
|
||||
if archKV.Key == "" {
|
||||
t.Error("KeyValue(\"general.architecture\") not found")
|
||||
}
|
||||
if got := archKV.String(); got != "llama" {
|
||||
t.Errorf("KeyValue(\"general.architecture\").String() = %q, want %q", got, "llama")
|
||||
}
|
||||
alignKV := f.KeyValue("general.alignment")
|
||||
if alignKV.Key == "" {
|
||||
t.Error("KeyValue(\"general.alignment\") not found")
|
||||
}
|
||||
if got := alignKV.Int(); got != 32 {
|
||||
t.Errorf("KeyValue(\"general.alignment\").Int() = %d, want %d", got, 32)
|
||||
}
|
||||
expectedTensorNames := []string{"token_embd.weight", "output.weight"}
|
||||
var gotTensorNames []string
|
||||
for _, tensor := range f.TensorInfos() {
|
||||
gotTensorNames = append(gotTensorNames, tensor.Name)
|
||||
}
|
||||
if !slices.Equal(gotTensorNames, expectedTensorNames) {
|
||||
t.Errorf("tensor names = %v, want %v", gotTensorNames, expectedTensorNames)
|
||||
}
|
||||
tokenTensor := f.TensorInfo("token_embd.weight")
|
||||
if tokenTensor.Name != "token_embd.weight" {
|
||||
t.Error("TensorInfo(\"token_embd.weight\") not found")
|
||||
}
|
||||
if len(tokenTensor.Shape) == 0 {
|
||||
t.Error("TensorInfo(\"token_embd.weight\") has empty shape")
|
||||
}
|
||||
outputTensor := f.TensorInfo("output.weight")
|
||||
if outputTensor.Name != "output.weight" {
|
||||
t.Error("TensorInfo(\"output.weight\") not found")
|
||||
}
|
||||
if len(outputTensor.Shape) == 0 {
|
||||
t.Error("TensorInfo(\"output.weight\") has empty shape")
|
||||
}
|
||||
var gotKeyCount int
|
||||
for _, kv := range f.KeyValues() {
|
||||
gotKeyCount++
|
||||
if kv.Key == "" {
|
||||
t.Error("found key value with empty key")
|
||||
}
|
||||
}
|
||||
if gotKeyCount != 2 {
|
||||
t.Errorf("iterated key count = %d, want %d", gotKeyCount, 2)
|
||||
}
|
||||
tensorInfo, reader, err := f.TensorReader("token_embd.weight")
|
||||
if err != nil {
|
||||
t.Errorf("TensorReader(\"token_embd.weight\") error: %v", err)
|
||||
}
|
||||
if tensorInfo.Name != "token_embd.weight" {
|
||||
t.Errorf("TensorReader returned wrong tensor: %q", tensorInfo.Name)
|
||||
}
|
||||
if reader == nil {
|
||||
t.Error("TensorReader returned nil reader")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRead(b *testing.B) {
|
||||
// Create benchmark test file
|
||||
tempDir := b.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "benchmark.gguf")
|
||||
|
||||
if err := createTestGGUFFile(tempFile, map[string]any{
|
||||
"general.architecture": "llama",
|
||||
"general.alignment": int64(32),
|
||||
}, []testTensorInfo{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||
{Name: "output.weight", Shape: []uint64{512, 1000}, Type: 1}, // F16
|
||||
}); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Get file info for reporting
|
||||
info, err := os.Stat(tempFile)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.Logf("Benchmark file size: %d bytes", info.Size())
|
||||
|
||||
b.ReportAllocs()
|
||||
|
||||
for b.Loop() {
|
||||
f, err := Open(tempFile)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Access some data to ensure it's actually being read
|
||||
_ = f.KeyValue("general.architecture").String()
|
||||
_ = f.KeyValue("general.alignment").Int()
|
||||
_ = f.NumTensors()
|
||||
_ = f.NumKeyValues()
|
||||
|
||||
// Iterate through some tensors
|
||||
count := 0
|
||||
for _, tensor := range f.TensorInfos() {
|
||||
_ = tensor.Name
|
||||
count++
|
||||
if count >= 2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create test GGUF files
|
||||
func createTestGGUFFile(path string, keyValues map[string]any, tensors []testTensorInfo) error {
|
||||
file, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write GGUF magic
|
||||
if _, err := file.Write([]byte("GGUF")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write version
|
||||
if err := binary.Write(file, binary.LittleEndian, uint32(3)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write tensor count
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(tensors))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write metadata count
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(keyValues))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write metadata
|
||||
for key, value := range keyValues {
|
||||
if err := writeKeyValue(file, key, value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write tensor info
|
||||
for _, tensor := range tensors {
|
||||
if err := writeTensorInfo(file, tensor); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write some dummy tensor data
|
||||
dummyData := make([]byte, 1024)
|
||||
file.Write(dummyData)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type testTensorInfo struct {
|
||||
Name string
|
||||
Shape []uint64
|
||||
Type uint32
|
||||
}
|
||||
|
||||
func writeKeyValue(file *os.File, key string, value any) error {
|
||||
// Write key length and key
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(key))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := file.Write([]byte(key)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write value based on type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeString); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := file.Write([]byte(v))
|
||||
return err
|
||||
case int64:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeInt64); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(file, binary.LittleEndian, v)
|
||||
case bool:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeBool); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(file, binary.LittleEndian, v)
|
||||
case float64:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeFloat64); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(file, binary.LittleEndian, v)
|
||||
case []string:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, typeString); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, s := range v {
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(s))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := file.Write([]byte(s)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case []int64:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, typeInt64); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, i := range v {
|
||||
if err := binary.Write(file, binary.LittleEndian, i); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case []float64:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, typeFloat64); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, f := range v {
|
||||
if err := binary.Write(file, binary.LittleEndian, f); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported value type: %T", value)
|
||||
}
|
||||
}
|
||||
|
||||
func writeTensorInfo(file *os.File, tensor testTensorInfo) error {
|
||||
// Write tensor name
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(tensor.Name))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := file.Write([]byte(tensor.Name)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write dimensions
|
||||
if err := binary.Write(file, binary.LittleEndian, uint32(len(tensor.Shape))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, dim := range tensor.Shape {
|
||||
if err := binary.Write(file, binary.LittleEndian, dim); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write type
|
||||
if err := binary.Write(file, binary.LittleEndian, tensor.Type); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write offset (dummy value)
|
||||
return binary.Write(file, binary.LittleEndian, uint64(0))
|
||||
}
|
||||
102
fs/gguf/keyvalue.go
Normal file
102
fs/gguf/keyvalue.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"slices"
|
||||
)
|
||||
|
||||
type KeyValue struct {
|
||||
Key string
|
||||
Value
|
||||
}
|
||||
|
||||
type Value struct {
|
||||
value any
|
||||
}
|
||||
|
||||
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
|
||||
vv := reflect.ValueOf(v.value)
|
||||
if slices.Contains(kinds, vv.Kind()) {
|
||||
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
|
||||
switch vv := reflect.ValueOf(v.value); vv.Kind() {
|
||||
case reflect.Slice:
|
||||
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
|
||||
ts = make([]T, vv.Len())
|
||||
for i := range vv.Len() {
|
||||
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
|
||||
func (v Value) Int() int64 {
|
||||
return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||
}
|
||||
|
||||
// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil.
|
||||
func (v Value) Ints() (i64s []int64) {
|
||||
return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||
}
|
||||
|
||||
// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0.
|
||||
func (v Value) Uint() uint64 {
|
||||
return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||
}
|
||||
|
||||
// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil.
|
||||
func (v Value) Uints() (u64s []uint64) {
|
||||
return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||
}
|
||||
|
||||
// Float returns Value as a float. If it is not a float, it returns 0.
|
||||
func (v Value) Float() float64 {
|
||||
return value[float64](v, reflect.Float32, reflect.Float64)
|
||||
}
|
||||
|
||||
// Floats returns Value as a float slice. If it is not a float slice, it returns nil.
|
||||
func (v Value) Floats() (f64s []float64) {
|
||||
return values[float64](v, reflect.Float32, reflect.Float64)
|
||||
}
|
||||
|
||||
// Bool returns Value as a boolean. If it is not a boolean, it returns false.
|
||||
func (v Value) Bool() bool {
|
||||
return value[bool](v, reflect.Bool)
|
||||
}
|
||||
|
||||
// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil.
|
||||
func (v Value) Bools() (bools []bool) {
|
||||
return values[bool](v, reflect.Bool)
|
||||
}
|
||||
|
||||
// String returns Value as a string. If it is not a string, it returns an empty string.
|
||||
func (v Value) String() string {
|
||||
return value[string](v, reflect.String)
|
||||
}
|
||||
|
||||
// Strings returns Value as a string slice. If it is not a string slice, it returns nil.
|
||||
func (v Value) Strings() (strings []string) {
|
||||
return values[string](v, reflect.String)
|
||||
}
|
||||
|
||||
// IsNil checks if the Value is nil. It returns true if the value is nil or if it is a nil pointer, interface, slice, map, channel, or function.
|
||||
func (v Value) IsNil() bool {
|
||||
if v.value == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for nil pointers, interfaces, slices, maps, channels, and functions
|
||||
rv := reflect.ValueOf(v.value)
|
||||
switch rv.Kind() {
|
||||
case reflect.Ptr, reflect.Interface, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func:
|
||||
return rv.IsNil()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
208
fs/gguf/keyvalue_test.go
Normal file
208
fs/gguf/keyvalue_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func split(name string, values map[string][]any) (matched []any, unmatched []any) {
|
||||
for key, value := range values {
|
||||
if key == name {
|
||||
matched = value
|
||||
} else {
|
||||
unmatched = append(unmatched, value...)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestValue(t *testing.T) {
|
||||
values := map[string][]any{
|
||||
"int64": {int(42), int8(42), int16(42), int32(42), int64(42)},
|
||||
"uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)},
|
||||
"float64": {float32(42), float64(42)},
|
||||
"string": {"42", "hello"},
|
||||
"bool": {true, false},
|
||||
}
|
||||
|
||||
t.Run("int64", func(t *testing.T) {
|
||||
matched, unmatched := split("int64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64 := kv.Int(); i64 != 42 {
|
||||
t.Errorf("expected 42, got %d", i64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64 := kv.Int(); i64 != 0 {
|
||||
t.Errorf("expected 42, got %d", i64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uint64", func(t *testing.T) {
|
||||
matched, unmatched := split("uint64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64 := kv.Uint(); u64 != 42 {
|
||||
t.Errorf("expected 42, got %d", u64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64 := kv.Uint(); u64 != 0 {
|
||||
t.Errorf("expected 42, got %d", u64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("float64", func(t *testing.T) {
|
||||
matched, unmatched := split("float64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64 := kv.Float(); f64 != 42 {
|
||||
t.Errorf("expected 42, got %f", f64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64 := kv.Float(); f64 != 0 {
|
||||
t.Errorf("expected 42, got %f", f64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("string", func(t *testing.T) {
|
||||
matched, unmatched := split("string", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.String(); s != v {
|
||||
t.Errorf("expected 42, got %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.String(); s != "" {
|
||||
t.Errorf("expected 42, got %s", s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bool", func(t *testing.T) {
|
||||
matched, unmatched := split("bool", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bool(); b != v {
|
||||
t.Errorf("expected true, got %v", b)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bool(); b != false {
|
||||
t.Errorf("expected false, got %v", b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValues(t *testing.T) {
|
||||
values := map[string][]any{
|
||||
"int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}},
|
||||
"uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}},
|
||||
"float64s": {[]float32{42}, []float64{42}},
|
||||
"strings": {[]string{"42"}, []string{"hello"}},
|
||||
"bools": {[]bool{true}, []bool{false}},
|
||||
}
|
||||
|
||||
t.Run("int64s", func(t *testing.T) {
|
||||
matched, unmatched := split("int64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64s := kv.Ints(); i64s != nil {
|
||||
t.Errorf("expected nil, got %v", i64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uint64s", func(t *testing.T) {
|
||||
matched, unmatched := split("uint64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64s := kv.Uints(); u64s != nil {
|
||||
t.Errorf("expected nil, got %v", u64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("float64s", func(t *testing.T) {
|
||||
matched, unmatched := split("float64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64s := kv.Floats(); f64s != nil {
|
||||
t.Errorf("expected nil, got %v", f64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("strings", func(t *testing.T) {
|
||||
matched, unmatched := split("strings", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Strings(), v); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.Strings(); s != nil {
|
||||
t.Errorf("expected nil, got %v", s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bools", func(t *testing.T) {
|
||||
matched, unmatched := split("bools", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Bools(), v); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bools(); b != nil {
|
||||
t.Errorf("expected nil, got %v", b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
88
fs/gguf/lazy.go
Normal file
88
fs/gguf/lazy.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"iter"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type lazy[T any] struct {
|
||||
count uint64
|
||||
next func() (T, bool)
|
||||
stop func()
|
||||
values []T
|
||||
|
||||
doneFunc func() error
|
||||
}
|
||||
|
||||
func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) {
|
||||
it := lazy[T]{}
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
it.values = make([]T, 0)
|
||||
it.next, it.stop = iter.Pull(func(yield func(T) bool) {
|
||||
for i := range it.count {
|
||||
t, err := fn()
|
||||
if err != nil {
|
||||
slog.Error("error reading tensor", "index", i, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
it.values = append(it.values, t)
|
||||
if !yield(t) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if it.doneFunc != nil {
|
||||
it.doneFunc()
|
||||
}
|
||||
})
|
||||
|
||||
return &it, nil
|
||||
}
|
||||
|
||||
func (g *lazy[T]) Values() iter.Seq[T] {
|
||||
return func(yield func(T) bool) {
|
||||
for _, v := range g.All() {
|
||||
if !yield(v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *lazy[T]) All() iter.Seq2[int, T] {
|
||||
return func(yield func(int, T) bool) {
|
||||
for i := range int(g.count) {
|
||||
if i < len(g.values) {
|
||||
if !yield(i, g.values[i]) {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
t, ok := g.next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
if !yield(i, t) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *lazy[T]) rest() (collected bool) {
|
||||
for {
|
||||
_, ok := g.next()
|
||||
collected = collected || ok
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return collected
|
||||
}
|
||||
34
fs/gguf/reader.go
Normal file
34
fs/gguf/reader.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
)
|
||||
|
||||
type readSeeker struct {
|
||||
rs io.ReadSeeker
|
||||
br *bufio.Reader
|
||||
}
|
||||
|
||||
func newReadSeeker(rs io.ReadSeeker, size int) *readSeeker {
|
||||
return &readSeeker{
|
||||
rs: rs,
|
||||
br: bufio.NewReaderSize(rs, size),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *readSeeker) Read(p []byte) (int, error) {
|
||||
return b.br.Read(p)
|
||||
}
|
||||
|
||||
func (b *readSeeker) Seek(offset int64, whence int) (int64, error) {
|
||||
if whence == io.SeekCurrent {
|
||||
offset -= int64(b.br.Buffered())
|
||||
}
|
||||
n, err := b.rs.Seek(offset, whence)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b.br.Reset(b.rs)
|
||||
return n, nil
|
||||
}
|
||||
284
fs/gguf/tensor.go
Normal file
284
fs/gguf/tensor.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TensorInfo struct {
|
||||
Name string
|
||||
Offset uint64
|
||||
Shape []uint64
|
||||
Type TensorType
|
||||
}
|
||||
|
||||
func (t TensorInfo) NumValues() int64 {
|
||||
var numItems int64 = 1
|
||||
for _, dim := range t.Shape {
|
||||
numItems *= int64(dim)
|
||||
}
|
||||
return numItems
|
||||
}
|
||||
|
||||
// NumBytes returns the number of bytes in the tensor.
|
||||
func (t TensorInfo) NumBytes() int64 {
|
||||
return int64(float64(t.NumValues()) * t.Type.NumBytes())
|
||||
}
|
||||
|
||||
func (t TensorInfo) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("name", t.Name),
|
||||
slog.Int64("offset", int64(t.Offset)),
|
||||
slog.Any("shape", t.Shape),
|
||||
slog.Int64("num_values", t.NumValues()),
|
||||
slog.Int64("num_bytes", t.NumBytes()),
|
||||
slog.Any("type", t.Type),
|
||||
)
|
||||
}
|
||||
|
||||
type TensorType uint32
|
||||
|
||||
const (
|
||||
TensorTypeF32 TensorType = iota
|
||||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeQ4_2
|
||||
tensorTypeQ4_3
|
||||
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
TensorTypeQ8_0
|
||||
TensorTypeQ8_1
|
||||
TensorTypeQ2_K
|
||||
TensorTypeQ3_K
|
||||
TensorTypeQ4_K
|
||||
TensorTypeQ5_K
|
||||
TensorTypeQ6_K
|
||||
TensorTypeQ8_K
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeIQ2_XXS
|
||||
tensorTypeIQ2_XS
|
||||
tensorTypeIQ3_XXS
|
||||
tensorTypeIQ1_S
|
||||
tensorTypeIQ4_NL
|
||||
tensorTypeIQ3_S
|
||||
tensorTypeIQ2_S
|
||||
tensorTypeIQ4_XS
|
||||
|
||||
TensorTypeI8
|
||||
TensorTypeI16
|
||||
TensorTypeI32
|
||||
TensorTypeI64
|
||||
TensorTypeF64
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeIQ1_M
|
||||
|
||||
TensorTypeBF16
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeQ4_0_4_4
|
||||
tensorTypeQ4_0_4_8
|
||||
tensorTypeQ4_0_8_8
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeTQ1_0
|
||||
tensorTypeTQ2_0
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeIQ4_NL_4_4
|
||||
tensorTypeIQ4_NL_4_8
|
||||
tensorTypeIQ4_NL_8_8
|
||||
)
|
||||
|
||||
func (t TensorType) NumBytes() float64 {
|
||||
return float64(t.typeSize()) / float64(t.blockSize())
|
||||
}
|
||||
|
||||
func (t TensorType) typeSize() int64 {
|
||||
switch t {
|
||||
case TensorTypeF32:
|
||||
return 4
|
||||
case TensorTypeF16:
|
||||
return 2
|
||||
case TensorTypeQ4_0:
|
||||
return 2 + t.blockSize()/2
|
||||
case TensorTypeQ4_1:
|
||||
return 2 + 2 + t.blockSize()/2
|
||||
case TensorTypeQ5_0:
|
||||
return 2 + 4 + t.blockSize()/2
|
||||
case TensorTypeQ5_1:
|
||||
return 2 + 2 + 4 + t.blockSize()/2
|
||||
case TensorTypeQ8_0:
|
||||
return 2 + t.blockSize()
|
||||
case TensorTypeQ8_1:
|
||||
return 2 + 2 + t.blockSize()
|
||||
case TensorTypeQ2_K:
|
||||
return t.blockSize()/16 + t.blockSize()/4 + 2 + 2
|
||||
case TensorTypeQ3_K:
|
||||
return t.blockSize()/8 + t.blockSize()/4 + 12 + 2
|
||||
case TensorTypeQ4_K:
|
||||
return 2 + 2 + 12 + t.blockSize()/2
|
||||
case TensorTypeQ5_K:
|
||||
return 2 + 2 + 12 + t.blockSize()/8 + t.blockSize()/2
|
||||
case TensorTypeQ6_K:
|
||||
return t.blockSize()/2 + t.blockSize()/4 + t.blockSize()/16 + 2
|
||||
case TensorTypeQ8_K:
|
||||
return 4 + t.blockSize() + 2*t.blockSize()/16
|
||||
case tensorTypeIQ2_XXS:
|
||||
return 2 + 2*t.blockSize()/8
|
||||
case tensorTypeIQ2_XS:
|
||||
return 2 + 2*t.blockSize()/8 + t.blockSize()/32
|
||||
case tensorTypeIQ3_XXS:
|
||||
return 2 + t.blockSize()/4 + t.blockSize()/8
|
||||
case tensorTypeIQ1_S:
|
||||
return 2 + t.blockSize()/8 + t.blockSize()/16
|
||||
case tensorTypeIQ4_NL:
|
||||
return 2 + t.blockSize()/2
|
||||
case tensorTypeIQ3_S:
|
||||
return 2 + t.blockSize()/4 + t.blockSize()/8 + t.blockSize()/32 + 4
|
||||
case tensorTypeIQ2_S:
|
||||
return 2 + t.blockSize()/4 + t.blockSize()/16
|
||||
case tensorTypeIQ4_XS:
|
||||
return 2 + 2 + t.blockSize()/2 + t.blockSize()/64
|
||||
case TensorTypeI8:
|
||||
return 1
|
||||
case TensorTypeI16:
|
||||
return 2
|
||||
case TensorTypeI32:
|
||||
return 4
|
||||
case TensorTypeI64:
|
||||
return 8
|
||||
case TensorTypeF64:
|
||||
return 8
|
||||
case tensorTypeIQ1_M:
|
||||
return t.blockSize()/8 + t.blockSize()/16 + t.blockSize()/32
|
||||
case TensorTypeBF16:
|
||||
return 2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (t TensorType) blockSize() int64 {
|
||||
switch t {
|
||||
case TensorTypeF32,
|
||||
TensorTypeF16,
|
||||
TensorTypeI8,
|
||||
TensorTypeI16,
|
||||
TensorTypeI32,
|
||||
TensorTypeI64,
|
||||
TensorTypeF64,
|
||||
TensorTypeBF16:
|
||||
return 1
|
||||
case TensorTypeQ4_0,
|
||||
TensorTypeQ4_1,
|
||||
TensorTypeQ5_0,
|
||||
TensorTypeQ5_1,
|
||||
TensorTypeQ8_0,
|
||||
TensorTypeQ8_1,
|
||||
tensorTypeIQ4_NL:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
func (t TensorType) String() string {
|
||||
switch t {
|
||||
case TensorTypeF32:
|
||||
return "f32"
|
||||
case TensorTypeF16:
|
||||
return "f16"
|
||||
case TensorTypeQ4_0:
|
||||
return "q4_0"
|
||||
case TensorTypeQ4_1:
|
||||
return "q4_1"
|
||||
case tensorTypeQ4_2:
|
||||
return "q4_2"
|
||||
case tensorTypeQ4_3:
|
||||
return "q4_3"
|
||||
case TensorTypeQ5_0:
|
||||
return "q5_0"
|
||||
case TensorTypeQ5_1:
|
||||
return "q5_1"
|
||||
case TensorTypeQ8_0:
|
||||
return "q8_0"
|
||||
case TensorTypeQ8_1:
|
||||
return "q8_1"
|
||||
case TensorTypeQ2_K:
|
||||
return "q2_k"
|
||||
case TensorTypeQ3_K:
|
||||
return "q3_k"
|
||||
case TensorTypeQ4_K:
|
||||
return "q4_k"
|
||||
case TensorTypeQ5_K:
|
||||
return "q5_k"
|
||||
case TensorTypeQ6_K:
|
||||
return "q6_k"
|
||||
case TensorTypeQ8_K:
|
||||
return "q8_k"
|
||||
case tensorTypeIQ2_XXS:
|
||||
return "iq2_xxs"
|
||||
case tensorTypeIQ2_XS:
|
||||
return "iq2_xs"
|
||||
case tensorTypeIQ3_XXS:
|
||||
return "iq3_xxs"
|
||||
case tensorTypeIQ1_S:
|
||||
return "iq1_s"
|
||||
case tensorTypeIQ4_NL:
|
||||
return "iq4_nl"
|
||||
case tensorTypeIQ3_S:
|
||||
return "iq3_s"
|
||||
case tensorTypeIQ2_S:
|
||||
return "iq2_s"
|
||||
case tensorTypeIQ4_XS:
|
||||
return "iq4_xs"
|
||||
case TensorTypeI8:
|
||||
return "i8"
|
||||
case TensorTypeI16:
|
||||
return "i16"
|
||||
case TensorTypeI32:
|
||||
return "i32"
|
||||
case TensorTypeI64:
|
||||
return "i64"
|
||||
case TensorTypeF64:
|
||||
return "f64"
|
||||
case tensorTypeIQ1_M:
|
||||
return "iq1_m"
|
||||
case TensorTypeBF16:
|
||||
return "bf16"
|
||||
case tensorTypeQ4_0_4_4:
|
||||
return "q4_0_4_4"
|
||||
case tensorTypeQ4_0_4_8:
|
||||
return "q4_0_4_8"
|
||||
case tensorTypeQ4_0_8_8:
|
||||
return "q4_0_8_8"
|
||||
case tensorTypeTQ1_0:
|
||||
return "tq1_0"
|
||||
case tensorTypeTQ2_0:
|
||||
return "tq2_0"
|
||||
case tensorTypeIQ4_NL_4_4:
|
||||
return "iq4_nl_4_4"
|
||||
case tensorTypeIQ4_NL_4_8:
|
||||
return "iq4_nl_4_8"
|
||||
case tensorTypeIQ4_NL_8_8:
|
||||
return "iq4_nl_8_8"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t TensorType) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Uint64("value", uint64(t)),
|
||||
slog.String("name", strings.ToUpper(t.String())),
|
||||
slog.Int64("size", t.typeSize()),
|
||||
slog.Int64("block_size", t.blockSize()),
|
||||
slog.Float64("num_bytes", t.NumBytes()),
|
||||
)
|
||||
}
|
||||
1
integration/testdata/embed.json
vendored
1
integration/testdata/embed.json
vendored
File diff suppressed because one or more lines are too long
@@ -30,6 +30,11 @@ type Causal struct {
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// curReserve indicates that this forward pass is only for
|
||||
// memory reservation and we should not update our metadata
|
||||
// based on it.
|
||||
curReserve bool
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
@@ -159,12 +164,13 @@ func (c *Causal) Close() {
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
c.curReserve = reserve
|
||||
c.curBatchSize = len(batch.Positions)
|
||||
c.curSequences = batch.Sequences
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
if !reserve {
|
||||
if !c.curReserve {
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
@@ -304,6 +310,11 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
|
||||
if c.curReserve {
|
||||
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
|
||||
}
|
||||
|
||||
mask := make([]float32, batchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
|
||||
102
llama/patches/0017-ggml-Export-GPU-UUIDs.patch
Normal file
102
llama/patches/0017-ggml-Export-GPU-UUIDs.patch
Normal file
@@ -0,0 +1,102 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jesse Gross <jesse@ollama.com>
|
||||
Date: Thu, 24 Apr 2025 14:48:51 -0700
|
||||
Subject: [PATCH] ggml: Export GPU UUIDs
|
||||
|
||||
This enables matching up devices and information reported by the backend
|
||||
with tools (e.g. nvidia-smi) and system management libraries (e.g. nvml).
|
||||
---
|
||||
ggml/include/ggml-backend.h | 1 +
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 33 ++++++++++++++++++++++++++++++++
|
||||
ggml/src/ggml-metal/ggml-metal.m | 1 +
|
||||
3 files changed, 35 insertions(+)
|
||||
|
||||
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||
index 74e46716..a880df33 100644
|
||||
--- a/ggml/include/ggml-backend.h
|
||||
+++ b/ggml/include/ggml-backend.h
|
||||
@@ -152,6 +152,7 @@ extern "C" {
|
||||
struct ggml_backend_dev_props {
|
||||
const char * name;
|
||||
const char * description;
|
||||
+ const char * uuid;
|
||||
size_t memory_free;
|
||||
size_t memory_total;
|
||||
enum ggml_backend_dev_type type;
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index cb0d8528..4c829153 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -2884,6 +2884,7 @@ struct ggml_backend_cuda_device_context {
|
||||
int device;
|
||||
std::string name;
|
||||
std::string description;
|
||||
+ std::string uuid;
|
||||
};
|
||||
|
||||
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
|
||||
@@ -2896,6 +2897,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
|
||||
return ctx->description.c_str();
|
||||
}
|
||||
|
||||
+static const char * ggml_backend_cuda_device_get_uuid(ggml_backend_dev_t dev) {
|
||||
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
+ return ctx->uuid.c_str();
|
||||
+}
|
||||
+
|
||||
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
ggml_cuda_set_device(ctx->device);
|
||||
@@ -2910,6 +2916,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
|
||||
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_cuda_device_get_name(dev);
|
||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||
+ props->uuid = ggml_backend_cuda_device_get_uuid(dev);
|
||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
@@ -3458,6 +3465,32 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
+ #if !defined(GGML_USE_HIP)
|
||||
+ char uuid[64];
|
||||
+ snprintf(uuid, sizeof(uuid),
|
||||
+ "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||
+ (unsigned char)prop.uuid.bytes[0],
|
||||
+ (unsigned char)prop.uuid.bytes[1],
|
||||
+ (unsigned char)prop.uuid.bytes[2],
|
||||
+ (unsigned char)prop.uuid.bytes[3],
|
||||
+ (unsigned char)prop.uuid.bytes[4],
|
||||
+ (unsigned char)prop.uuid.bytes[5],
|
||||
+ (unsigned char)prop.uuid.bytes[6],
|
||||
+ (unsigned char)prop.uuid.bytes[7],
|
||||
+ (unsigned char)prop.uuid.bytes[8],
|
||||
+ (unsigned char)prop.uuid.bytes[9],
|
||||
+ (unsigned char)prop.uuid.bytes[10],
|
||||
+ (unsigned char)prop.uuid.bytes[11],
|
||||
+ (unsigned char)prop.uuid.bytes[12],
|
||||
+ (unsigned char)prop.uuid.bytes[13],
|
||||
+ (unsigned char)prop.uuid.bytes[14],
|
||||
+ (unsigned char)prop.uuid.bytes[15]
|
||||
+ );
|
||||
+ dev_ctx->uuid = uuid;
|
||||
+ #else
|
||||
+ dev_ctx->uuid = "GPU-" + std::string(prop.uuid.bytes, 16);
|
||||
+ #endif
|
||||
+
|
||||
ggml_backend_dev_t dev = new ggml_backend_device {
|
||||
/* .iface = */ ggml_backend_cuda_device_interface,
|
||||
/* .reg = */ ®,
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||
index 1b56f858..ee4f2dcb 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||
@@ -5703,6 +5703,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
|
||||
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_metal_device_get_name(dev);
|
||||
props->description = ggml_backend_metal_device_get_description(dev);
|
||||
+ props->uuid = "0";
|
||||
props->type = ggml_backend_metal_device_get_type(dev);
|
||||
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
props->caps = (struct ggml_backend_dev_caps) {
|
||||
@@ -797,7 +797,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
|
||||
res, err := http.DefaultClient.Do(serverReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("POST predict: %v", err)
|
||||
slog.Error("post predict", "error", err)
|
||||
return errors.New("model runner has unexpectedly stopped, this may be due to resource limitations or an internal error, check ollama server logs for details")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -123,6 +124,10 @@ type DeviceMemory struct {
|
||||
// may not be persistent across instances of the runner.
|
||||
Name string
|
||||
|
||||
// UUID is a unique persistent identifier for the device for matching
|
||||
// with system management libraries
|
||||
UUID string
|
||||
|
||||
// Weights is the per-layer memory needed for the model weights.
|
||||
Weights []Memory
|
||||
|
||||
@@ -133,6 +138,31 @@ type DeviceMemory struct {
|
||||
Graph Memory
|
||||
}
|
||||
|
||||
func memoryPresent(mem []Memory) bool {
|
||||
return slices.ContainsFunc(mem, func(m Memory) bool { return m.Size != 0 })
|
||||
}
|
||||
|
||||
func (m DeviceMemory) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
if memoryPresent(m.Weights) {
|
||||
attrs = append(attrs, slog.Any("Weights", m.Weights))
|
||||
}
|
||||
|
||||
if memoryPresent(m.Cache) {
|
||||
attrs = append(attrs, slog.Any("Cache", m.Cache))
|
||||
}
|
||||
|
||||
if m.Graph.Size != 0 {
|
||||
attrs = append(attrs, slog.Any("Graph", m.Graph))
|
||||
}
|
||||
|
||||
if len(attrs) > 0 && m.UUID != "" {
|
||||
attrs = append([]slog.Attr{slog.String("UUID", m.UUID)}, attrs...)
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// BackendMemory provides the amount of memory required to load the model
|
||||
// per device based on the BackendParams. In some cases, not all required
|
||||
// allocations will be known at this point. However, the size of the most recent
|
||||
@@ -150,6 +180,20 @@ type BackendMemory struct {
|
||||
GPUs []DeviceMemory
|
||||
}
|
||||
|
||||
func (m BackendMemory) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
if m.InputWeights.Size != 0 {
|
||||
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
|
||||
}
|
||||
|
||||
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
|
||||
for _, g := range m.GPUs {
|
||||
attrs = append(attrs, slog.Any(g.Name, g))
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
|
||||
|
||||
@@ -136,6 +136,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
|
||||
requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
|
||||
var props C.struct_ggml_backend_dev_props
|
||||
C.ggml_backend_dev_get_props(cpuDeviceBufferType.d, &props)
|
||||
requiredMemory.CPU.UUID = C.GoString(props.uuid)
|
||||
requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1)
|
||||
requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1)
|
||||
|
||||
@@ -150,6 +153,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
})
|
||||
btDeviceMemory[bt] = &requiredMemory.GPUs[i]
|
||||
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
|
||||
var props C.struct_ggml_backend_dev_props
|
||||
C.ggml_backend_dev_get_props(d, &props)
|
||||
requiredMemory.GPUs[i].UUID = C.GoString(props.uuid)
|
||||
requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1)
|
||||
requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
|
||||
}
|
||||
|
||||
1
ml/backend/ggml/ggml/include/ggml-backend.h
vendored
1
ml/backend/ggml/ggml/include/ggml-backend.h
vendored
@@ -152,6 +152,7 @@ extern "C" {
|
||||
struct ggml_backend_dev_props {
|
||||
const char * name;
|
||||
const char * description;
|
||||
const char * uuid;
|
||||
size_t memory_free;
|
||||
size_t memory_total;
|
||||
enum ggml_backend_dev_type type;
|
||||
|
||||
33
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
33
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
@@ -2884,6 +2884,7 @@ struct ggml_backend_cuda_device_context {
|
||||
int device;
|
||||
std::string name;
|
||||
std::string description;
|
||||
std::string uuid;
|
||||
};
|
||||
|
||||
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
|
||||
@@ -2896,6 +2897,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
|
||||
return ctx->description.c_str();
|
||||
}
|
||||
|
||||
static const char * ggml_backend_cuda_device_get_uuid(ggml_backend_dev_t dev) {
|
||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
return ctx->uuid.c_str();
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
ggml_cuda_set_device(ctx->device);
|
||||
@@ -2910,6 +2916,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
|
||||
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_cuda_device_get_name(dev);
|
||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||
props->uuid = ggml_backend_cuda_device_get_uuid(dev);
|
||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
@@ -3458,6 +3465,32 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
#if !defined(GGML_USE_HIP)
|
||||
char uuid[64];
|
||||
snprintf(uuid, sizeof(uuid),
|
||||
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||
(unsigned char)prop.uuid.bytes[0],
|
||||
(unsigned char)prop.uuid.bytes[1],
|
||||
(unsigned char)prop.uuid.bytes[2],
|
||||
(unsigned char)prop.uuid.bytes[3],
|
||||
(unsigned char)prop.uuid.bytes[4],
|
||||
(unsigned char)prop.uuid.bytes[5],
|
||||
(unsigned char)prop.uuid.bytes[6],
|
||||
(unsigned char)prop.uuid.bytes[7],
|
||||
(unsigned char)prop.uuid.bytes[8],
|
||||
(unsigned char)prop.uuid.bytes[9],
|
||||
(unsigned char)prop.uuid.bytes[10],
|
||||
(unsigned char)prop.uuid.bytes[11],
|
||||
(unsigned char)prop.uuid.bytes[12],
|
||||
(unsigned char)prop.uuid.bytes[13],
|
||||
(unsigned char)prop.uuid.bytes[14],
|
||||
(unsigned char)prop.uuid.bytes[15]
|
||||
);
|
||||
dev_ctx->uuid = uuid;
|
||||
#else
|
||||
dev_ctx->uuid = "GPU-" + std::string(prop.uuid.bytes, 16);
|
||||
#endif
|
||||
|
||||
ggml_backend_dev_t dev = new ggml_backend_device {
|
||||
/* .iface = */ ggml_backend_cuda_device_interface,
|
||||
/* .reg = */ ®,
|
||||
|
||||
@@ -5703,6 +5703,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
|
||||
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_metal_device_get_name(dev);
|
||||
props->description = ggml_backend_metal_device_get_description(dev);
|
||||
props->uuid = "0";
|
||||
props->type = ggml_backend_metal_device_get_type(dev);
|
||||
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
props->caps = (struct ggml_backend_dev_caps) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package model
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"strings"
|
||||
@@ -210,6 +211,14 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
type lazyIdsString struct {
|
||||
ids []int32
|
||||
}
|
||||
|
||||
func (l lazyIdsString) LogValue() slog.Value {
|
||||
return slog.AnyValue(fmt.Sprint(l.ids))
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
@@ -234,6 +243,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String())
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
@@ -61,6 +61,8 @@ const (
|
||||
ColorGrey = Esc + "[38;5;245m"
|
||||
ColorDefault = Esc + "[0m"
|
||||
|
||||
ColorBold = Esc + "[1m"
|
||||
|
||||
StartBracketedPaste = Esc + "[?2004h"
|
||||
EndBracketedPaste = Esc + "[?2004l"
|
||||
)
|
||||
|
||||
@@ -464,6 +464,10 @@ type downloadOpts struct {
|
||||
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
||||
if opts.digest == "" {
|
||||
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is is empty")
|
||||
}
|
||||
|
||||
fp, err := GetBlobsPath(opts.digest)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
||||
@@ -23,9 +23,10 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/thinking"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -37,6 +38,7 @@ var (
|
||||
errCapabilityInsert = errors.New("insert")
|
||||
errCapabilityVision = errors.New("vision")
|
||||
errCapabilityEmbedding = errors.New("embedding")
|
||||
errCapabilityThinking = errors.New("thinking")
|
||||
errInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
@@ -71,22 +73,20 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for completion capability
|
||||
r, err := os.Open(m.ModelPath)
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
if err == nil {
|
||||
defer r.Close()
|
||||
defer f.Close()
|
||||
|
||||
f, err := ggml.Decode(r, 1024)
|
||||
if err == nil {
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
} else {
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
}
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
embedding := f.KeyValue("pooling_type")
|
||||
if !embedding.Value.IsNil() {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
} else {
|
||||
slog.Error("couldn't decode ggml", "error", err)
|
||||
// If no embedding is specified, we assume the model supports completion
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
}
|
||||
vision := f.KeyValue("vision.block_count")
|
||||
if !vision.Value.IsNil() {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
} else {
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
@@ -111,6 +111,12 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
|
||||
// Check for thinking capability
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
if openingTag != "" && closingTag != "" {
|
||||
capabilities = append(capabilities, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
return capabilities
|
||||
}
|
||||
|
||||
@@ -127,6 +133,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
model.CapabilityInsert: errCapabilityInsert,
|
||||
model.CapabilityVision: errCapabilityVision,
|
||||
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||
model.CapabilityThinking: errCapabilityThinking,
|
||||
}
|
||||
|
||||
for _, cap := range want {
|
||||
@@ -141,11 +148,19 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
||||
err = fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
||||
}
|
||||
|
||||
return nil
|
||||
if slices.Contains(errs, errCapabilityThinking) {
|
||||
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
|
||||
// append a message to the existing error
|
||||
return fmt.Errorf("%w. Pull the model again to get the latest version with full thinking support", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Model) String() string {
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -13,81 +12,200 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Constants for GGUF magic bytes and version
|
||||
var (
|
||||
ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF"
|
||||
ggufVer = uint32(3) // Version 3
|
||||
// GGUF type constants (matching gguf package)
|
||||
const (
|
||||
typeUint8 = uint32(0)
|
||||
typeInt8 = uint32(1)
|
||||
typeUint16 = uint32(2)
|
||||
typeInt16 = uint32(3)
|
||||
typeUint32 = uint32(4)
|
||||
typeInt32 = uint32(5)
|
||||
typeFloat32 = uint32(6)
|
||||
typeBool = uint32(7)
|
||||
typeString = uint32(8)
|
||||
typeArray = uint32(9)
|
||||
typeUint64 = uint32(10)
|
||||
typeInt64 = uint32(11)
|
||||
typeFloat64 = uint32(12)
|
||||
)
|
||||
|
||||
// Helper function to create mock GGUF data
|
||||
func createMockGGUFData(architecture string, vision bool) []byte {
|
||||
var buf bytes.Buffer
|
||||
type testTensorInfo struct {
|
||||
Name string
|
||||
Shape []uint64
|
||||
Type uint32
|
||||
}
|
||||
|
||||
// Write GGUF header
|
||||
buf.Write(ggufMagic)
|
||||
binary.Write(&buf, binary.LittleEndian, ggufVer)
|
||||
|
||||
// Write tensor count (0 for our test)
|
||||
var numTensors uint64 = 0
|
||||
binary.Write(&buf, binary.LittleEndian, numTensors)
|
||||
|
||||
// Calculate number of metadata entries
|
||||
numMetaEntries := uint64(1) // architecture entry
|
||||
if vision {
|
||||
numMetaEntries++
|
||||
// Helper function to create test GGUF files (matching gguf package approach)
|
||||
func createTestGGUFFile(path string, keyValues map[string]any, tensors []testTensorInfo) error {
|
||||
file, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Add embedding entry if architecture is "bert"
|
||||
if architecture == "bert" {
|
||||
numMetaEntries++
|
||||
}
|
||||
binary.Write(&buf, binary.LittleEndian, numMetaEntries)
|
||||
defer file.Close()
|
||||
|
||||
// Write architecture metadata
|
||||
archKey := "general.architecture"
|
||||
keyLen := uint64(len(archKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(archKey)
|
||||
|
||||
// String type (8)
|
||||
var strType uint32 = 8
|
||||
binary.Write(&buf, binary.LittleEndian, strType)
|
||||
|
||||
// String length
|
||||
strLen := uint64(len(architecture))
|
||||
binary.Write(&buf, binary.LittleEndian, strLen)
|
||||
buf.WriteString(architecture)
|
||||
|
||||
if vision {
|
||||
visionKey := architecture + ".vision.block_count"
|
||||
keyLen = uint64(len(visionKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(visionKey)
|
||||
|
||||
// uint32 type (4)
|
||||
var uint32Type uint32 = 4
|
||||
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||
|
||||
// uint32 value (1)
|
||||
var countVal uint32 = 1
|
||||
binary.Write(&buf, binary.LittleEndian, countVal)
|
||||
}
|
||||
// Write embedding metadata if architecture is "bert"
|
||||
if architecture == "bert" {
|
||||
poolKey := architecture + ".pooling_type"
|
||||
keyLen = uint64(len(poolKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(poolKey)
|
||||
|
||||
// uint32 type (4)
|
||||
var uint32Type uint32 = 4
|
||||
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||
|
||||
// uint32 value (1)
|
||||
var poolingVal uint32 = 1
|
||||
binary.Write(&buf, binary.LittleEndian, poolingVal)
|
||||
// Write GGUF magic
|
||||
if _, err := file.Write([]byte("GGUF")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
// Write version
|
||||
if err := binary.Write(file, binary.LittleEndian, uint32(3)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write tensor count
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(tensors))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write metadata count
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(keyValues))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write metadata
|
||||
for key, value := range keyValues {
|
||||
if err := writeKeyValue(file, key, value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write tensor info
|
||||
for _, tensor := range tensors {
|
||||
if err := writeTensorInfo(file, tensor); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write some dummy tensor data
|
||||
dummyData := make([]byte, 1024)
|
||||
file.Write(dummyData)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeKeyValue(file *os.File, key string, value any) error {
|
||||
// Write key length and key
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(key))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := file.Write([]byte(key)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write value based on type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if err := binary.Write(file, binary.LittleEndian, uint32(typeString)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := file.Write([]byte(v))
|
||||
return err
|
||||
case int64:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeInt64); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(file, binary.LittleEndian, v)
|
||||
case uint32:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeUint32); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(file, binary.LittleEndian, v)
|
||||
case bool:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeBool); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(file, binary.LittleEndian, v)
|
||||
case float64:
|
||||
if err := binary.Write(file, binary.LittleEndian, uint32(typeFloat64)); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(file, binary.LittleEndian, v)
|
||||
case []string:
|
||||
if err := binary.Write(file, binary.LittleEndian, uint32(typeArray)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, typeString); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, s := range v {
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(s))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := file.Write([]byte(s)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case []int64:
|
||||
if err := binary.Write(file, binary.LittleEndian, uint32(typeArray)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, typeInt64); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, i := range v {
|
||||
if err := binary.Write(file, binary.LittleEndian, i); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case []float64:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, typeFloat64); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, f := range v {
|
||||
if err := binary.Write(file, binary.LittleEndian, f); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported value type: %T", value)
|
||||
}
|
||||
}
|
||||
|
||||
func writeTensorInfo(file *os.File, tensor testTensorInfo) error {
|
||||
// Write tensor name
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(tensor.Name))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := file.Write([]byte(tensor.Name)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write dimensions
|
||||
if err := binary.Write(file, binary.LittleEndian, uint32(len(tensor.Shape))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, dim := range tensor.Shape {
|
||||
if err := binary.Write(file, binary.LittleEndian, dim); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write type
|
||||
if err := binary.Write(file, binary.LittleEndian, tensor.Type); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write offset (dummy value)
|
||||
return binary.Write(file, binary.LittleEndian, uint64(0))
|
||||
}
|
||||
|
||||
func TestModelCapabilities(t *testing.T) {
|
||||
@@ -101,13 +219,38 @@ func TestModelCapabilities(t *testing.T) {
|
||||
// Create a simple model file for tests that don't depend on GGUF content
|
||||
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
|
||||
|
||||
if err := errors.Join(
|
||||
os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644),
|
||||
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
|
||||
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
|
||||
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
|
||||
); err != nil {
|
||||
t.Fatalf("Failed to create model files: %v", err)
|
||||
// Create completion model (llama architecture without vision)
|
||||
if err := createTestGGUFFile(completionModelPath, map[string]any{
|
||||
"general.architecture": "llama",
|
||||
}, []testTensorInfo{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||
}); err != nil {
|
||||
t.Fatalf("Failed to create completion model file: %v", err)
|
||||
}
|
||||
|
||||
// Create vision model (llama architecture with vision block count)
|
||||
if err := createTestGGUFFile(visionModelPath, map[string]any{
|
||||
"general.architecture": "llama",
|
||||
"llama.vision.block_count": uint32(1),
|
||||
}, []testTensorInfo{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||
}); err != nil {
|
||||
t.Fatalf("Failed to create vision model file: %v", err)
|
||||
}
|
||||
|
||||
// Create embedding model (bert architecture with pooling type)
|
||||
if err := createTestGGUFFile(embeddingModelPath, map[string]any{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(1),
|
||||
}, []testTensorInfo{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||
}); err != nil {
|
||||
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||
}
|
||||
|
||||
// Create simple model file for tests that don't depend on GGUF content
|
||||
if err := os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644); err != nil {
|
||||
t.Fatalf("Failed to create simple model file: %v", err)
|
||||
}
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
@@ -231,12 +374,29 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
simpleModelPath := filepath.Join(tempDir, "model.bin")
|
||||
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||
|
||||
if err := errors.Join(
|
||||
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
|
||||
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
|
||||
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
|
||||
); err != nil {
|
||||
t.Fatalf("Failed to create model files: %v", err)
|
||||
// Create vision model (llama architecture with vision block count)
|
||||
if err := createTestGGUFFile(visionModelPath, map[string]any{
|
||||
"general.architecture": "llama",
|
||||
"llama.vision.block_count": uint32(1),
|
||||
}, []testTensorInfo{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||
}); err != nil {
|
||||
t.Fatalf("Failed to create vision model file: %v", err)
|
||||
}
|
||||
|
||||
// Create embedding model (bert architecture with pooling type)
|
||||
if err := createTestGGUFFile(embeddingModelPath, map[string]any{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(1),
|
||||
}, []testTensorInfo{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||
}); err != nil {
|
||||
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||
}
|
||||
|
||||
// Create simple model file for tests that don't depend on GGUF content
|
||||
if err := os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644); err != nil {
|
||||
t.Fatalf("Failed to create simple model file: %v", err)
|
||||
}
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
|
||||
124
server/model.go
124
server/model.go
@@ -10,9 +10,6 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template/parse"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
@@ -128,124 +125,3 @@ func detectContentType(r io.Reader) (string, error) {
|
||||
|
||||
return "unknown", nil
|
||||
}
|
||||
|
||||
func parseObjects(s string) []map[string]any {
|
||||
var objs []map[string]any
|
||||
for offset := 0; offset < len(s); {
|
||||
var obj map[string]any
|
||||
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
|
||||
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
break
|
||||
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
|
||||
// skip over any syntax errors
|
||||
offset += int(syntax.Offset)
|
||||
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
|
||||
// skip over any unmarshalable types
|
||||
offset += int(unmarshalType.Offset)
|
||||
} else if err != nil {
|
||||
return nil
|
||||
} else {
|
||||
offset += int(decoder.InputOffset())
|
||||
objs = append(objs, obj)
|
||||
}
|
||||
}
|
||||
|
||||
return objs
|
||||
}
|
||||
|
||||
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||
// mxyng: this only really works if the input contains tool calls in some JSON format
|
||||
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||
// create a subtree from the node that ranges over .ToolCalls
|
||||
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
||||
if t, ok := n.(*parse.RangeNode); ok {
|
||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
if tmpl == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||
"ToolCalls": {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "@@name@@",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"@@argument@@": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
templateObjects := parseObjects(b.String())
|
||||
if len(templateObjects) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// find the keys that correspond to the name and arguments fields
|
||||
var name, arguments string
|
||||
for k, v := range templateObjects[0] {
|
||||
switch v.(type) {
|
||||
case string:
|
||||
name = k
|
||||
case map[string]any:
|
||||
arguments = k
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" || arguments == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
responseObjects := parseObjects(s)
|
||||
if len(responseObjects) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// collect all nested objects
|
||||
var collect func(any) []map[string]any
|
||||
collect = func(obj any) (all []map[string]any) {
|
||||
switch o := obj.(type) {
|
||||
case map[string]any:
|
||||
all = append(all, o)
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
case []any:
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
var objs []map[string]any
|
||||
for _, p := range responseObjects {
|
||||
objs = append(objs, collect(p)...)
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
for _, kv := range objs {
|
||||
n, nok := kv[name].(string)
|
||||
a, aok := kv[arguments].(map[string]any)
|
||||
if nok && aok {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: n,
|
||||
Arguments: a,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls, len(toolCalls) > 0
|
||||
}
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join(base, name))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return bytes.NewBuffer(bts)
|
||||
}
|
||||
|
||||
func TestExecuteWithTools(t *testing.T) {
|
||||
p := filepath.Join("testdata", "tools")
|
||||
cases := []struct {
|
||||
model string
|
||||
output string
|
||||
ok bool
|
||||
}{
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
|
||||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
|
||||
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||
|
||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||
{"command-r-plus", "Action: ```json" + `
|
||||
[
|
||||
{
|
||||
"tool_name": "get_current_weather",
|
||||
"parameters": {
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA"
|
||||
}
|
||||
},
|
||||
{
|
||||
"tool_name": "get_current_weather",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "Toronto, Canada"
|
||||
}
|
||||
}
|
||||
]
|
||||
` + "```", true},
|
||||
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||
{"llama3-groq-tool-use", `<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
|
||||
</tool_call>`, true},
|
||||
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
|
||||
{"nemotron", `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true},
|
||||
}
|
||||
|
||||
var tools []api.Tool
|
||||
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var messages []api.Message
|
||||
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
calls := []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"format": "celsius",
|
||||
"location": "Toronto, Canada",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.model, func(t *testing.T) {
|
||||
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("template", func(t *testing.T) {
|
||||
var actual bytes.Buffer
|
||||
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("parse", func(t *testing.T) {
|
||||
m := &Model{Template: tmpl}
|
||||
actual, ok := m.parseToolCalls(tt.output)
|
||||
if ok != tt.ok {
|
||||
t.Fatalf("expected %t, got %t", tt.ok, ok)
|
||||
}
|
||||
|
||||
if tt.ok {
|
||||
if diff := cmp.Diff(actual, calls); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseObjects(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want []map[string]any
|
||||
}{
|
||||
{
|
||||
input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
want: []map[string]any{
|
||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
|
||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall>`,
|
||||
want: []map[string]any{
|
||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall> <toolcall>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} </toolcall>`,
|
||||
want: []map[string]any{
|
||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
|
||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `{"name": "get_current_weather", "arguments": `,
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.input, func(t *testing.T) {
|
||||
got := parseObjects(tc.input)
|
||||
|
||||
if diff := cmp.Diff(got, tc.want); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func (mp ModelPath) BaseURL() *url.URL {
|
||||
func GetManifestPath() (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), "manifests")
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
@@ -139,7 +139,7 @@ func GetBlobsPath(digest string) (string, error) {
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
|
||||
@@ -19,7 +19,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *bool) (prompt string, images []llm.ImageData, _ error) {
|
||||
var system []api.Message
|
||||
|
||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||
@@ -41,8 +41,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
}
|
||||
}
|
||||
|
||||
thinkVal := false
|
||||
if think != nil {
|
||||
thinkVal = *think
|
||||
}
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
@@ -96,7 +100,11 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
// truncate any messages that do not fit into the context window
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
|
||||
thinkVal := false
|
||||
if think != nil {
|
||||
thinkVal = *think
|
||||
}
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -208,7 +208,8 @@ func TestChatPrompt(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
|
||||
think := false
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
|
||||
168
server/routes.go
168
server/routes.go
@@ -17,7 +17,6 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
@@ -38,6 +37,8 @@ import (
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/thinking"
|
||||
"github.com/ollama/ollama/tools"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -185,6 +186,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
if req.Suffix != "" {
|
||||
caps = append(caps, model.CapabilityInsert)
|
||||
}
|
||||
if req.Think != nil && *req.Think {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
// TODO(drifkin): consider adding a warning if it's false and the model
|
||||
// doesn't support thinking. It's not strictly required, but it can be a
|
||||
// hint that the user is on an older qwen3/r1 model that doesn't have an
|
||||
// updated template supporting thinking
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
@@ -253,6 +261,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||
}
|
||||
|
||||
values.Think = req.Think != nil && *req.Think
|
||||
values.IsThinkSet = req.Think != nil
|
||||
|
||||
var b bytes.Buffer
|
||||
if req.Context != nil {
|
||||
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
|
||||
@@ -272,6 +283,15 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
prompt = b.String()
|
||||
}
|
||||
|
||||
var thinkingState *thinking.Parser
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
|
||||
thinkingState = &thinking.Parser{
|
||||
OpeningTag: openingTag,
|
||||
ClosingTag: closingTag,
|
||||
}
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
// TODO (jmorganca): avoid building the response twice both here and below
|
||||
@@ -296,6 +316,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
},
|
||||
}
|
||||
|
||||
if thinkingState != nil {
|
||||
thinking, content := thinkingState.AddContent(cr.Content)
|
||||
res.Thinking = thinking
|
||||
res.Response = content
|
||||
}
|
||||
|
||||
if _, err := sb.WriteString(cr.Content); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
@@ -323,11 +349,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
var r api.GenerateResponse
|
||||
var sb strings.Builder
|
||||
var sbThinking strings.Builder
|
||||
var sbContent strings.Builder
|
||||
for rr := range ch {
|
||||
switch t := rr.(type) {
|
||||
case api.GenerateResponse:
|
||||
sb.WriteString(t.Response)
|
||||
sbThinking.WriteString(t.Thinking)
|
||||
sbContent.WriteString(t.Response)
|
||||
r = t
|
||||
case gin.H:
|
||||
msg, ok := t["error"].(string)
|
||||
@@ -343,7 +371,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
r.Response = sb.String()
|
||||
r.Thinking = sbThinking.String()
|
||||
r.Response = sbContent.String()
|
||||
|
||||
c.JSON(http.StatusOK, r)
|
||||
return
|
||||
}
|
||||
@@ -899,8 +929,7 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// tag should never be masked
|
||||
models = append(models, api.ListModelResponse{
|
||||
r := api.ListModelResponse{
|
||||
Model: n.DisplayShortest(),
|
||||
Name: n.DisplayShortest(),
|
||||
Size: m.Size(),
|
||||
@@ -913,7 +942,16 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
ParameterSize: cf.ModelType,
|
||||
QuantizationLevel: cf.FileType,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
model, err := GetModel(n.String())
|
||||
if err != nil {
|
||||
slog.Warn("bad model details", "name", n, "error", err)
|
||||
} else {
|
||||
r.Capabilities = model.Capabilities()
|
||||
}
|
||||
|
||||
models = append(models, r)
|
||||
}
|
||||
|
||||
slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
|
||||
@@ -1435,6 +1473,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if len(req.Tools) > 0 {
|
||||
caps = append(caps, model.CapabilityTools)
|
||||
}
|
||||
if req.Think != nil && *req.Think {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
@@ -1475,18 +1516,36 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var thinkingState *thinking.Parser
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
|
||||
thinkingState = &thinking.Parser{
|
||||
OpeningTag: openingTag,
|
||||
ClosingTag: closingTag,
|
||||
}
|
||||
}
|
||||
|
||||
var toolParser *tools.Parser
|
||||
if len(req.Tools) > 0 {
|
||||
toolParser, err = tools.NewParser(m.Template.Template)
|
||||
if err != nil {
|
||||
slog.Error("failed to create tool parser", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
var sb strings.Builder
|
||||
var toolCallIndex int = 0
|
||||
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
@@ -1506,43 +1565,40 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
},
|
||||
}
|
||||
|
||||
if thinkingState != nil {
|
||||
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
|
||||
if thinkingContent == "" && remainingContent == "" && !r.Done {
|
||||
// need to accumulate more to decide what to send
|
||||
return
|
||||
}
|
||||
res.Message.Content = remainingContent
|
||||
res.Message.Thinking = thinkingContent
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
|
||||
// however this was a simple change for now without reworking streaming logic of this (and other)
|
||||
// handlers
|
||||
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
|
||||
ch <- res
|
||||
return
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls, content := toolParser.Add(res.Message.Content)
|
||||
if len(content) > 0 {
|
||||
res.Message.Content = content
|
||||
} else if len(toolCalls) > 0 {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
res.Message.Content = ""
|
||||
} else if res.Message.Thinking != "" {
|
||||
// don't return
|
||||
} else {
|
||||
if r.Done {
|
||||
ch <- res
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Streaming tool calls:
|
||||
// If tools are recognized, use a flag to track the sending of a tool downstream
|
||||
// This ensures that content is cleared from the message on the last chunk sent
|
||||
sb.WriteString(r.Content)
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
for i := range toolCalls {
|
||||
toolCalls[i].Function.Index = toolCallIndex
|
||||
toolCallIndex++
|
||||
}
|
||||
res.Message.Content = ""
|
||||
sb.Reset()
|
||||
ch <- res
|
||||
return
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
// Send any remaining content if no tool calls were detected
|
||||
if toolCallIndex == 0 {
|
||||
res.Message.Content = sb.String()
|
||||
}
|
||||
ch <- res
|
||||
}
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
@@ -1550,12 +1606,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
var resp api.ChatResponse
|
||||
var sb strings.Builder
|
||||
var toolCalls []api.ToolCall
|
||||
var sbThinking strings.Builder
|
||||
var sbContent strings.Builder
|
||||
for rr := range ch {
|
||||
switch t := rr.(type) {
|
||||
case api.ChatResponse:
|
||||
sb.WriteString(t.Message.Content)
|
||||
sbThinking.WriteString(t.Message.Thinking)
|
||||
sbContent.WriteString(t.Message.Content)
|
||||
resp = t
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls = append(toolCalls, t.Message.ToolCalls...)
|
||||
}
|
||||
case gin.H:
|
||||
msg, ok := t["error"].(string)
|
||||
if !ok {
|
||||
@@ -1570,13 +1632,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
resp.Message.Content = sb.String()
|
||||
resp.Message.Content = sbContent.String()
|
||||
resp.Message.Thinking = sbThinking.String()
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
resp.Message.ToolCalls = toolCalls
|
||||
resp.Message.Content = ""
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
resp.Message.ToolCalls = toolCalls
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
@@ -1601,8 +1661,6 @@ func handleScheduleError(c *gin.Context, name string, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`)
|
||||
|
||||
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
|
||||
finalUserIndex := -1
|
||||
@@ -1614,7 +1672,17 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
|
||||
for i, msg := range msgs {
|
||||
if msg.Role == "assistant" && i < finalUserIndex {
|
||||
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "")
|
||||
// TODO(drifkin): this is from before we added proper thinking support.
|
||||
// However, even if thinking is not enabled (and therefore we shouldn't
|
||||
// change the user output), we should probably perform this filtering
|
||||
// for all thinking models (not just qwen3 & deepseek-r1) since it tends
|
||||
// to save tokens and improve quality.
|
||||
thinkingState := &thinking.Parser{
|
||||
OpeningTag: "<think>",
|
||||
ClosingTag: "</think>",
|
||||
}
|
||||
_, content := thinkingState.AddContent(msg.Content)
|
||||
msgs[i].Content = content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,6 +143,25 @@ func TestGenerateChat(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing thinking capability", func(t *testing.T) {
|
||||
think := true
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Think: &think,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
|
||||
@@ -167,6 +167,10 @@ type Values struct {
|
||||
api.Tools
|
||||
Prompt string
|
||||
Suffix string
|
||||
Think bool
|
||||
// whether or not the user explicitly set the thinking flag (vs. it being
|
||||
// implicitly false). Templates can't see whether `Think` is nil
|
||||
IsThinkSet bool
|
||||
|
||||
// forceLegacy is a flag used to test compatibility with legacy templates
|
||||
forceLegacy bool
|
||||
@@ -222,16 +226,20 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
system, messages := collate(v.Messages)
|
||||
if v.Prompt != "" && v.Suffix != "" {
|
||||
return t.Template.Execute(w, map[string]any{
|
||||
"Prompt": v.Prompt,
|
||||
"Suffix": v.Suffix,
|
||||
"Response": "",
|
||||
"Prompt": v.Prompt,
|
||||
"Suffix": v.Suffix,
|
||||
"Response": "",
|
||||
"Think": v.Think,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
})
|
||||
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||
return t.Template.Execute(w, map[string]any{
|
||||
"System": system,
|
||||
"Messages": messages,
|
||||
"Tools": v.Tools,
|
||||
"Response": "",
|
||||
"System": system,
|
||||
"Messages": messages,
|
||||
"Tools": v.Tools,
|
||||
"Response": "",
|
||||
"Think": v.Think,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -241,9 +249,11 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
for _, m := range messages {
|
||||
execute := func() error {
|
||||
if err := t.Template.Execute(&b, map[string]any{
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"Think": v.Think,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -286,9 +296,11 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
|
||||
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
|
||||
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"Think": v.Think,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
171
thinking/parser.go
Normal file
171
thinking/parser.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package thinking
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type thinkingState int
|
||||
|
||||
const (
|
||||
// We're looking for the opening tag, but we haven't seen any non-whitespace
|
||||
// characters yet
|
||||
thinkingState_LookingForOpening thinkingState = iota
|
||||
// We've seen the opening tag, but we haven't seen any non-whitespace
|
||||
// characters yet (we want to eat any whitespace between the opening tag and
|
||||
// the thinking content)
|
||||
thinkingState_ThinkingStartedEatingWhitespace
|
||||
// We've seen non-whitespace characters after the opening tag, but we haven't
|
||||
// seen the closing tag yet
|
||||
thinkingState_Thinking
|
||||
// We've seen the closing tag, but we haven't seen any non-whitespace
|
||||
// characters after the closing tag yet (we want to eat any whitespace between
|
||||
// the closing tag and the content)
|
||||
thinkingState_ThinkingDoneEatingWhitespace
|
||||
// We've seen the closing tag and seen at least one non-whitespace character
|
||||
// after it
|
||||
thinkingState_ThinkingDone
|
||||
)
|
||||
|
||||
func (s thinkingState) String() string {
|
||||
switch s {
|
||||
case thinkingState_LookingForOpening:
|
||||
return "LookingForOpening"
|
||||
case thinkingState_ThinkingStartedEatingWhitespace:
|
||||
return "ThinkingStartedEatingWhitespace"
|
||||
case thinkingState_Thinking:
|
||||
return "Thinking"
|
||||
case thinkingState_ThinkingDoneEatingWhitespace:
|
||||
return "ThinkingDoneEatingWhitespace"
|
||||
case thinkingState_ThinkingDone:
|
||||
return "ThinkingDone"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type Parser struct {
|
||||
state thinkingState
|
||||
OpeningTag string
|
||||
ClosingTag string
|
||||
acc strings.Builder
|
||||
}
|
||||
|
||||
// AddContent returns the thinking content and the non-thinking content that
|
||||
// should be immediately sent to the user. It will internally buffer if it needs
|
||||
// to see more raw content to disambiguate
|
||||
func (s *Parser) AddContent(content string) (string, string) {
|
||||
s.acc.WriteString(content)
|
||||
|
||||
var thinkingSb, remainingSb strings.Builder
|
||||
|
||||
var thinking, remaining string
|
||||
keepLooping := true
|
||||
// we loop because we might pass through multiple parsing states in a single
|
||||
// call to addContent, and we want to make sure callers don't have to wait for
|
||||
// data that's already unambiguous
|
||||
for keepLooping {
|
||||
thinking, remaining, keepLooping = eat(s)
|
||||
thinkingSb.WriteString(thinking)
|
||||
remainingSb.WriteString(remaining)
|
||||
}
|
||||
|
||||
return thinkingSb.String(), remainingSb.String()
|
||||
}
|
||||
|
||||
// the additional bool return is true iff we should continue eating
|
||||
func eat(s *Parser) (string, string, bool) {
|
||||
switch s.state {
|
||||
case thinkingState_LookingForOpening:
|
||||
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, s.OpeningTag) {
|
||||
after := strings.Join(strings.Split(trimmed, s.OpeningTag)[1:], s.OpeningTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
// after might contain more than just thinking tokens, so we continue
|
||||
// parsing instead of returning it as thinking tokens here
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
if after == "" {
|
||||
s.state = thinkingState_ThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
s.state = thinkingState_Thinking
|
||||
}
|
||||
return "", "", true
|
||||
} else if strings.HasPrefix(s.OpeningTag, trimmed) {
|
||||
// partial opening seen, so let's keep accumulating
|
||||
return "", "", false
|
||||
} else if trimmed == "" {
|
||||
// saw whitespace only, so let's keep accumulating
|
||||
return "", "", false
|
||||
} else {
|
||||
// didn't see an opening tag, but we have content, so thinking was skipped
|
||||
s.state = thinkingState_ThinkingDone
|
||||
// note that we use the original content, not the trimmed one because we
|
||||
// don't want to eat any whitespace in the real content if there were no
|
||||
// thinking tags
|
||||
return "", s.acc.String(), false
|
||||
}
|
||||
case thinkingState_ThinkingStartedEatingWhitespace:
|
||||
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
|
||||
s.acc.Reset()
|
||||
if trimmed == "" {
|
||||
return "", "", false
|
||||
} else {
|
||||
s.state = thinkingState_Thinking
|
||||
s.acc.WriteString(trimmed)
|
||||
return "", "", true
|
||||
}
|
||||
case thinkingState_Thinking:
|
||||
acc := s.acc.String()
|
||||
if strings.Contains(acc, s.ClosingTag) {
|
||||
split := strings.Split(acc, s.ClosingTag)
|
||||
thinking := split[0]
|
||||
remaining := strings.Join(split[1:], s.ClosingTag)
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
s.acc.Reset()
|
||||
if remaining == "" {
|
||||
s.state = thinkingState_ThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
s.state = thinkingState_ThinkingDone
|
||||
}
|
||||
return thinking, remaining, false
|
||||
} else if overlapLen := overlap(acc, s.ClosingTag); overlapLen > 0 {
|
||||
thinking := acc[:len(acc)-overlapLen]
|
||||
remaining := acc[len(acc)-overlapLen:]
|
||||
s.acc.Reset()
|
||||
// keep track of the candidate closing tag. We have to buffer it until it
|
||||
// becomes disambiguated
|
||||
s.acc.WriteString(remaining)
|
||||
return thinking, "", false
|
||||
} else {
|
||||
// purely just thinking tokens, so we can return them
|
||||
s.acc.Reset()
|
||||
return acc, "", false
|
||||
}
|
||||
case thinkingState_ThinkingDoneEatingWhitespace:
|
||||
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
|
||||
s.acc.Reset()
|
||||
// if we see non-whitespace, we're done eating the leading whitespace of the content
|
||||
if trimmed != "" {
|
||||
s.state = thinkingState_ThinkingDone
|
||||
}
|
||||
return "", trimmed, false
|
||||
case thinkingState_ThinkingDone:
|
||||
acc := s.acc.String()
|
||||
s.acc.Reset()
|
||||
return "", acc, false
|
||||
default:
|
||||
panic("unknown state")
|
||||
}
|
||||
}
|
||||
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
278
thinking/parser_test.go
Normal file
278
thinking/parser_test.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package thinking
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractThinking(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, wantContent, wantThink string
|
||||
}{
|
||||
{
|
||||
in: "<think> internal </think> world",
|
||||
wantThink: "internal ",
|
||||
wantContent: "world",
|
||||
},
|
||||
{
|
||||
in: "<think>a</think><think>b</think>c",
|
||||
wantThink: "a",
|
||||
wantContent: "<think>b</think>c",
|
||||
},
|
||||
{
|
||||
in: "no think",
|
||||
wantThink: "",
|
||||
wantContent: "no think",
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := Parser{
|
||||
OpeningTag: "<think>",
|
||||
ClosingTag: "</think>",
|
||||
}
|
||||
gotThinking, gotContent := parser.AddContent(tt.in)
|
||||
if gotContent != tt.wantContent || gotThinking != tt.wantThink {
|
||||
t.Errorf("case %d: got (%q,%q), want (%q,%q)", i, gotThinking, gotContent, tt.wantThink, tt.wantContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkingStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantThinking string
|
||||
wantContent string
|
||||
wantStateAfter thinkingState
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
skip bool
|
||||
steps []step
|
||||
}{
|
||||
{
|
||||
desc: "content without a thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " abc",
|
||||
wantThinking: "",
|
||||
wantContent: " abc",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before a thinking tag nerfs the thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " abc <think>def</think> ghi",
|
||||
wantThinking: "",
|
||||
wantContent: " abc <think>def</think> ghi",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "building up a thinking tag partially",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <th",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_LookingForOpening,
|
||||
},
|
||||
{
|
||||
input: "in",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_LookingForOpening,
|
||||
},
|
||||
{
|
||||
input: "k>a",
|
||||
wantThinking: "a",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_Thinking,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial closing tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>abc</th",
|
||||
wantThinking: "abc",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "ink>def",
|
||||
wantThinking: "",
|
||||
wantContent: "def",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial closing tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>abc</th",
|
||||
wantThinking: "abc",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "ing>def",
|
||||
wantThinking: "</thing>def",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "ghi</thi",
|
||||
wantThinking: "ghi",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "nk>jkl",
|
||||
wantThinking: "",
|
||||
wantContent: "jkl",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <think>abc</think>\n\ndef",
|
||||
wantThinking: "abc",
|
||||
wantContent: "def",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after thinking tag (incremental)",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <think>abc</think>",
|
||||
wantThinking: "abc",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: "\n\ndef",
|
||||
wantThinking: "",
|
||||
wantContent: "def",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after thinking tag with content and more whitespace",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <think>abc</think>\n\ndef ",
|
||||
wantThinking: "abc",
|
||||
wantContent: "def ",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
{
|
||||
input: " ghi",
|
||||
wantThinking: "",
|
||||
wantContent: " ghi",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "token by token",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: "\n",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: "</think>",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: "\n\n",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: "Hi",
|
||||
wantThinking: "",
|
||||
wantContent: "Hi",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
{
|
||||
input: " there",
|
||||
wantThinking: "",
|
||||
wantContent: " there",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "leading thinking whitespace",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <think> \t ",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: " these are some ",
|
||||
wantThinking: "these are some ",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "thoughts </think> ",
|
||||
wantThinking: "thoughts ",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: " more content",
|
||||
wantThinking: "",
|
||||
wantContent: "more content",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
parser := Parser{
|
||||
OpeningTag: "<think>",
|
||||
ClosingTag: "</think>",
|
||||
}
|
||||
if c.skip {
|
||||
continue
|
||||
}
|
||||
for i, step := range c.steps {
|
||||
thinking, content := parser.AddContent(step.input)
|
||||
if content != step.wantContent || thinking != step.wantThinking {
|
||||
t.Errorf("case %q (step %d): got (%q,%q), want (%q,%q)", c.desc, i, content, thinking, step.wantContent, step.wantThinking)
|
||||
}
|
||||
if parser.state != step.wantStateAfter {
|
||||
t.Errorf("case %q (step %d): got state %s, want %s", c.desc, i, parser.state, step.wantStateAfter)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
134
thinking/template.go
Normal file
134
thinking/template.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package thinking
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"text/template"
|
||||
"text/template/parse"
|
||||
)
|
||||
|
||||
func templateVisit(n parse.Node, enterFn func(parse.Node) bool, exitFn func(parse.Node)) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
shouldContinue := enterFn(n)
|
||||
if !shouldContinue {
|
||||
return
|
||||
}
|
||||
switch x := n.(type) {
|
||||
case *parse.ListNode:
|
||||
for _, c := range x.Nodes {
|
||||
templateVisit(c, enterFn, exitFn)
|
||||
}
|
||||
case *parse.BranchNode:
|
||||
if x.Pipe != nil {
|
||||
templateVisit(x.Pipe, enterFn, exitFn)
|
||||
}
|
||||
if x.List != nil {
|
||||
templateVisit(x.List, enterFn, exitFn)
|
||||
}
|
||||
if x.ElseList != nil {
|
||||
templateVisit(x.ElseList, enterFn, exitFn)
|
||||
}
|
||||
case *parse.ActionNode:
|
||||
templateVisit(x.Pipe, enterFn, exitFn)
|
||||
case *parse.WithNode:
|
||||
templateVisit(&x.BranchNode, enterFn, exitFn)
|
||||
case *parse.RangeNode:
|
||||
templateVisit(&x.BranchNode, enterFn, exitFn)
|
||||
case *parse.IfNode:
|
||||
templateVisit(&x.BranchNode, enterFn, exitFn)
|
||||
case *parse.TemplateNode:
|
||||
templateVisit(x.Pipe, enterFn, exitFn)
|
||||
case *parse.PipeNode:
|
||||
for _, c := range x.Cmds {
|
||||
templateVisit(c, enterFn, exitFn)
|
||||
}
|
||||
case *parse.CommandNode:
|
||||
for _, a := range x.Args {
|
||||
templateVisit(a, enterFn, exitFn)
|
||||
}
|
||||
// text, field, number, etc. are leaves – nothing to recurse into
|
||||
}
|
||||
if exitFn != nil {
|
||||
exitFn(n)
|
||||
}
|
||||
}
|
||||
|
||||
// InferTags uses a heuristic to infer the tags that surround thinking traces:
|
||||
// We look for a range node that iterates over "Messages" and then look for a
|
||||
// reference to "Thinking" like `{{.Thinking}}`. We then go up to the nearest
|
||||
// ListNode and take the first and last TextNodes as the opening and closing
|
||||
// tags.
|
||||
func InferTags(t *template.Template) (string, string) {
|
||||
ancestors := []parse.Node{}
|
||||
|
||||
openingTag := ""
|
||||
closingTag := ""
|
||||
|
||||
enterFn := func(n parse.Node) bool {
|
||||
ancestors = append(ancestors, n)
|
||||
|
||||
switch x := n.(type) {
|
||||
case *parse.FieldNode:
|
||||
if len(x.Ident) > 0 && x.Ident[0] == "Thinking" {
|
||||
var mostRecentRange *parse.RangeNode
|
||||
for i := len(ancestors) - 1; i >= 0; i-- {
|
||||
if r, ok := ancestors[i].(*parse.RangeNode); ok {
|
||||
mostRecentRange = r
|
||||
break
|
||||
}
|
||||
}
|
||||
if mostRecentRange == nil || !rangeUsesField(mostRecentRange, "Messages") {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO(drifkin): to be more robust, check that it's in the action
|
||||
// part, not the `if`'s pipeline part. We do match on the nearest list
|
||||
// that starts and ends with text nodes, which makes this not strictly
|
||||
// necessary for our heuristic
|
||||
|
||||
// go up to the nearest ancestor that is a *parse.ListNode
|
||||
for i := len(ancestors) - 1; i >= 0; i-- {
|
||||
if l, ok := ancestors[i].(*parse.ListNode); ok {
|
||||
firstNode := l.Nodes[0]
|
||||
if t, ok := firstNode.(*parse.TextNode); ok {
|
||||
openingTag = strings.TrimSpace(t.String())
|
||||
}
|
||||
lastNode := l.Nodes[len(l.Nodes)-1]
|
||||
if t, ok := lastNode.(*parse.TextNode); ok {
|
||||
closingTag = strings.TrimSpace(t.String())
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
exitFn := func(n parse.Node) {
|
||||
ancestors = ancestors[:len(ancestors)-1]
|
||||
}
|
||||
|
||||
templateVisit(t.Root, enterFn, exitFn)
|
||||
|
||||
return openingTag, closingTag
|
||||
}
|
||||
|
||||
// checks to see if the given field name is present in the pipeline of the given range node
|
||||
func rangeUsesField(rangeNode *parse.RangeNode, field string) bool {
|
||||
found := false
|
||||
enterFn := func(n parse.Node) bool {
|
||||
switch x := n.(type) {
|
||||
case *parse.FieldNode:
|
||||
if x.Ident[0] == field {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
templateVisit(rangeNode.BranchNode.Pipe, enterFn, nil)
|
||||
return found
|
||||
}
|
||||
130
thinking/template_test.go
Normal file
130
thinking/template_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package thinking
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
func TestInferThinkingTags(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
tmplString string
|
||||
wantOpeningTag string
|
||||
wantClosingTag string
|
||||
}{
|
||||
{
|
||||
desc: "basic",
|
||||
tmplString: `
|
||||
{{ if .Thinking}}
|
||||
/think
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{ if and $last .Thinking }}
|
||||
<think>{{ .Thinking }}</think>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
`,
|
||||
wantOpeningTag: "<think>",
|
||||
wantClosingTag: "</think>",
|
||||
},
|
||||
{
|
||||
desc: "doubly nested range",
|
||||
tmplString: `
|
||||
{{ if .Thinking}}
|
||||
/think
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- range $j, $_ := .NotMessages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{ if and $last .Thinking }}
|
||||
<think>{{ .Thinking }}</think>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
`,
|
||||
wantOpeningTag: "",
|
||||
wantClosingTag: "",
|
||||
},
|
||||
{
|
||||
desc: "whitespace is trimmed",
|
||||
tmplString: `
|
||||
{{ if .Thinking}}
|
||||
/think
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{ if and $last .Thinking }}
|
||||
Some text before {{ .Thinking }} Some text after
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
`,
|
||||
wantOpeningTag: "Some text before",
|
||||
wantClosingTag: "Some text after",
|
||||
},
|
||||
{
|
||||
desc: "qwen3",
|
||||
tmplString: `
|
||||
{{- if or .System .Tools .Thinking }}<|im_start|>system
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{{- range .Tools }}
|
||||
{"type": "function", "function": {{ .Function }}}
|
||||
{{- end }}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
{{- end }}
|
||||
{{- if .Thinking }}
|
||||
/think
|
||||
{{- else }}
|
||||
/no_think
|
||||
{{- end }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{- if eq .Role "user" }}<|im_start|>user
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||
{{ if and $last .Thinking }}
|
||||
<think>{{ .Thinking }}</think>
|
||||
{{ end }}
|
||||
{{ if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{ end }}</tool_call>
|
||||
{{- end }}{{ if not $last }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||
<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response><|im_end|>
|
||||
{{ end }}
|
||||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
`,
|
||||
wantOpeningTag: "<think>",
|
||||
wantClosingTag: "</think>",
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
tmpl := template.Must(template.New("test").Parse(c.tmplString))
|
||||
openingTag, closingTag := InferTags(tmpl)
|
||||
if openingTag != c.wantOpeningTag || closingTag != c.wantClosingTag {
|
||||
t.Errorf("case %q: got (%q,%q), want (%q,%q)", c.desc, openingTag, closingTag, c.wantOpeningTag, c.wantClosingTag)
|
||||
}
|
||||
}
|
||||
}
|
||||
44
tools/testdata/llama3.2.gotmpl
vendored
Normal file
44
tools/testdata/llama3.2.gotmpl
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
|
||||
{{ if .System }}{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question.
|
||||
|
||||
You are a helpful assistant with tool calling capabilities.
|
||||
{{- end }}<|eot_id|>
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
||||
{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
|
||||
{{- if and $.Tools $last }}
|
||||
|
||||
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
|
||||
|
||||
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
|
||||
|
||||
{{ range $.Tools }}
|
||||
{{- . }}
|
||||
{{ end }}
|
||||
{{ .Content }}<|eot_id|>
|
||||
{{- else }}
|
||||
|
||||
{{ .Content }}<|eot_id|>
|
||||
{{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ end }}
|
||||
{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
|
||||
{{- if .ToolCalls }}
|
||||
{{ range .ToolCalls }}
|
||||
{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
|
||||
{{- else }}
|
||||
|
||||
{{ .Content }}
|
||||
{{- end }}{{ if not $last }}<|eot_id|>{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
|
||||
|
||||
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
24
tools/testdata/llama3.2.out
vendored
Normal file
24
tools/testdata/llama3.2.out
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.When you receive a tool call response, use the output to format an answer to the orginal user question.
|
||||
|
||||
You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{"name": "get_current_weather", "parameters": {"format":"celsius","location":"Paris, France"}}<|eot_id|><|start_header_id|>ipython<|end_header_id|>
|
||||
|
||||
22<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
|
||||
|
||||
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
|
||||
|
||||
{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||
|
||||
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
51
tools/testdata/qwen2.5.gotmpl
vendored
Normal file
51
tools/testdata/qwen2.5.gotmpl
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
|
||||
{{- else if .Messages }}
|
||||
{{- if or .System .Tools }}<|im_start|>system
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{{- range .Tools }}
|
||||
{"type": "function", "function": {{ .Function }}}
|
||||
{{- end }}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
{{- end }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{- if eq .Role "user" }}<|im_start|>user
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||
{{ if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{ end }}</tool_call>
|
||||
{{- end }}{{ if not $last }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||
<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response><|im_end|>
|
||||
{{ end }}
|
||||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- else }}
|
||||
{{- if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}
|
||||
31
tools/testdata/qwen2.5.out
vendored
Normal file
31
tools/testdata/qwen2.5.out
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
<|im_start|>system
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in Paris?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
22
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in San Francisco and Toronto?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
50
tools/testdata/qwen3.gotmpl
vendored
Normal file
50
tools/testdata/qwen3.gotmpl
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
{{- if .Messages }}
|
||||
{{- if or .System .Tools }}<|im_start|>system
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{{- range .Tools }}
|
||||
{"type": "function", "function": {{ .Function }}}
|
||||
{{- end }}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
{{- end }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{- if eq .Role "user" }}<|im_start|>user
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||
{{ if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{ end }}</tool_call>
|
||||
{{- end }}{{ if not $last }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||
<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response><|im_end|>
|
||||
{{ end }}
|
||||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- else }}
|
||||
{{- if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}
|
||||
31
tools/testdata/qwen3.out
vendored
Normal file
31
tools/testdata/qwen3.out
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
<|im_start|>system
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in Paris?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
22
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in San Francisco and Toronto?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
253
tools/tools.go
Normal file
253
tools/tools.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
gotmpl "text/template"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidToolCall = errors.New("invalid tool call format")
|
||||
errAccumulateMore = errors.New("need to accumulate more content")
|
||||
)
|
||||
|
||||
type Parser struct {
|
||||
greedyParseJSON bool
|
||||
prefix string
|
||||
prefixFound bool
|
||||
tmpl gotmpl.Template
|
||||
sb strings.Builder
|
||||
index int
|
||||
name string
|
||||
arguments string
|
||||
}
|
||||
|
||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||
//
|
||||
// Parameters:
|
||||
// - s: The string to parse
|
||||
// - name: The field name from template that identifies the tool call name
|
||||
// - arguments: The field name from template that identifies the tool call arguments
|
||||
//
|
||||
// Returns:
|
||||
// - []api.ToolCall: The parsed tool calls if successful
|
||||
// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful
|
||||
func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) {
|
||||
// Check for balanced braces before attempting to parse
|
||||
braceCount := 0
|
||||
squareCount := 0
|
||||
startIndex := -1
|
||||
var rawToolCalls []string
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
// Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case.
|
||||
trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[")
|
||||
for i, c := range s {
|
||||
switch c {
|
||||
case '{':
|
||||
braceCount++
|
||||
if startIndex == -1 {
|
||||
startIndex = i
|
||||
}
|
||||
case '}':
|
||||
braceCount--
|
||||
if braceCount == 0 {
|
||||
rawToolCalls = append(rawToolCalls, s[startIndex:i+1])
|
||||
startIndex = -1
|
||||
}
|
||||
case '[':
|
||||
if trackSquareBrackets {
|
||||
squareCount++
|
||||
}
|
||||
case ']':
|
||||
if trackSquareBrackets {
|
||||
squareCount--
|
||||
}
|
||||
}
|
||||
|
||||
// Negative means we have an extra closing brace/bracket
|
||||
if braceCount < 0 || squareCount < 0 {
|
||||
return nil, errInvalidToolCall
|
||||
}
|
||||
}
|
||||
|
||||
// If braces/brackets aren't balanced, need more input
|
||||
if braceCount > 0 || squareCount > 0 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
|
||||
t := strings.TrimSpace(s)
|
||||
if len(t) == 0 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
// If the input is a single square bracket, it's not a valid tool call
|
||||
if t[0] == '[' && len(t) == 1 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
|
||||
// Attempt full unmarshal of the JSON
|
||||
var toolCalls []api.ToolCall
|
||||
for _, rawToolCall := range rawToolCalls {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Collect nested objects that could contain tool calls
|
||||
objs := collect(resp)
|
||||
if len(objs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract tool calls from objects
|
||||
for _, kv := range objs {
|
||||
n, nok := kv[name].(string)
|
||||
a, aok := kv[arguments].(map[string]any)
|
||||
if nok && aok {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: n,
|
||||
Arguments: a,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
slog.Debug("No valid tool call found in object.", "object", kv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Valid JSON, no tool calls found
|
||||
if len(toolCalls) == 0 {
|
||||
slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls)
|
||||
return nil, errInvalidToolCall
|
||||
}
|
||||
|
||||
return toolCalls, nil
|
||||
}
|
||||
|
||||
// checkPrefix processes a string to find and handle a prefix pattern.
|
||||
//
|
||||
// Returns:
|
||||
// - The processed string with prefix removed if found
|
||||
// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful
|
||||
func (p *Parser) checkPrefix(s string) (string, error) {
|
||||
if s == "" || p.prefix == "" {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Check for prefix at start of string
|
||||
if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix {
|
||||
// Found prefix at start - accumulate for potential tool
|
||||
p.prefixFound = true
|
||||
return cut, nil
|
||||
}
|
||||
|
||||
// Check if prefix overlaps end of string
|
||||
if idx := suffixOverlap(s, p.prefix); idx != -1 {
|
||||
// Return everything except overlapping portion
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(s[idx:])
|
||||
return s[:idx], errAccumulateMore
|
||||
}
|
||||
|
||||
// Check if prefix appears in middle of string
|
||||
if idx := strings.Index(s, p.prefix); idx != -1 {
|
||||
// Save remainder starting at prefix for next pass
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(strings.TrimSpace(s[idx:]))
|
||||
// Return everything before prefix
|
||||
return s[:idx], errAccumulateMore
|
||||
}
|
||||
|
||||
// No partial prefix found
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Add processes a string input to parse tool calls and content.
|
||||
// It handles prefix detection and JSON parsing to extract tool calls.
|
||||
//
|
||||
// Returns:
|
||||
// - tools: Any parsed tool calls
|
||||
// - content: Non-tool call content
|
||||
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
||||
p.sb.WriteString(s)
|
||||
s = p.sb.String()
|
||||
|
||||
// Check for prefix pattern in input
|
||||
s, err := p.checkPrefix(s)
|
||||
if err != nil {
|
||||
// Need more input to complete prefix
|
||||
return nil, s
|
||||
}
|
||||
|
||||
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
||||
if !p.greedyParseJSON && !p.prefixFound {
|
||||
p.sb.Reset()
|
||||
return nil, s
|
||||
}
|
||||
|
||||
toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix)
|
||||
if err != nil {
|
||||
if errors.Is(err, errAccumulateMore) {
|
||||
return nil, ""
|
||||
}
|
||||
p.sb.Reset()
|
||||
// Only do greedy JSON parsing if there is no prefix from template
|
||||
if p.prefix != "" {
|
||||
p.greedyParseJSON = false
|
||||
}
|
||||
if p.index != 0 && p.prefix == "" {
|
||||
return nil, ""
|
||||
}
|
||||
if p.prefixFound {
|
||||
// Drop tokens since prefix was found
|
||||
return nil, ""
|
||||
}
|
||||
return nil, s
|
||||
}
|
||||
|
||||
for _, tc := range toolCalls {
|
||||
tc.Function.Index = p.index
|
||||
p.index++
|
||||
}
|
||||
|
||||
p.sb.Reset()
|
||||
return toolCalls, ""
|
||||
}
|
||||
|
||||
// NewParser creates a new tool call parser from a template. It extracts the tool call format,
|
||||
// prefix, and field names from the template to use for parsing tool calls from model output.
|
||||
//
|
||||
// Returns an error if the template does not contain valid tool call formatting.
|
||||
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||
parsed, err := template.Parse(templateToProcess.Root.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tt, err := toolTemplate(parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp := toolPrefix(templateToProcess)
|
||||
|
||||
name, arguments, err := extractToolArgs(tt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Parser{
|
||||
tmpl: *tt,
|
||||
sb: strings.Builder{},
|
||||
prefix: tp,
|
||||
greedyParseJSON: true,
|
||||
name: name,
|
||||
arguments: arguments,
|
||||
}, nil
|
||||
}
|
||||
673
tools/tools_test.go
Normal file
673
tools/tools_test.go
Normal file
@@ -0,0 +1,673 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join(base, name))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return bytes.NewBuffer(bts)
|
||||
}
|
||||
|
||||
func TestParseJSONToolCalls(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
nameField string
|
||||
argsField string
|
||||
wantToolCalls []api.ToolCall
|
||||
wantErr error
|
||||
prefix string
|
||||
}{
|
||||
{
|
||||
name: "valid single tool call",
|
||||
input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test_tool",
|
||||
Arguments: map[string]any{
|
||||
"arg1": "value1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "incomplete JSON",
|
||||
input: `{"name": "test_tool", "arguments": {"arg1": `,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: nil,
|
||||
wantErr: errAccumulateMore,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
input: `not json at all`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: nil,
|
||||
wantErr: errInvalidToolCall,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "missing required fields",
|
||||
input: `{"other": "field"}`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: nil,
|
||||
wantErr: errInvalidToolCall,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls in array",
|
||||
input: `[
|
||||
{"name": "tool1", "arguments": {"arg1": 1}},
|
||||
{"name": "tool2", "arguments": {"arg2": "value"}}
|
||||
]`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool1",
|
||||
Arguments: map[string]any{
|
||||
"arg1": float64(1),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool2",
|
||||
Arguments: map[string]any{
|
||||
"arg2": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls without array",
|
||||
input: `
|
||||
{"name": "tool1", "arguments": {"arg1": 1}},
|
||||
{"name": "tool2", "arguments": {"arg2": "value"}}
|
||||
`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool1",
|
||||
Arguments: map[string]any{
|
||||
"arg1": float64(1),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool2",
|
||||
Arguments: map[string]any{
|
||||
"arg2": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls with text after",
|
||||
input: `
|
||||
{"name": "tool1", "arguments": {"arg1": 1}} text
|
||||
{"name": "tool2", "arguments": {"arg2": "value"}} text
|
||||
`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool1",
|
||||
Arguments: map[string]any{
|
||||
"arg1": float64(1),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool2",
|
||||
Arguments: map[string]any{
|
||||
"arg2": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "second tool call in array",
|
||||
input: `
|
||||
, {"name": "tool2", "arguments": {"arg2": "value"}}
|
||||
`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool2",
|
||||
Arguments: map[string]any{
|
||||
"arg2": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
prefix: "",
|
||||
},
|
||||
// a bad JSON would not return any tool calls or content as it would always accumulate more
|
||||
{
|
||||
name: "unbalanced square brackets",
|
||||
input: `[{"name": "tool1", "arguments": {"arg1": [1, 2}]`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: nil,
|
||||
wantErr: errAccumulateMore,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "incomplete square brackets",
|
||||
input: `[{"name": "tool1", "arguments": {"arg1": [1, 2, 3`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: nil,
|
||||
wantErr: errAccumulateMore,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "nested arrays in arguments",
|
||||
input: `{"name": "tool1", "arguments": {"arg1": [1, 2, ["nested", "array"]]}}`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool1",
|
||||
Arguments: map[string]any{
|
||||
"arg1": []any{float64(1), float64(2), []any{"nested", "array"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
prefix: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotCalls, err := parseJSONToolCalls(tt.input, tt.nameField, tt.argsField, tt.prefix)
|
||||
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if len(gotCalls) != 0 && tt.wantErr != nil {
|
||||
t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" {
|
||||
t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCalls(t *testing.T) {
|
||||
p := filepath.Join("testdata")
|
||||
t1 := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
},
|
||||
}
|
||||
t2 := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"format": "celsius",
|
||||
"location": "Toronto, Canada",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
output string
|
||||
expectedToolCall []api.ToolCall
|
||||
expectedTokens string
|
||||
}{
|
||||
{
|
||||
name: "mistral malformed json with tool calls prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral multiple tool calls without prefix",
|
||||
model: "mistral",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} ]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral tool calls with text between no prefix",
|
||||
model: "mistral",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
},
|
||||
{
|
||||
name: "mistral valid json with tool calls prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral multiple tool calls with text between and prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2, t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral incomplete json with tool calls prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral invalid tool call with explanatory text no prefix",
|
||||
model: "mistral",
|
||||
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||
|
||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
},
|
||||
{
|
||||
name: "mistral tool calls without prefix",
|
||||
model: "mistral",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "command r plus tool calls with json block format",
|
||||
model: "command-r-plus",
|
||||
output: "Action: ```json" + `
|
||||
[
|
||||
{
|
||||
"tool_name": "get_current_weather",
|
||||
"parameters": {
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA"
|
||||
}
|
||||
},
|
||||
{
|
||||
"tool_name": "get_current_weather",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "Toronto, Canada"
|
||||
}
|
||||
}
|
||||
]
|
||||
` + "```",
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "firefunction tool calls with functools prefix",
|
||||
model: "firefunction",
|
||||
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "llama3 groq single tool call with xml tags",
|
||||
model: "llama3-groq-tool-use",
|
||||
output: `<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||
</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "xlam tool calls with wrapper object",
|
||||
model: "xlam",
|
||||
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 single tool call with prefix",
|
||||
model: "qwen2.5",
|
||||
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 multiple tool calls with and without prefix",
|
||||
model: "qwen2.5",
|
||||
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call> <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1, t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 plain text response no tool calls",
|
||||
model: "qwen2.5",
|
||||
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls with trailing text",
|
||||
model: "qwen2.5",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "some tokens after call",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls with initial text",
|
||||
model: "qwen2.5",
|
||||
output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls with prefix and trailing text",
|
||||
model: "qwen2.5",
|
||||
output: `<tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call> some tokens after call`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls with prefix and initial text",
|
||||
model: "qwen2.5",
|
||||
output: `some tokens before call <tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "some tokens before call",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls without and with prefix",
|
||||
model: "qwen2.5",
|
||||
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls without and with prefix and text between",
|
||||
model: "qwen2.5",
|
||||
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} some tokens between <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call> some tokens after call`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "some tokens between",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls without prefix and invalid tool call with other tokens",
|
||||
model: "qwen2.5",
|
||||
output: `hi [{"options": "foo"}]`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `hi [{"options": "foo"}]`,
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls with prefix and invalid tool call",
|
||||
model: "qwen2.5",
|
||||
output: `<tool_call> [{"options": "foo"}] </tool_call> `,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: ``,
|
||||
},
|
||||
{
|
||||
name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)",
|
||||
model: "qwen3",
|
||||
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
|
||||
},
|
||||
{
|
||||
name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)",
|
||||
model: "qwen3",
|
||||
output: `<think>Okay, let me think what tool we should use...</think> <tool_call>{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
|
||||
},
|
||||
{
|
||||
name: "qwen3 empty think prefix without tool prefix and invalid tool call",
|
||||
model: "qwen3",
|
||||
output: `<think></think> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "qwen3 empty think prefix with tool prefix and valid tool call",
|
||||
model: "qwen3",
|
||||
output: `<think></think><tool_call>{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: `<think></think>`,
|
||||
},
|
||||
{
|
||||
name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)",
|
||||
model: "qwen3",
|
||||
output: `<think></think>< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think>< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
|
||||
model: "qwen3",
|
||||
output: `<think></think><tool_c fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think><tool_c fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "qwen3 invalid tool call with malformed tool prefix",
|
||||
model: "qwen3",
|
||||
output: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "model with prefix in template, no prefix in output",
|
||||
model: "qwen2.5",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "model with prefix in template, prefix in output",
|
||||
model: "qwen2.5",
|
||||
output: `<tool_call>[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, no prefix in output",
|
||||
model: "llama3.2",
|
||||
output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, no prefix in output, single tool call",
|
||||
model: "llama3.2",
|
||||
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, prefix in output, multiple tool calls in list",
|
||||
model: "llama3.2",
|
||||
output: `<tool_call> [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: `<tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, prefix in output, individual tool calls",
|
||||
model: "llama3.2",
|
||||
output: `<tool_call> {"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: `<tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "model with prefix in template, no prefix in output, tokens before",
|
||||
model: "qwen2.5",
|
||||
output: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
},
|
||||
{
|
||||
name: "model with prefix in template, prefix in output, tokens after",
|
||||
model: "qwen2.5",
|
||||
output: `<tool_call>[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, no prefix in output, tokens after",
|
||||
model: "llama3.2",
|
||||
output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, no prefix in output, tokens before",
|
||||
model: "llama3.2",
|
||||
output: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: `some tokens before`,
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, prefix in output, tokens after",
|
||||
model: "llama3.2",
|
||||
output: `<tool_call>
|
||||
[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: `<tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "model without without prefix, match all jsons",
|
||||
model: "llama3.2",
|
||||
output: `model outputs some text [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "model outputs some text",
|
||||
},
|
||||
{
|
||||
name: "model flushes tokens if tool call doesn't match",
|
||||
model: "llama3.2",
|
||||
output: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`,
|
||||
},
|
||||
{
|
||||
name: "model flushes tokens if tool call doesn't match array",
|
||||
model: "llama3.2",
|
||||
output: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`,
|
||||
},
|
||||
}
|
||||
|
||||
var tools []api.Tool
|
||||
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var messages []api.Message
|
||||
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("template", func(t *testing.T) {
|
||||
actual := &bytes.Buffer{} // Create new buffer for each test
|
||||
if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("parse", func(t *testing.T) {
|
||||
tp, err := NewParser(tmpl.Template)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := []api.ToolCall{}
|
||||
var gotTokens strings.Builder
|
||||
|
||||
tokens := strings.Fields(tt.output)
|
||||
for _, tok := range tokens {
|
||||
s := " " + tok
|
||||
|
||||
toolCalls, content := tp.Add(s)
|
||||
if len(content) > 0 {
|
||||
gotTokens.WriteString(content)
|
||||
} else if len(toolCalls) > 0 {
|
||||
got = append(got, toolCalls...)
|
||||
}
|
||||
}
|
||||
|
||||
// Compare tool calls if we expect any
|
||||
if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
// Compare tokens if we expect any
|
||||
stripped := strings.TrimSpace(gotTokens.String())
|
||||
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
|
||||
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)
|
||||
t.Errorf("tokens mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
222
tools/tools_utils.go
Normal file
222
tools/tools_utils.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
gotmpl "text/template"
|
||||
"text/template/parse"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition.
|
||||
// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any
|
||||
// immediate text nodes that follow. This is used to identify tool call prefixes and formatting.
|
||||
//
|
||||
// Returns:
|
||||
// - string: The extracted text following the first ".ToolCalls" condition found
|
||||
// - bool: Whether a ".ToolCalls" condition was found in the template
|
||||
func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) {
|
||||
if tmpl == nil || tmpl.Tree == nil {
|
||||
slog.Debug("template or tree is nil")
|
||||
return "", false
|
||||
}
|
||||
|
||||
var result string
|
||||
var found bool
|
||||
|
||||
var walk func(nodes []parse.Node)
|
||||
walk = func(nodes []parse.Node) {
|
||||
for _, node := range nodes {
|
||||
if found {
|
||||
return
|
||||
}
|
||||
|
||||
switch n := node.(type) {
|
||||
case *parse.IfNode:
|
||||
if isToolCallsNode(n) {
|
||||
// Collect immediate TextNode(s) at start of IfNode's list
|
||||
var sb strings.Builder
|
||||
for _, innerNode := range n.List.Nodes {
|
||||
if tn, ok := innerNode.(*parse.TextNode); ok {
|
||||
sb.Write(tn.Text)
|
||||
} else {
|
||||
// Stop at first non-text node
|
||||
break
|
||||
}
|
||||
}
|
||||
result = sb.String()
|
||||
found = true
|
||||
return
|
||||
}
|
||||
// Recurse into child nodes
|
||||
walk(n.List.Nodes)
|
||||
if n.ElseList != nil {
|
||||
walk(n.ElseList.Nodes)
|
||||
}
|
||||
case *parse.ListNode:
|
||||
walk(n.Nodes)
|
||||
case *parse.RangeNode:
|
||||
walk(n.List.Nodes)
|
||||
if n.ElseList != nil {
|
||||
walk(n.ElseList.Nodes)
|
||||
}
|
||||
case *parse.WithNode:
|
||||
walk(n.List.Nodes)
|
||||
if n.ElseList != nil {
|
||||
walk(n.ElseList.Nodes)
|
||||
}
|
||||
default:
|
||||
// Continue to next node
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
walk(tmpl.Tree.Root.Nodes)
|
||||
return result, found
|
||||
}
|
||||
|
||||
// isToolCallsNode detects if a node's condition includes ".ToolCalls"
|
||||
func isToolCallsNode(n *parse.IfNode) bool {
|
||||
for _, cmd := range n.Pipe.Cmds {
|
||||
for _, arg := range cmd.Args {
|
||||
if field, ok := arg.(*parse.FieldNode); ok {
|
||||
if slices.Contains(field.Ident, "ToolCalls") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func toolPrefix(tmpl *gotmpl.Template) string {
|
||||
tokenText, ok := extractToolCallsFormat(tmpl)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
tokenText = strings.TrimSpace(tokenText)
|
||||
tokenText = strings.ReplaceAll(tokenText, "\r", "")
|
||||
tokenText = strings.ReplaceAll(tokenText, "\n", " ")
|
||||
|
||||
return tokenText
|
||||
}
|
||||
|
||||
// toolTemplate creates a subtree from the node that ranges over .ToolCalls
|
||||
//
|
||||
// Returns:
|
||||
// - *gotmpl.Template: The subtree containing the .ToolCalls range
|
||||
// - error: Error if parsing failed
|
||||
func toolTemplate(t *template.Template) (*gotmpl.Template, error) {
|
||||
tmpl := t.Subtree(func(n parse.Node) bool {
|
||||
if t, ok := n.(*parse.RangeNode); ok {
|
||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
if tmpl == nil {
|
||||
return nil, errors.New("failed to find tool template")
|
||||
}
|
||||
|
||||
return tmpl, nil
|
||||
}
|
||||
|
||||
// suffixOverlap returns the index in s where the longest suffix overlap with prefix begins
|
||||
//
|
||||
// Returns:
|
||||
// - int: The starting index in s where the suffix overlap begins
|
||||
func suffixOverlap(s, prefix string) int {
|
||||
max := min(len(prefix), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, prefix[:i]) {
|
||||
return len(s) - i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// extractToolArgs executes a template with a known tool call format to extract the name and arguments
|
||||
//
|
||||
// Returns:
|
||||
// - string: The name of the tool call
|
||||
// - string: The arguments of the tool call
|
||||
// - error: Error if parsing failed
|
||||
func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) {
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||
"ToolCalls": {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "@@name@@",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"@@argument@@": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Extract JSON object between curly braces
|
||||
// JSON arrays are also valid as they will not be repeated in the template
|
||||
output := b.String()
|
||||
start := strings.Index(output, "{")
|
||||
end := strings.LastIndex(output, "}")
|
||||
if start == -1 || end == -1 || start > end {
|
||||
return "", "", errors.New("no valid JSON object found in template output")
|
||||
}
|
||||
jsonStr := output[start : end+1]
|
||||
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Find name and arguments fields
|
||||
for k, v := range obj {
|
||||
if str, ok := v.(string); ok && str == "@@name@@" {
|
||||
name = k
|
||||
} else if _, ok := v.(map[string]any); ok {
|
||||
arguments = k
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" || arguments == "" {
|
||||
slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments)
|
||||
return "", "", errors.New("missing required fields in tool call template")
|
||||
}
|
||||
|
||||
return name, arguments, nil
|
||||
}
|
||||
|
||||
// collect recursively traverses an object to collect all nested maps
|
||||
//
|
||||
// Returns:
|
||||
// - []map[string]any: A slice of all nested maps found in the object
|
||||
func collect(obj any) []map[string]any {
|
||||
var all []map[string]any
|
||||
switch o := obj.(type) {
|
||||
case map[string]any:
|
||||
all = append(all, o)
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
case []any:
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
497
tools/tools_utils_test.go
Normal file
497
tools/tools_utils_test.go
Normal file
@@ -0,0 +1,497 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"testing"
|
||||
gotmpl "text/template"
|
||||
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
func TestExtractToolCallsFormat(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want string
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
name: "nil template",
|
||||
template: "",
|
||||
want: "",
|
||||
found: false,
|
||||
},
|
||||
{
|
||||
name: "basic tool call with text",
|
||||
template: "{{if .ToolCalls}}Hello world{{end}}",
|
||||
want: "Hello world",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "tool call with json format",
|
||||
template: "{{if .ToolCalls}}```json\n{{end}}",
|
||||
want: "```json\n",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "tool call in range",
|
||||
template: "{{range .ToolCalls}}tool: {{.}}{{end}}",
|
||||
want: "",
|
||||
found: false,
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple text nodes",
|
||||
template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}",
|
||||
want: "First text",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "nested if without tool calls",
|
||||
template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}",
|
||||
want: "",
|
||||
found: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tc.template)
|
||||
if err != nil && tc.template != "" {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
got, found := extractToolCallsFormat(tmpl)
|
||||
if got != tc.want {
|
||||
t.Errorf("got text %q, want %q", got, tc.want)
|
||||
}
|
||||
if found != tc.found {
|
||||
t.Errorf("got found %v, want %v", found, tc.found)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolPrefix(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "basic tool call with action prefix",
|
||||
template: "{{if .ToolCalls}}Action: ```json{{end}}",
|
||||
want: "Action: ```json",
|
||||
},
|
||||
{
|
||||
name: "incomplete functools bracket",
|
||||
template: "{{if .ToolCalls}}functools[{{end}}",
|
||||
want: "functools[",
|
||||
},
|
||||
{
|
||||
name: "tool call with angle brackets",
|
||||
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
|
||||
want: "Hello, world! <tool_call>",
|
||||
},
|
||||
{
|
||||
name: "multiple tool call formats",
|
||||
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
|
||||
want: "[tool_call] <tool_call>",
|
||||
},
|
||||
{
|
||||
name: "single angle bracket tool call",
|
||||
template: "{{if .ToolCalls}}<tool_call>{{end}}",
|
||||
want: "<tool_call>",
|
||||
},
|
||||
{
|
||||
name: "incomplete angle bracket after tool call",
|
||||
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
|
||||
want: "[tool_call] <",
|
||||
},
|
||||
{
|
||||
name: "angle bracket prefix with tool call",
|
||||
template: "{{if .ToolCalls}}> <tool_call>{{end}}",
|
||||
want: "> <tool_call>",
|
||||
},
|
||||
{
|
||||
name: "uppercase tool call with incomplete bracket",
|
||||
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
|
||||
want: "[TOOL_CALL] [",
|
||||
},
|
||||
{
|
||||
name: "uppercase tool call with adjacent bracket",
|
||||
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
|
||||
want: "[TOOL_CALL][",
|
||||
},
|
||||
{
|
||||
name: "tool call with pipe delimiters",
|
||||
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
|
||||
want: "<|tool_call|>",
|
||||
},
|
||||
{
|
||||
name: "tool with no prefix",
|
||||
template: "{{if .ToolCalls}}{{end}}",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
got := toolPrefix(tmpl)
|
||||
if got != tt.want {
|
||||
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolTemplate(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "basic tool call range",
|
||||
template: "{{range .ToolCalls}}test{{end}}",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no tool calls",
|
||||
template: "{{range .Other}}test{{end}}",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nested tool calls",
|
||||
template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty template",
|
||||
template: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "tool calls in if statement",
|
||||
template: "{{if .ToolCalls}}test{{end}}",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
parsed, err := template.Parse(tmpl.Root.String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
_, err = toolTemplate(parsed)
|
||||
if err != nil && tt.want {
|
||||
t.Errorf("toolTemplate() = %v; want %v", err, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuffixOverlap(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
s string
|
||||
d string
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "no overlap",
|
||||
s: "hello world",
|
||||
d: "<tool_call>",
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "full overlap",
|
||||
s: "<tool_call>",
|
||||
d: "<tool_call>",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "partial overlap",
|
||||
s: "text <tool_call>",
|
||||
d: "<tool_call>",
|
||||
want: 5,
|
||||
},
|
||||
{
|
||||
name: "delimiter longer than string",
|
||||
s: "<tool>",
|
||||
d: "<tool_call>",
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
s: "",
|
||||
d: "<tool_call>",
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "empty delimiter",
|
||||
s: "<tool_call>",
|
||||
d: "",
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "single char overlap",
|
||||
s: "test<",
|
||||
d: "<tool_call>",
|
||||
want: 4,
|
||||
},
|
||||
{
|
||||
name: "partial tool call",
|
||||
s: "hello <tool_",
|
||||
d: "<tool_call>",
|
||||
want: 6,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := suffixOverlap(tt.s, tt.d)
|
||||
if got != tt.want {
|
||||
t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolArgs(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
wantName string
|
||||
wantArgs string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic tool call",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}`,
|
||||
wantName: "name",
|
||||
wantArgs: "parameters",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "tool call with whitespace",
|
||||
template: `{{range .ToolCalls}}
|
||||
{"name": "{{.Function.Name}}", "parameters": {{.Function.Arguments}}}
|
||||
{{end}}`,
|
||||
wantName: "name",
|
||||
wantArgs: "parameters",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "tool call with extra content",
|
||||
template: `Before {{range .ToolCalls}}
|
||||
{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}} After`,
|
||||
wantName: "name",
|
||||
wantArgs: "arguments",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no tool calls",
|
||||
template: `{{if .Something}}no tools here{{end}}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty template",
|
||||
template: ``,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "prefix within tool call",
|
||||
template: `{{- if .ToolCalls }}
|
||||
{{ range .ToolCalls }}
|
||||
<tool_call>
|
||||
{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
</tool_call>{{ end }}{{- end }}`,
|
||||
wantName: "name",
|
||||
wantArgs: "arguments",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "JSON array",
|
||||
template: `{{ range .ToolCalls }}
|
||||
[{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}]{{ end }}`,
|
||||
wantName: "name",
|
||||
wantArgs: "arguments",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}, invalid}{{ end }}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing name field",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"parameters": {{ .Function.Arguments }}}{{ end }}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing arguments field",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"name": "{{ .Function.Name }}"}{{ end }}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed JSON",
|
||||
template: `{{ range .ToolCalls }}
|
||||
{"name": {{ .Function.Name }}, "arguments": {{ .Function.Arguments }}{{ end }}`,
|
||||
wantName: "",
|
||||
wantArgs: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
gotName, gotArgs, err := extractToolArgs(tmpl)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("extractToolArgs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if gotName != tt.wantName {
|
||||
t.Errorf("extractToolArgs() gotName = %q, want %q", gotName, tt.wantName)
|
||||
}
|
||||
if gotArgs != tt.wantArgs {
|
||||
t.Errorf("extractToolArgs() gotArgs = %q, want %q", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollect(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
obj any
|
||||
want []map[string]any
|
||||
}{
|
||||
{
|
||||
name: "simple map",
|
||||
obj: map[string]any{
|
||||
"key": "value",
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"key": "value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested map",
|
||||
obj: map[string]any{
|
||||
"outer": map[string]any{
|
||||
"inner": "value",
|
||||
},
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"outer": map[string]any{"inner": "value"}},
|
||||
{"inner": "value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array of maps",
|
||||
obj: []any{
|
||||
map[string]any{"key1": "val1"},
|
||||
map[string]any{"key2": "val2"},
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"key1": "val1"},
|
||||
{"key2": "val2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deeply nested",
|
||||
obj: map[string]any{
|
||||
"l1": map[string]any{
|
||||
"l2": map[string]any{
|
||||
"l3": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"l1": map[string]any{"l2": map[string]any{"l3": "value"}}},
|
||||
{"l2": map[string]any{"l3": "value"}},
|
||||
{"l3": "value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-map value",
|
||||
obj: "string",
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := collect(tt.obj)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want))
|
||||
return
|
||||
}
|
||||
|
||||
// Compare each map in the result
|
||||
for i := range tt.want {
|
||||
if !mapsEqual(got[i], tt.want[i]) {
|
||||
t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mapsEqual compares two maps for deep equality
|
||||
func mapsEqual(m1, m2 map[string]any) bool {
|
||||
if len(m1) != len(m2) {
|
||||
return false
|
||||
}
|
||||
for k, v1 := range m1 {
|
||||
v2, ok := m2[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch val1 := v1.(type) {
|
||||
case map[string]any:
|
||||
val2, ok := v2.(map[string]any)
|
||||
if !ok || !mapsEqual(val1, val2) {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if v1 != v2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -8,6 +8,7 @@ const (
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
||||
Reference in New Issue
Block a user