Compare commits

...

44 Commits

Author SHA1 Message Date
Bruce MacDonald
f2a4d058f9 gofmt 2025-06-16 16:34:46 -07:00
Bruce MacDonald
63e7634014 pr feedback 2025-06-16 16:08:38 -07:00
Bruce MacDonald
8d51d92f3b server: cache gguf model capabilities rather than reading off disc 2025-06-16 15:17:36 -07:00
Bruce MacDonald
2348fef568 Revert "server: model info caching system for improved performance"
This reverts commit 8ef643d4978168a8563ae24434a424358ce390e3.
2025-06-16 15:17:02 -07:00
Bruce MacDonald
883f655dd6 server: model info caching system for improved performance
Implements an in-memory cache for loaded models with file modification
time tracking to ensure cache validity. Models are now cached after
first load and retrieved from cache on subsequent requests if the
underlying manifest file hasn't changed.

Key changes:
- Add ModelCache with get/set methods and modification time validation
- Cache models in GetModel() and check cache before disk load
- Move capabilities calculation to model loading time and store in model
- Update capability access to use cached field instead of runtime calculation
- Add test coverage for cache behavior and model loading

This reduces redundant model loading operations and improves response
times for model access.
2025-06-16 15:16:58 -07:00
Michael Yang
a6fbfc880c gguf: fix write order (#11068)
* ggml: test write gguf order
* ggml: fix write tensor order
2025-06-16 10:42:32 -07:00
NGC13009
502028968d readme: add ollama-launcher to community integrations (#11080) 2025-06-15 21:27:49 -07:00
Phil
5a8eb0e151 readme: add GPTranslate to community integrations (#11071) 2025-06-14 08:54:03 -07:00
Jeffrey Morgan
9f8a18ec05 tools: loosen tool parsing to allow for more formats (#11030) 2025-06-12 14:18:54 -07:00
Michael Yang
6b04cad7e8 feat: incremental gguf parser (#10822)
* incremental gguf parser
* gguf: update test to not rely on gguf on disc
* re-use existing create gguf
* read capabilities from gguf kv
* kv exists
* update tests
* s/doneFunc/successFunc/g
* new buffered reader

---------

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2025-06-12 11:04:11 -07:00
Michael Yang
45f56355d5 feat: uneven splits (#11048)
The current splitDim function only operates on tensors that are split evenly which isn't always the case, e.g. a QKV tensor. This change allows the function to be used for arbitrary splits
2025-06-11 12:10:54 -07:00
Michael Yang
0dabb4ef6a skip tokenizer.model if possible (#11050)
if tokenizer.json is already copied, skip tokenizer.model
2025-06-11 12:10:35 -07:00
Michael Yang
2e77aa1ae7 use nn.Linear in place of ml.Tensor (#11049)
while nn.Linear.Forward isn't applicable for sparse MLP, it's still
a nice container for the tensors
2025-06-11 12:10:15 -07:00
Attogram Project
deaabe292d readme: add ollama-multirun to community integrations (#11038) 2025-06-10 14:14:51 -07:00
Jeffrey Morgan
af21a5ac39 readme: update quickstart link text to Gemma 3 2025-06-10 09:34:23 -07:00
Jeffrey Morgan
f63d7f68eb readme: update quickstart example to Gemma 3 2025-06-10 09:33:54 -07:00
Daniel Hiltgen
82ad1dbc07 mac: handle "keep" named apps (#11031)
When a user elects to keep the existing app, the
new Ollama is named `Ollama 2.app`
This fixes the app startup flow to handle this naming pattern.
2025-06-09 16:29:57 -07:00
Daniel Hiltgen
feeabdadd2 spawn desktop quickly (#11011)
Give the desktop app a hint to start fast.
2025-06-08 09:34:52 -07:00
Krzysztof Jeziorny
fc0309615e docs: update link to AMD drivers in linux.md (#10973) 2025-06-06 23:30:04 -04:00
Jeffrey Morgan
09d308d6b6 Revert "server: add model capabilities to the list endpoint (#10174)" (#11004)
This reverts commit 0943001193.
2025-06-06 23:29:14 -04:00
Daniel Hiltgen
a8ed68bd93 launch app hidden (#10962)
When starting the app in the background, start it hidden.
2025-06-06 14:06:29 -07:00
Daniel Hiltgen
2ae65ae471 win: handle more than 2048 processes (#10997)
Fix an array out of bounds crash
2025-06-06 14:06:09 -07:00
Devon Rifkin
a3b6886b7d move thinking logic into its own package (#10990)
move thinking logic into its own package
2025-06-06 12:02:20 -07:00
Hunter Wittenborn
c6a6d7294d docs: fix typo in development.md (#10998) 2025-06-06 12:07:29 -04:00
Devon Rifkin
2cf007c9d1 Merge pull request #10987 from ollama/drifkin/export-thinking-parser
export ThinkingParser
2025-06-05 12:19:14 -07:00
Devon Rifkin
0683efa637 export ThinkingParser 2025-06-05 10:22:32 -07:00
JasonHonKL
0943001193 server: add model capabilities to the list endpoint (#10174) 2025-06-04 11:39:48 -07:00
HardCodeDev
5c42800fca readme: add SimpleOllamaUnity to community integrations (#10817) 2025-05-30 19:50:16 -07:00
Parth Sareen
65f10c2823 tools: resiliency upgrade to name and arg extraction from template (#10917) 2025-05-30 15:18:09 -07:00
Jesse Gross
aaa7818000 ggml: Export GPU UUIDs
This enables matching up devices and information reported by the backend
with system management libraries such as nvml to get accurate free
memory reporting.
2025-05-29 14:01:26 -07:00
Jesse Gross
f15ffc4320 llm: Make "POST predict" error message more informative
"POST predict" basically means that the runner has crashed, which
can have many reasons. However, many people think this is a specific
error and either report only this message or group together unrelated
bugs. This replaces it with a more friendly and helpful message.
2025-05-29 09:41:19 -07:00
Devon Rifkin
5f57b0ef42 add thinking support to the api and cli (#10584)
- Both `/api/generate` and `/api/chat` now accept a `"think"`
  option that allows specifying whether thinking mode should be on or
  not
- Templates get passed this new option so, e.g., qwen3's template can
  put `/think` or `/no_think` in the system prompt depending on the
  value of the setting
- Models' thinking support is inferred by inspecting model templates.
  The prefix and suffix the parser uses to identify thinking support is
  also automatically inferred from templates
- Thinking control & parsing is opt-in via the API to prevent breaking
  existing API consumers. If the `"think"` option is not specified, the
  behavior is unchanged from previous versions of ollama
- Add parsing for thinking blocks in both streaming/non-streaming mode
  in both `/generate` and `/chat`
- Update the CLI to make use of these changes. Users can pass `--think`
  or `--think=false` to control thinking, or during an interactive
  session they can use the commands `/set think` or `/set nothink`
- A `--hidethinking` option has also been added to the CLI. This makes
  it easy to use thinking in scripting scenarios like
  `ollama run qwen3 --think --hidethinking "my question here"` where you
  just want to see the answer but still want the benefits of thinking
  models
2025-05-28 19:38:52 -07:00
Patrick Devine
aa25aff10d client: add request signing to the client (#10881)
If OLLAMA_AUTH is set, sign each request w/ a timestamp and pass the signature in the token header
2025-05-27 16:50:57 -07:00
Jesse Gross
ea79003180 kvcache: Skip computing causal mask for worst case graph reservation
Computing an attention mask for a large context and max batch is
expensive - over 100ms. Models like Gemma3 that have multiple types
of caches and custom attention masks need to do this 4 times, so this
adds approximately 500ms to startup time when using 128k context

When we are reserving the worst case graph, we don't need the mask,
only its shape, so we can skip this.
2025-05-27 14:25:15 -07:00
Kyle Steere
9239a254e0 server: abort download on empty digest
Signed-off-by: Kyle Steere <kyle.steere@chainguard.dev>
2025-05-27 11:28:48 -07:00
Parth Sareen
066d0f4746 tools: relax JSON parse constraints for tool calling (#10872) 2025-05-26 18:59:06 -07:00
Parth Sareen
aea6fb9b58 tools: remove newline stripping (#10869) 2025-05-26 17:16:00 -07:00
RAPID ARCHITECT
012cf65340 readme: add AWS Strands Agents SDK example to community integrations (#10865) 2025-05-26 12:05:03 -07:00
Min Yoo
a45231af47 readme: Add macLlama to community integrations (#10790)
This commit updates the README to include macLlama within the community integrations section.

macLlama is a native macOS application built for lightweight and efficient LLM interaction.  Key features include:

*   **Lightweight & Native:** Designed to be resource-friendly and perform optimally on macOS.
*   **Chat-like Interface:** Provides a user-friendly, conversational interface.
*   **Multiple Window Support:** Allows users to manage multiple conversations simultaneously.

The primary goal of macLlama is to offer a simple and easy-to-run LLM experience on macOS.
2025-05-24 13:18:32 -07:00
Daniel Hiltgen
2307fc2bcd tests: drop llama3.2-vision embedding tests (#10837) 2025-05-24 13:17:53 -07:00
frob
6623898198 docs: remove unsupported quantizations (#10842) 2025-05-24 13:17:26 -07:00
frob
eda472df1b server: add hint to the error message when model path access fails (#10843) 2025-05-24 13:17:04 -07:00
Jesse Gross
f18e0cb550 ml: Improve slog formatting for BackendMemory 2025-05-23 20:08:23 -07:00
Parth Sareen
e8b981fa5d tools: refactor tool call parsing and enable streaming (#10415) 2025-05-23 14:19:31 -07:00
80 changed files with 5007 additions and 1267 deletions

View File

@@ -40,10 +40,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
## Quickstart
To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2):
To run and chat with [Gemma 3](https://ollama.com/library/gemma3):
```shell
ollama run llama3.2
ollama run gemma3
```
## Model library
@@ -406,6 +406,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
### Cloud
@@ -449,6 +452,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
### Apple Vision Pro
@@ -585,6 +590,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
### Supported backends

View File

@@ -24,7 +24,10 @@ import (
"net/http"
"net/url"
"runtime"
"strconv"
"time"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/version"
@@ -76,6 +79,14 @@ func NewClient(base *url.URL, http *http.Client) *Client {
}
}
func getAuthorizationToken(ctx context.Context, challenge string) (string, error) {
token, err := auth.Sign(ctx, []byte(challenge))
if err != nil {
return "", err
}
return token, nil
}
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
var reqBody io.Reader
var data []byte
@@ -97,6 +108,21 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
}
requestURL := c.base.JoinPath(path)
var token string
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
now := strconv.FormatInt(time.Now().Unix(), 10)
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
token, err = getAuthorizationToken(ctx, chal)
if err != nil {
return err
}
q := requestURL.Query()
q.Set("ts", now)
requestURL.RawQuery = q.Encode()
}
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
if err != nil {
return err
@@ -106,6 +132,10 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if token != "" {
request.Header.Set("Authorization", token)
}
respObj, err := c.http.Do(request)
if err != nil {
return err
@@ -143,6 +173,22 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
}
requestURL := c.base.JoinPath(path)
var token string
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
var err error
now := strconv.FormatInt(time.Now().Unix(), 10)
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
token, err = getAuthorizationToken(ctx, chal)
if err != nil {
return err
}
q := requestURL.Query()
q.Set("ts", now)
requestURL.RawQuery = q.Encode()
}
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
if err != nil {
return err
@@ -152,6 +198,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
request.Header.Set("Accept", "application/x-ndjson")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if token != "" {
request.Header.Set("Authorization", token)
}
response, err := c.http.Do(request)
if err != nil {
return err

View File

@@ -83,6 +83,12 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]any `json:"options"`
// Think controls whether thinking/reasoning models will think before
// responding. Needs to be a pointer so we can distinguish between false
// (request that thinking _not_ be used) and unset (use the old behavior
// before this option was introduced)
Think *bool `json:"think,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -108,6 +114,10 @@ type ChatRequest struct {
// Options lists model-specific options.
Options map[string]any `json:"options"`
// Think controls whether thinking/reasoning models will think before
// responding
Think *bool `json:"think,omitempty"`
}
type Tools []Tool
@@ -126,8 +136,11 @@ func (t Tool) String() string {
// role ("system", "user", or "assistant"), the content and an optional list
// of images.
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content"`
// Thinking contains the text that was inside thinking tags in the
// original model output when ChatRequest.Think is enabled.
Thinking string `json:"thinking,omitempty"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
@@ -478,6 +491,10 @@ type GenerateResponse struct {
// Response is the textual response itself.
Response string `json:"response"`
// Thinking contains the text that was inside thinking tags in the
// original model output when ChatRequest.Think is enabled.
Thinking string `json:"thinking,omitempty"`
// Done specifies if the response is complete.
Done bool `json:"done"`

View File

@@ -372,3 +372,50 @@ func TestPropertyType_MarshalJSON(t *testing.T) {
})
}
}
func TestThinking_UnmarshalJSON(t *testing.T) {
trueVal := true
falseVal := false
tests := []struct {
name string
input string
expectedThinking *bool
expectedError bool
}{
{
name: "true",
input: `{ "think": true }`,
expectedThinking: &trueVal,
},
{
name: "false",
input: `{ "think": false }`,
expectedThinking: &falseVal,
},
{
name: "unset",
input: `{ }`,
expectedThinking: nil,
},
{
name: "invalid",
input: `{ "think": "true" }`,
expectedThinking: nil,
expectedError: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var req GenerateRequest
err := json.Unmarshal([]byte(test.input), &req)
if test.expectedError {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, test.expectedThinking, req.Think)
}
})
}
}

View File

@@ -39,6 +39,7 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
@@ -46,6 +47,23 @@ import (
"github.com/ollama/ollama/version"
)
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
if name == "" {
return
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return
}
for _, cap := range resp.Capabilities {
if cap == model.CapabilityThinking {
return
}
}
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
}
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
func getModelfileName(cmd *cobra.Command) (string, error) {
@@ -265,6 +283,9 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
req := &api.GenerateRequest{
Model: opts.Model,
KeepAlive: opts.KeepAlive,
// pass Think here so we fail before getting to the chat prompt if the model doesn't support it
Think: opts.Think,
}
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
@@ -299,6 +320,22 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
opts.Format = format
thinkFlag := cmd.Flags().Lookup("think")
if thinkFlag.Changed {
think, err := cmd.Flags().GetBool("think")
if err != nil {
return err
}
opts.Think = &think
} else {
opts.Think = nil
}
hidethinking, err := cmd.Flags().GetBool("hidethinking")
if err != nil {
return err
}
opts.HideThinking = hidethinking
keepAlive, err := cmd.Flags().GetString("keepalive")
if err != nil {
return err
@@ -362,6 +399,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err
}
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
if err != nil {
return err
}
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
// TODO: remove the projector info and vision info checks below,
@@ -923,17 +965,19 @@ func PullHandler(cmd *cobra.Command, args []string) error {
type generateContextKey string
type runOptions struct {
Model string
ParentModel string
Prompt string
Messages []api.Message
WordWrap bool
Format string
System string
Images []api.ImageData
Options map[string]any
MultiModal bool
KeepAlive *api.Duration
Model string
ParentModel string
Prompt string
Messages []api.Message
WordWrap bool
Format string
System string
Images []api.ImageData
Options map[string]any
MultiModal bool
KeepAlive *api.Duration
Think *bool
HideThinking bool
}
type displayResponseState struct {
@@ -989,6 +1033,26 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
}
}
func thinkingOutputOpeningText(plainText bool) string {
text := "Thinking...\n"
if plainText {
return text
}
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault + readline.ColorGrey
}
func thinkingOutputClosingText(plainText bool) string {
text := "...done thinking.\n\n"
if plainText {
return text
}
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault
}
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -1016,14 +1080,34 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
var latest api.ChatResponse
var fullResponse strings.Builder
var role string
var thinkTagOpened bool = false
var thinkTagClosed bool = false
fn := func(response api.ChatResponse) error {
p.StopAndClear()
if response.Message.Content != "" || !opts.HideThinking {
p.StopAndClear()
}
latest = response
role = response.Message.Role
if response.Message.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
fmt.Print(thinkingOutputOpeningText(false))
thinkTagOpened = true
}
displayResponse(response.Message.Thinking, opts.WordWrap, state)
}
content := response.Message.Content
if thinkTagOpened && !thinkTagClosed && content != "" {
fmt.Print(thinkingOutputClosingText(false))
thinkTagClosed = true
}
// purposefully not putting thinking blocks in the response, which would
// only be needed if we later added tool calling to the cli (they get
// filtered out anyway since current models don't expect them unless you're
// about to finish some tool calls)
fullResponse.WriteString(content)
displayResponse(content, opts.WordWrap, state)
@@ -1040,6 +1124,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
Messages: opts.Messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
Think: opts.Think,
}
if opts.KeepAlive != nil {
@@ -1101,13 +1186,32 @@ func generate(cmd *cobra.Command, opts runOptions) error {
}()
var state *displayResponseState = &displayResponseState{}
var thinkTagOpened bool = false
var thinkTagClosed bool = false
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
fn := func(response api.GenerateResponse) error {
p.StopAndClear()
latest = response
content := response.Response
if response.Response != "" || !opts.HideThinking {
p.StopAndClear()
}
if response.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
fmt.Print(thinkingOutputOpeningText(plainText))
thinkTagOpened = true
}
displayResponse(response.Thinking, opts.WordWrap, state)
}
if thinkTagOpened && !thinkTagClosed && content != "" {
fmt.Print(thinkingOutputClosingText(plainText))
thinkTagClosed = true
}
displayResponse(content, opts.WordWrap, state)
return nil
@@ -1133,6 +1237,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
System: opts.System,
Options: opts.Options,
KeepAlive: opts.KeepAlive,
Think: opts.Think,
}
if err := client.Generate(ctx, &request, fn); err != nil {
@@ -1348,6 +1453,8 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
runCmd.Flags().String("format", "", "Response format (e.g. json)")
runCmd.Flags().Bool("think", false, "Whether to use thinking mode for supported models")
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
stopCmd := &cobra.Command{
Use: "stop MODEL",
@@ -1399,7 +1506,6 @@ func NewCLI() *cobra.Command {
PreRunE: checkServerHeartbeat,
RunE: ListRunningHandler,
}
copyCmd := &cobra.Command{
Use: "cp SOURCE DESTINATION",
Short: "Copy a model",
@@ -1488,3 +1594,45 @@ func NewCLI() *cobra.Command {
return rootCmd
}
// If the user has explicitly set thinking options, either through the CLI or
// through the `/set think` or `set nothink` interactive options, then we
// respect them. Otherwise, we check model capabilities to see if the model
// supports thinking. If the model does support thinking, we enable it.
// Otherwise, we unset the thinking option (which is different than setting it
// to false).
//
// If capabilities are not provided, we fetch them from the server.
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*bool, error) {
if explicitlySetByUser {
return runOpts.Think, nil
}
if caps == nil {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
ret, err := client.Show(context.Background(), &api.ShowRequest{
Model: runOpts.Model,
})
if err != nil {
return nil, err
}
caps = &ret.Capabilities
}
thinkingSupported := false
for _, cap := range *caps {
if cap == model.CapabilityThinking {
thinkingSupported = true
}
}
if thinkingSupported {
thinking := true
return &thinking, nil
}
return nil, nil
}

View File

@@ -62,6 +62,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, " /set think Enable thinking")
fmt.Fprintln(os.Stderr, " /set nothink Disable thinking")
fmt.Fprintln(os.Stderr, "")
}
@@ -128,6 +130,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
var sb strings.Builder
var multiline MultilineState
var thinkExplicitlySet bool = opts.Think != nil
for {
line, err := scanner.Readline()
@@ -195,11 +198,19 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Model = args[1]
opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model)
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
if err != nil {
return err
}
if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("error: %v\n", err)
continue
}
if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
continue
}
return err
}
continue
@@ -260,6 +271,22 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err
}
fmt.Println("Set 'quiet' mode.")
case "think":
think := true
opts.Think = &think
thinkExplicitlySet = true
if client, err := api.ClientFromEnvironment(); err == nil {
ensureThinkingSupport(cmd.Context(), client, opts.Model)
}
fmt.Println("Set 'think' mode.")
case "nothink":
think := false
opts.Think = &think
thinkExplicitlySet = true
if client, err := api.ClientFromEnvironment(); err == nil {
ensureThinkingSupport(cmd.Context(), client, opts.Model)
}
fmt.Println("Set 'nothink' mode.")
case "format":
if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
@@ -448,6 +475,11 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
assistant, err := chat(cmd, opts)
if err != nil {
if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
sb.Reset()
continue
}
return err
}
if assistant != nil {

View File

@@ -5,7 +5,7 @@ import (
"errors"
"os"
"os/exec"
"strings"
"regexp"
"github.com/ollama/ollama/api"
)
@@ -19,11 +19,12 @@ func startApp(ctx context.Context, client *api.Client) error {
if err != nil {
return err
}
if !strings.Contains(link, "Ollama.app") {
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
m := r.FindStringSubmatch(link)
if len(m) != 1 {
return errors.New("could not find ollama app")
}
path := strings.Split(link, "Ollama.app")
if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil {
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
return err
}
return waitForServer(ctx, client)

View File

@@ -45,14 +45,11 @@ func startApp(ctx context.Context, client *api.Client) error {
}
}
}
// log.Printf("XXX attempting to start app %s", appExe)
cmd_path := "c:\\Windows\\system32\\cmd.exe"
cmd := exec.Command(cmd_path, "/c", appExe)
// TODO - these hide flags aren't working - still pops up a command window for some reason
cmd := exec.Command(cmd_path, "/c", appExe, "--hide", "--fast-startup")
cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true}
// TODO this didn't help either...
cmd.Stdin = strings.NewReader("")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
@@ -74,7 +71,16 @@ func isProcRunning(procName string) []uint32 {
slog.Debug("failed to check for running installers", "error", err)
return nil
}
pids = pids[:ret]
if ret > uint32(len(pids)) {
pids = make([]uint32, ret+10)
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
slog.Debug("failed to check for running installers", "error", err)
return nil
}
}
if ret < uint32(len(pids)) {
pids = pids[:ret]
}
var matches []uint32
for _, pid := range pids {
if pid == 0 {

63
cmd/warn_thinking_test.go Normal file
View File

@@ -0,0 +1,63 @@
package cmd
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
// Test that a warning is printed when thinking is requested but not supported.
func TestWarnMissingThinking(t *testing.T) {
cases := []struct {
capabilities []model.Capability
expectWarn bool
}{
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
{capabilities: []model.Capability{}, expectWarn: true},
}
for _, tc := range cases {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
}
var req api.ShowRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
resp := api.ShowResponse{Capabilities: tc.capabilities}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("encode response: %v", err)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
client, err := api.ClientFromEnvironment()
if err != nil {
t.Fatal(err)
}
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
ensureThinkingSupport(t.Context(), client, "m")
w.Close()
os.Stderr = oldStderr
out, _ := io.ReadAll(r)
warned := strings.Contains(string(out), "warning:")
if tc.expectWarn && !warned {
t.Errorf("expected warning, got none")
}
if !tc.expectWarn && warned {
t.Errorf("did not expect warning, got: %s", string(out))
}
}
}

View File

@@ -65,17 +65,17 @@ func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
for _, t := range ts {
if strings.Contains(t.Name(), "patch_embed.proj") {
for t := range splitDim(t, 2,
strings.NewReplacer("patch_embed.proj", "patch_embd_0"),
strings.NewReplacer("patch_embed.proj", "patch_embd_1"),
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_0")},
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_1")},
) {
t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 })
out = append(out, t)
}
} else if strings.Contains(t.Name(), "attn.qkv") {
out = append(out, slices.Collect(splitDim(t, 0,
strings.NewReplacer("attn.qkv", "attn_q"),
strings.NewReplacer("attn.qkv", "attn_k"),
strings.NewReplacer("attn.qkv", "attn_v"),
split{Replacer: strings.NewReplacer("attn.qkv", "attn_q")},
split{Replacer: strings.NewReplacer("attn.qkv", "attn_k")},
split{Replacer: strings.NewReplacer("attn.qkv", "attn_v")},
))...)
} else {
out = append(out, &ggml.Tensor{

View File

@@ -1,53 +1,73 @@
package convert
import (
"cmp"
"iter"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type split struct {
*strings.Replacer
dim int
// fn is an optional function to apply to the tensor after slicing
fn func(tensor.Tensor) (tensor.Tensor, error)
}
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
// is split evenly based on the number of replacers provided.
func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] {
// is split evenly based on the number of replacers provided unless a specific count is given.
func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
return func(yield func(*ggml.Tensor) bool) {
for i, replacer := range replacers {
var offset int
for _, split := range splits {
t := t.Clone()
shape := slices.Clone(t.Shape())
shape[dim] = shape[dim] / uint64(len(replacers))
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim]))
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
offset += int(shape[dim])
tt := t.Clone()
tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err := t.Slice(slice...)
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tt.Slice(slice...)
if err != nil {
return nil, err
}
t = tensor.Materialize(t)
tt = tensor.Materialize(tt)
if split.fn != nil {
tt, err = split.fn(tt)
if err != nil {
return nil, err
}
}
// flatten tensor so it can be written as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
return native.VectorF32(tt.(*tensor.Dense))
})
if !yield(&ggml.Tensor{
Name: replacer.Replace(t.Name()),
Name: split.Replace(t.Name()),
Kind: t.Kind(),
Shape: shape,
WriterTo: tt,
WriterTo: t,
}) {
break
}

304
convert/tensor_test.go Normal file
View File

@@ -0,0 +1,304 @@
package convert
import (
"bytes"
"encoding/binary"
"io"
"iter"
"slices"
"strings"
"testing"
"github.com/pdevine/tensor"
)
type fakeTensor struct {
name string
shape []uint64
data []float32
repacker Repacker
}
func (f fakeTensor) Name() string {
return f.name
}
func (f fakeTensor) Shape() []uint64 {
return f.shape
}
func (f fakeTensor) Kind() uint32 {
return 0
}
func (f *fakeTensor) SetRepacker(fn Repacker) {
f.repacker = fn
}
func (f fakeTensor) Clone() Tensor {
return &fakeTensor{
name: f.name,
shape: slices.Clone(f.shape),
data: slices.Clone(f.data),
repacker: f.repacker,
}
}
func (f fakeTensor) WriteTo(w io.Writer) (n int64, err error) {
data := f.data
if f.repacker != nil {
data, err = f.repacker(f.name, data, f.shape)
if err != nil {
return 0, err
}
}
if err := binary.Write(w, binary.LittleEndian, data); err != nil {
return 0, err
}
return int64(len(data) * 4), nil
}
func mul(shape []uint64) int {
n := 1
for _, dim := range shape {
n *= int(dim)
}
return n
}
func TestSplitDim(t *testing.T) {
r := fakeTensor{
name: "a.b",
shape: []uint64{3, 4},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
}
t.Run("no split", func(t *testing.T) {
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
if tt.Name != "x.b" {
t.Fatalf("expected name 'x', got '%s'", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 4}) {
t.Fatalf("expected shape [3, 4], got %v", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) {
t.Fatalf("expected data [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], got %v", f32s)
}
}
})
t.Run("even split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 1,
split{Replacer: strings.NewReplacer("a", "x")},
split{Replacer: strings.NewReplacer("b", "y")},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
t.Fatal("expected shape [3, 2], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
t.Fatal("expected shape [3, 2], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{2, 3, 6, 7, 10, 11}) {
t.Fatal("expected data [2, 3, 6, 7, 10, 11], got", f32s)
}
}
})
t.Run("uneven split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 0,
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{2, 4}) {
t.Fatal("expected shape [2, 4], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}) {
t.Fatal("expected data [0, 1, 2, 3, 4, 5, 6, 7], got", f32s)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{1, 4}) {
t.Fatal("expected shape [1, 4], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{8, 9, 10, 11}) {
t.Fatal("expected data [8, 9, 10, 11], got", f32s)
}
}
})
t.Run("split with transpose", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 1,
split{Replacer: strings.NewReplacer("a", "x")},
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
return tensor.Transpose(tt, 1, 0)
}},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
t.Fatal("expected shape [3, 2], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
t.Fatal("expected shape [3, 2], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{2, 6, 10, 3, 7, 11}) {
t.Fatal("expected data [2, 6, 10, 3, 7, 11], got", f32s)
}
}
})
}

View File

@@ -43,6 +43,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
- `prompt`: the prompt to generate a response for
- `suffix`: the text after the model response
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
- `think`: (for thinking models) should the model think before responding?
Advanced parameters (optional):
@@ -490,11 +491,13 @@ Generate the next message in a chat with a provided model. This is a streaming e
- `model`: (required) the [model name](#model-names)
- `messages`: the messages of the chat, this can be used to keep a chat memory
- `tools`: list of tools in JSON for the model to use if supported
- `think`: (for thinking models) should the model think before responding?
The `message` object has the following fields:
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
- `content`: the content of the message
- `thinking`: (for thinking models) the model's thinking process
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
- `tool_calls` (optional): a list of tools in JSON that the model wants to use

View File

@@ -118,7 +118,7 @@ To run tests, use `go test`:
go test ./...
```
> NOTE: In rare cirumstances, you may nedd to change a package using the new
> NOTE: In rare cirumstances, you may need to change a package using the new
> "synctest" package in go1.24.
>
> If you do not have the "synctest" package enabled, you will not see build or

View File

@@ -132,22 +132,12 @@ success
### Supported Quantizations
- `q4_0`
- `q4_1`
- `q5_0`
- `q5_1`
- `q8_0`
#### K-means Quantizations
- `q3_K_S`
- `q3_K_M`
- `q3_K_L`
- `q4_K_S`
- `q4_K_M`
- `q5_K_S`
- `q5_K_M`
- `q6_K`
## Sharing your model on ollama.com

View File

@@ -112,8 +112,8 @@ sudo systemctl status ollama
> While AMD has contributed the `amdgpu` driver upstream to the official linux
> kernel source, the version is older and may not support all ROCm features. We
> recommend you install the latest driver from
> https://www.amd.com/en/support/linux-drivers for best support of your Radeon
> GPU.
> [AMD](https://www.amd.com/en/support/download/linux-drivers.html) for best support
> of your Radeon GPU.
## Customizing

View File

@@ -183,6 +183,8 @@ var (
NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
// Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH")
)
func String(s string) func() string {

View File

@@ -527,23 +527,17 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
return err
}
keys := slices.Collect(maps.Keys(kv))
slices.Sort(keys)
for _, key := range keys {
for _, key := range slices.Sorted(maps.Keys(kv)) {
if err := ggufWriteKV(f, key, kv[key]); err != nil {
return err
}
}
slices.SortStableFunc(ts, func(a, b *Tensor) int {
if i, j := a.block(), b.block(); i < 0 && j > 0 {
return 1
} else if i > 0 && j < 0 {
return -1
} else {
if i, j := a.block(), b.block(); i > 0 && j > 0 {
return cmp.Compare(i, j)
}
return cmp.Compare(a.Name, b.Name)
})
var s uint64

View File

@@ -2,62 +2,82 @@ package ggml
import (
"bytes"
"math/rand/v2"
"os"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWriteGGUF(t *testing.T) {
w, err := os.CreateTemp(t.TempDir(), "*.bin")
if err != nil {
t.Fatal(err)
}
defer w.Close()
r := rand.New(rand.NewPCG(0, 0))
for range 8 {
t.Run("shuffle", func(t *testing.T) {
t.Parallel()
if err := WriteGGUF(w, KV{
"general.alignment": uint32(16),
}, []*Tensor{
{Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
}); err != nil {
t.Fatal(err)
}
ts := []*Tensor{
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
}
r, err := os.Open(w.Name())
if err != nil {
t.Fatal(err)
}
defer r.Close()
r.Shuffle(len(ts), func(i, j int) {
ts[i], ts[j] = ts[j], ts[i]
})
ff, err := Decode(r, 0)
if err != nil {
t.Fatal(err)
}
w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
if err != nil {
t.Fatal(err)
}
defer w.Close()
if diff := cmp.Diff(ff.KV(), KV{
"general.alignment": uint32(16),
"general.parameter_count": uint64(36),
}); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
if err := WriteGGUF(w, KV{
"general.alignment": uint32(16),
}, ts); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(ff.Tensors(), Tensors{
Offset: 336,
items: []*Tensor{
{Name: "test.0", Offset: 0, Shape: []uint64{2, 3}},
{Name: "test.1", Offset: 32, Shape: []uint64{2, 3}},
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}},
{Name: "test.3", Offset: 96, Shape: []uint64{2, 3}},
{Name: "test.4", Offset: 128, Shape: []uint64{2, 3}},
{Name: "test.5", Offset: 160, Shape: []uint64{2, 3}},
},
}, cmp.AllowUnexported(Tensors{})); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
r, err := os.Open(w.Name())
if err != nil {
t.Fatal(err)
}
defer r.Close()
ff, err := Decode(r, 0)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(KV{
"general.alignment": uint32(16),
"general.parameter_count": uint64(54),
}, ff.KV()); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(Tensors{
Offset: 608,
items: []*Tensor{
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
},
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
})
}
}

347
fs/gguf/gguf.go Normal file
View File

@@ -0,0 +1,347 @@
package gguf
import (
"bytes"
"cmp"
"encoding/binary"
"errors"
"fmt"
"io"
"iter"
"os"
"slices"
"strings"
)
const (
typeUint8 uint32 = iota
typeInt8
typeUint16
typeInt16
typeUint32
typeInt32
typeFloat32
typeBool
typeString
typeArray
typeUint64
typeInt64
typeFloat64
)
var ErrUnsupported = errors.New("unsupported")
type File struct {
Magic [4]byte
Version uint32
keyValues *lazy[KeyValue]
tensors *lazy[TensorInfo]
offset int64
file *os.File
reader *bufferedReader
bts []byte
}
func Open(path string) (f *File, err error) {
f = &File{bts: make([]byte, 4096)}
f.file, err = os.Open(path)
if err != nil {
return nil, err
}
f.reader = newBufferedReader(f.file, 32<<10)
if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil {
return nil, err
}
if bytes.Equal(f.Magic[:], []byte("gguf")) {
return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic)
}
if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil {
return nil, err
}
if f.Version != 3 {
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
}
f.tensors, err = newLazy(f, f.readTensor)
if err != nil {
return nil, err
}
f.tensors.successFunc = func() error {
offset := f.reader.offset
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32)
f.offset = offset + (alignment-offset%alignment)%alignment
return nil
}
f.keyValues, err = newLazy(f, f.readKeyValue)
if err != nil {
return nil, err
}
return f, nil
}
func (f *File) readTensor() (TensorInfo, error) {
name, err := readString(f)
if err != nil {
return TensorInfo{}, err
}
dims, err := read[uint32](f)
if err != nil {
return TensorInfo{}, err
}
shape := make([]uint64, dims)
for i := range dims {
shape[i], err = read[uint64](f)
if err != nil {
return TensorInfo{}, err
}
}
type_, err := read[uint32](f)
if err != nil {
return TensorInfo{}, err
}
offset, err := read[uint64](f)
if err != nil {
return TensorInfo{}, err
}
return TensorInfo{
Name: name,
Offset: offset,
Shape: shape,
Type: TensorType(type_),
}, nil
}
func (f *File) readKeyValue() (KeyValue, error) {
key, err := readString(f)
if err != nil {
return KeyValue{}, err
}
t, err := read[uint32](f)
if err != nil {
return KeyValue{}, err
}
value, err := func() (any, error) {
switch t {
case typeUint8:
return read[uint8](f)
case typeInt8:
return read[int8](f)
case typeUint16:
return read[uint16](f)
case typeInt16:
return read[int16](f)
case typeUint32:
return read[uint32](f)
case typeInt32:
return read[int32](f)
case typeUint64:
return read[uint64](f)
case typeInt64:
return read[int64](f)
case typeFloat32:
return read[float32](f)
case typeFloat64:
return read[float64](f)
case typeBool:
return read[bool](f)
case typeString:
return readString(f)
case typeArray:
return readArray(f)
default:
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
}
}()
if err != nil {
return KeyValue{}, err
}
return KeyValue{
Key: key,
Value: Value{value},
}, nil
}
func read[T any](f *File) (t T, err error) {
err = binary.Read(f.reader, binary.LittleEndian, &t)
return t, err
}
func readString(f *File) (string, error) {
n, err := read[uint64](f)
if err != nil {
return "", err
}
if int(n) > len(f.bts) {
f.bts = make([]byte, n)
}
bts := f.bts[:n]
if _, err := io.ReadFull(f.reader, bts); err != nil {
return "", err
}
defer clear(bts)
return string(bts), nil
}
func readArray(f *File) (any, error) {
t, err := read[uint32](f)
if err != nil {
return nil, err
}
n, err := read[uint64](f)
if err != nil {
return nil, err
}
switch t {
case typeUint8:
return readArrayData[uint8](f, n)
case typeInt8:
return readArrayData[int8](f, n)
case typeUint16:
return readArrayData[uint16](f, n)
case typeInt16:
return readArrayData[int16](f, n)
case typeUint32:
return readArrayData[uint32](f, n)
case typeInt32:
return readArrayData[int32](f, n)
case typeUint64:
return readArrayData[uint64](f, n)
case typeInt64:
return readArrayData[int64](f, n)
case typeFloat32:
return readArrayData[float32](f, n)
case typeFloat64:
return readArrayData[float64](f, n)
case typeBool:
return readArrayData[bool](f, n)
case typeString:
return readArrayString(f, n)
default:
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
}
}
func readArrayData[T any](f *File, n uint64) (s []T, err error) {
s = make([]T, n)
for i := range n {
e, err := read[T](f)
if err != nil {
return nil, err
}
s[i] = e
}
return s, nil
}
func readArrayString(f *File, n uint64) (s []string, err error) {
s = make([]string, n)
for i := range n {
e, err := readString(f)
if err != nil {
return nil, err
}
s[i] = e
}
return s, nil
}
func (f *File) Close() error {
f.keyValues.stop()
f.tensors.stop()
return f.file.Close()
}
func (f *File) KeyValue(key string) KeyValue {
if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") {
key = f.KeyValue("general.architecture").String() + "." + key
}
if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool {
return kv.Key == key
}); index >= 0 {
return f.keyValues.values[index]
}
for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() {
if keyValue.Key == key {
return keyValue
}
}
return KeyValue{}
}
func (f *File) NumKeyValues() int {
return int(f.keyValues.count)
}
func (f *File) KeyValues() iter.Seq2[int, KeyValue] {
return f.keyValues.All()
}
func (f *File) TensorInfo(name string) TensorInfo {
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool {
return t.Name == name
}); index >= 0 {
return f.tensors.values[index]
}
// fast-forward through key values if we haven't already
_ = f.keyValues.rest()
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() {
if tensor.Name == name {
return tensor
}
}
return TensorInfo{}
}
func (f *File) NumTensors() int {
return int(f.tensors.count)
}
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] {
// fast forward through key values if we haven't already
f.keyValues.rest()
return f.tensors.All()
}
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
t := f.TensorInfo(name)
if t.NumBytes() == 0 {
return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name)
}
// fast forward through tensor info if we haven't already
_ = f.tensors.rest()
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil
}

249
fs/gguf/gguf_test.go Normal file
View File

@@ -0,0 +1,249 @@
package gguf_test
import (
"bytes"
"os"
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/fs/gguf"
)
func createBinFile(tb testing.TB) string {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "")
if err != nil {
tb.Fatal(err)
}
defer f.Close()
kv := ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(8),
"llama.embedding_length": uint32(3),
"llama.attention.head_count": uint32(2),
"llama.attention.head_count_kv": uint32(2),
"llama.attention.key_length": uint32(3),
"llama.rope.dimension_count": uint32(4),
"llama.rope.freq_base": float32(10000.0),
"llama.rope.freq_scale": float32(1.0),
"llama.attention.layer_norm_rms_epsilon": float32(1e-6),
"tokenizer.ggml.eos_token_id": uint32(0),
"tokenizer.ggml.eos_token_ids": []int32{1, 2, 3},
"tokenizer.ggml.tokens": []string{"hello", "world"},
"tokenizer.ggml.scores": []float32{0, 1},
}
tensors := []*ggml.Tensor{
{
Name: "token_embd.weight",
Kind: 0,
Shape: []uint64{2, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*2*3)),
},
{
Name: "output.weight",
Kind: 0,
Shape: []uint64{3, 2},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*2)),
},
}
for i := range 8 {
tensors = append(tensors, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_q.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
}, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_k.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
}, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_v.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
}, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_output.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
})
}
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
tb.Fatal(err)
}
return f.Name()
}
func TestRead(t *testing.T) {
f, err := gguf.Open(createBinFile(t))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if got := f.KeyValue("does.not.exist").Valid(); got {
t.Errorf(`KeyValue("does.not.exist").Exists() = %v, want false`, got)
}
if got := f.KeyValue("general.architecture").String(); got != "llama" {
t.Errorf(`KeyValue("general.architecture").String() = %q, want %q`, got, "llama")
}
if got := f.TensorInfo("token_embd.weight"); got.Name != "token_embd.weight" {
t.Errorf(`TensorInfo("token_embd.weight").Name = %q, want %q`, got.Name, "token_embd.weight")
} else if diff := cmp.Diff(got.Shape, []uint64{2, 3}); diff != "" {
t.Errorf(`TensorInfo("token_embd.weight").Shape mismatch (-got +want):\n%s`, diff)
} else if got.Type != gguf.TensorTypeF32 {
t.Errorf(`TensorInfo("token_embd.weight").Type = %d, want %d`, got.Type, gguf.TensorTypeF32)
}
if got := f.KeyValue("block_count").Uint(); got != 8 {
t.Errorf(`KeyValue("block_count").Uint() = %d, want %d`, got, 8)
}
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.tokens").Strings(), []string{"hello", "world"}); diff != "" {
t.Errorf("KeyValue(\"tokenizer.ggml.tokens\").Strings() mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.scores").Floats(), []float64{0, 1}); diff != "" {
t.Errorf("KeyValue(\"tokenizer.ggml.scores\").Ints() mismatch (-got +want):\n%s", diff)
}
var kvs []string
for _, kv := range f.KeyValues() {
if !kv.Valid() {
t.Error("found invalid key-value pair:", kv)
}
kvs = append(kvs, kv.Key)
}
if len(kvs) != f.NumKeyValues() {
t.Errorf("iterated key count = %d, want %d", len(kvs), f.NumKeyValues())
}
if diff := cmp.Diff(kvs, []string{
"general.architecture",
"llama.block_count",
"llama.embedding_length",
"llama.attention.head_count",
"llama.attention.head_count_kv",
"llama.attention.key_length",
"llama.rope.dimension_count",
"llama.rope.freq_base",
"llama.rope.freq_scale",
"llama.attention.layer_norm_rms_epsilon",
"tokenizer.ggml.eos_token_id",
"tokenizer.ggml.eos_token_ids",
"tokenizer.ggml.tokens",
"tokenizer.ggml.scores",
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
t.Errorf("KeyValues() mismatch (-got +want):\n%s", diff)
}
var tis []string
for _, ti := range f.TensorInfos() {
if !ti.Valid() {
t.Error("found invalid tensor info:", ti)
}
tis = append(tis, ti.Name)
}
if len(tis) != f.NumTensors() {
t.Errorf("iterated tensor count = %d, want %d", len(tis), f.NumTensors())
}
if diff := cmp.Diff(tis, []string{
"token_embd.weight",
"output.weight",
"blk.0.attn_q.weight",
"blk.0.attn_k.weight",
"blk.0.attn_v.weight",
"blk.0.attn_output.weight",
"blk.1.attn_q.weight",
"blk.1.attn_k.weight",
"blk.1.attn_v.weight",
"blk.1.attn_output.weight",
"blk.2.attn_q.weight",
"blk.2.attn_k.weight",
"blk.2.attn_v.weight",
"blk.2.attn_output.weight",
"blk.3.attn_q.weight",
"blk.3.attn_k.weight",
"blk.3.attn_v.weight",
"blk.3.attn_output.weight",
"blk.4.attn_q.weight",
"blk.4.attn_k.weight",
"blk.4.attn_v.weight",
"blk.4.attn_output.weight",
"blk.5.attn_q.weight",
"blk.5.attn_k.weight",
"blk.5.attn_v.weight",
"blk.5.attn_output.weight",
"blk.6.attn_q.weight",
"blk.6.attn_k.weight",
"blk.6.attn_v.weight",
"blk.6.attn_output.weight",
"blk.7.attn_q.weight",
"blk.7.attn_k.weight",
"blk.7.attn_v.weight",
"blk.7.attn_output.weight",
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
t.Errorf("TensorInfos() mismatch (-got +want):\n%s", diff)
}
ti, r, err := f.TensorReader("output.weight")
if err != nil {
t.Fatalf(`TensorReader("output.weight") error: %v`, err)
}
if ti.Name != "output.weight" {
t.Errorf(`TensorReader("output.weight").Name = %q, want %q`, ti.Name, "output.weight")
} else if diff := cmp.Diff(ti.Shape, []uint64{3, 2}); diff != "" {
t.Errorf(`TensorReader("output.weight").Shape mismatch (-got +want):\n%s`, diff)
} else if ti.Type != gguf.TensorTypeF32 {
t.Errorf(`TensorReader("output.weight").Type = %d, want %d`, ti.Type, gguf.TensorTypeF32)
}
var b bytes.Buffer
if _, err := b.ReadFrom(r); err != nil {
t.Fatalf(`ReadFrom TensorReader("output.weight") error: %v`, err)
}
if b.Len() != int(ti.NumBytes()) {
t.Errorf(`ReadFrom TensorReader("output.weight") length = %d, want %d`, b.Len(), ti.NumBytes())
}
}
func BenchmarkRead(b *testing.B) {
b.ReportAllocs()
p := createBinFile(b)
for b.Loop() {
f, err := gguf.Open(p)
if err != nil {
b.Fatal(err)
}
if got := f.KeyValue("general.architecture").String(); got != "llama" {
b.Errorf("got = %q, want %q", got, "llama")
}
// Iterate through some tensors
for range f.TensorInfos() {
}
f.Close()
}
}

90
fs/gguf/keyvalue.go Normal file
View File

@@ -0,0 +1,90 @@
package gguf
import (
"reflect"
"slices"
)
type KeyValue struct {
Key string
Value
}
func (kv KeyValue) Valid() bool {
return kv.Key != "" && kv.Value.value != nil
}
type Value struct {
value any
}
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
vv := reflect.ValueOf(v.value)
if slices.Contains(kinds, vv.Kind()) {
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
}
return
}
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
switch vv := reflect.ValueOf(v.value); vv.Kind() {
case reflect.Slice:
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
ts = make([]T, vv.Len())
for i := range vv.Len() {
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
}
}
}
return
}
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
func (v Value) Int() int64 {
return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
}
// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil.
func (v Value) Ints() (i64s []int64) {
return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
}
// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0.
func (v Value) Uint() uint64 {
return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
}
// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil.
func (v Value) Uints() (u64s []uint64) {
return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
}
// Float returns Value as a float. If it is not a float, it returns 0.
func (v Value) Float() float64 {
return value[float64](v, reflect.Float32, reflect.Float64)
}
// Floats returns Value as a float slice. If it is not a float slice, it returns nil.
func (v Value) Floats() (f64s []float64) {
return values[float64](v, reflect.Float32, reflect.Float64)
}
// Bool returns Value as a boolean. If it is not a boolean, it returns false.
func (v Value) Bool() bool {
return value[bool](v, reflect.Bool)
}
// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil.
func (v Value) Bools() (bools []bool) {
return values[bool](v, reflect.Bool)
}
// String returns Value as a string. If it is not a string, it returns an empty string.
func (v Value) String() string {
return value[string](v, reflect.String)
}
// Strings returns Value as a string slice. If it is not a string slice, it returns nil.
func (v Value) Strings() (strings []string) {
return values[string](v, reflect.String)
}

208
fs/gguf/keyvalue_test.go Normal file
View File

@@ -0,0 +1,208 @@
package gguf
import (
"testing"
"github.com/google/go-cmp/cmp"
)
func split(name string, values map[string][]any) (matched []any, unmatched []any) {
for key, value := range values {
if key == name {
matched = value
} else {
unmatched = append(unmatched, value...)
}
}
return
}
func TestValue(t *testing.T) {
values := map[string][]any{
"int64": {int(42), int8(42), int16(42), int32(42), int64(42)},
"uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)},
"float64": {float32(42), float64(42)},
"string": {"42", "hello"},
"bool": {true, false},
}
t.Run("int64", func(t *testing.T) {
matched, unmatched := split("int64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if i64 := kv.Int(); i64 != 42 {
t.Errorf("expected 42, got %d", i64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if i64 := kv.Int(); i64 != 0 {
t.Errorf("expected 42, got %d", i64)
}
}
})
t.Run("uint64", func(t *testing.T) {
matched, unmatched := split("uint64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if u64 := kv.Uint(); u64 != 42 {
t.Errorf("expected 42, got %d", u64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if u64 := kv.Uint(); u64 != 0 {
t.Errorf("expected 42, got %d", u64)
}
}
})
t.Run("float64", func(t *testing.T) {
matched, unmatched := split("float64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if f64 := kv.Float(); f64 != 42 {
t.Errorf("expected 42, got %f", f64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if f64 := kv.Float(); f64 != 0 {
t.Errorf("expected 42, got %f", f64)
}
}
})
t.Run("string", func(t *testing.T) {
matched, unmatched := split("string", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if s := kv.String(); s != v {
t.Errorf("expected 42, got %s", s)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if s := kv.String(); s != "" {
t.Errorf("expected 42, got %s", s)
}
}
})
t.Run("bool", func(t *testing.T) {
matched, unmatched := split("bool", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bool(); b != v {
t.Errorf("expected true, got %v", b)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bool(); b != false {
t.Errorf("expected false, got %v", b)
}
}
})
}
func TestValues(t *testing.T) {
values := map[string][]any{
"int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}},
"uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}},
"float64s": {[]float32{42}, []float64{42}},
"strings": {[]string{"42"}, []string{"hello"}},
"bools": {[]bool{true}, []bool{false}},
}
t.Run("int64s", func(t *testing.T) {
matched, unmatched := split("int64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if i64s := kv.Ints(); i64s != nil {
t.Errorf("expected nil, got %v", i64s)
}
}
})
t.Run("uint64s", func(t *testing.T) {
matched, unmatched := split("uint64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if u64s := kv.Uints(); u64s != nil {
t.Errorf("expected nil, got %v", u64s)
}
}
})
t.Run("float64s", func(t *testing.T) {
matched, unmatched := split("float64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if f64s := kv.Floats(); f64s != nil {
t.Errorf("expected nil, got %v", f64s)
}
}
})
t.Run("strings", func(t *testing.T) {
matched, unmatched := split("strings", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Strings(), v); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if s := kv.Strings(); s != nil {
t.Errorf("expected nil, got %v", s)
}
}
})
t.Run("bools", func(t *testing.T) {
matched, unmatched := split("bools", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Bools(), v); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bools(); b != nil {
t.Errorf("expected nil, got %v", b)
}
}
})
}

89
fs/gguf/lazy.go Normal file
View File

@@ -0,0 +1,89 @@
package gguf
import (
"encoding/binary"
"iter"
"log/slog"
)
type lazy[T any] struct {
count uint64
next func() (T, bool)
stop func()
values []T
// successFunc is called when all values have been successfully read.
successFunc func() error
}
func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) {
it := lazy[T]{}
if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil {
return nil, err
}
it.values = make([]T, 0)
it.next, it.stop = iter.Pull(func(yield func(T) bool) {
for i := range it.count {
t, err := fn()
if err != nil {
slog.Error("error reading tensor", "index", i, "error", err)
return
}
it.values = append(it.values, t)
if !yield(t) {
break
}
}
if it.successFunc != nil {
it.successFunc()
}
})
return &it, nil
}
func (g *lazy[T]) Values() iter.Seq[T] {
return func(yield func(T) bool) {
for _, v := range g.All() {
if !yield(v) {
break
}
}
}
}
func (g *lazy[T]) All() iter.Seq2[int, T] {
return func(yield func(int, T) bool) {
for i := range int(g.count) {
if i < len(g.values) {
if !yield(i, g.values[i]) {
break
}
} else {
t, ok := g.next()
if !ok {
break
}
if !yield(i, t) {
break
}
}
}
}
}
func (g *lazy[T]) rest() (collected bool) {
for {
_, ok := g.next()
collected = collected || ok
if !ok {
break
}
}
return collected
}

23
fs/gguf/reader.go Normal file
View File

@@ -0,0 +1,23 @@
package gguf
import (
"bufio"
"io"
)
type bufferedReader struct {
offset int64
*bufio.Reader
}
func newBufferedReader(rs io.ReadSeeker, size int) *bufferedReader {
return &bufferedReader{
Reader: bufio.NewReaderSize(rs, size),
}
}
func (rs *bufferedReader) Read(p []byte) (n int, err error) {
n, err = rs.Reader.Read(p)
rs.offset += int64(n)
return n, err
}

288
fs/gguf/tensor.go Normal file
View File

@@ -0,0 +1,288 @@
package gguf
import (
"log/slog"
"strings"
)
type TensorInfo struct {
Name string
Offset uint64
Shape []uint64
Type TensorType
}
func (ti TensorInfo) Valid() bool {
return ti.Name != "" && ti.NumBytes() > 0
}
func (ti TensorInfo) NumValues() int64 {
var numItems int64 = 1
for _, dim := range ti.Shape {
numItems *= int64(dim)
}
return numItems
}
// NumBytes returns the number of bytes in the tensor.
func (ti TensorInfo) NumBytes() int64 {
return int64(float64(ti.NumValues()) * ti.Type.NumBytes())
}
func (ti TensorInfo) LogValue() slog.Value {
return slog.GroupValue(
slog.String("name", ti.Name),
slog.Int64("offset", int64(ti.Offset)),
slog.Any("shape", ti.Shape),
slog.Int64("num_values", ti.NumValues()),
slog.Int64("num_bytes", ti.NumBytes()),
slog.Any("type", ti.Type),
)
}
type TensorType uint32
const (
TensorTypeF32 TensorType = iota
TensorTypeF16
TensorTypeQ4_0
TensorTypeQ4_1
// unexported // unused in gguf
tensorTypeQ4_2
tensorTypeQ4_3
TensorTypeQ5_0
TensorTypeQ5_1
TensorTypeQ8_0
TensorTypeQ8_1
TensorTypeQ2_K
TensorTypeQ3_K
TensorTypeQ4_K
TensorTypeQ5_K
TensorTypeQ6_K
TensorTypeQ8_K
// unexported // unquantizable by ollama
tensorTypeIQ2_XXS
tensorTypeIQ2_XS
tensorTypeIQ3_XXS
tensorTypeIQ1_S
tensorTypeIQ4_NL
tensorTypeIQ3_S
tensorTypeIQ2_S
tensorTypeIQ4_XS
TensorTypeI8
TensorTypeI16
TensorTypeI32
TensorTypeI64
TensorTypeF64
// unexported // unquantizable by ollama
tensorTypeIQ1_M
TensorTypeBF16
// unexported // unused in gguf
tensorTypeQ4_0_4_4
tensorTypeQ4_0_4_8
tensorTypeQ4_0_8_8
// unexported // unquantizable by ollama
tensorTypeTQ1_0
tensorTypeTQ2_0
// unexported // unused in gguf
tensorTypeIQ4_NL_4_4
tensorTypeIQ4_NL_4_8
tensorTypeIQ4_NL_8_8
)
func (tt TensorType) NumBytes() float64 {
return float64(tt.typeSize()) / float64(tt.blockSize())
}
func (tt TensorType) typeSize() int64 {
switch tt {
case TensorTypeF32:
return 4
case TensorTypeF16:
return 2
case TensorTypeQ4_0:
return 2 + tt.blockSize()/2
case TensorTypeQ4_1:
return 2 + 2 + tt.blockSize()/2
case TensorTypeQ5_0:
return 2 + 4 + tt.blockSize()/2
case TensorTypeQ5_1:
return 2 + 2 + 4 + tt.blockSize()/2
case TensorTypeQ8_0:
return 2 + tt.blockSize()
case TensorTypeQ8_1:
return 2 + 2 + tt.blockSize()
case TensorTypeQ2_K:
return tt.blockSize()/16 + tt.blockSize()/4 + 2 + 2
case TensorTypeQ3_K:
return tt.blockSize()/8 + tt.blockSize()/4 + 12 + 2
case TensorTypeQ4_K:
return 2 + 2 + 12 + tt.blockSize()/2
case TensorTypeQ5_K:
return 2 + 2 + 12 + tt.blockSize()/8 + tt.blockSize()/2
case TensorTypeQ6_K:
return tt.blockSize()/2 + tt.blockSize()/4 + tt.blockSize()/16 + 2
case TensorTypeQ8_K:
return 4 + tt.blockSize() + 2*tt.blockSize()/16
case tensorTypeIQ2_XXS:
return 2 + 2*tt.blockSize()/8
case tensorTypeIQ2_XS:
return 2 + 2*tt.blockSize()/8 + tt.blockSize()/32
case tensorTypeIQ3_XXS:
return 2 + tt.blockSize()/4 + tt.blockSize()/8
case tensorTypeIQ1_S:
return 2 + tt.blockSize()/8 + tt.blockSize()/16
case tensorTypeIQ4_NL:
return 2 + tt.blockSize()/2
case tensorTypeIQ3_S:
return 2 + tt.blockSize()/4 + tt.blockSize()/8 + tt.blockSize()/32 + 4
case tensorTypeIQ2_S:
return 2 + tt.blockSize()/4 + tt.blockSize()/16
case tensorTypeIQ4_XS:
return 2 + 2 + tt.blockSize()/2 + tt.blockSize()/64
case TensorTypeI8:
return 1
case TensorTypeI16:
return 2
case TensorTypeI32:
return 4
case TensorTypeI64:
return 8
case TensorTypeF64:
return 8
case tensorTypeIQ1_M:
return tt.blockSize()/8 + tt.blockSize()/16 + tt.blockSize()/32
case TensorTypeBF16:
return 2
default:
return 0
}
}
func (tt TensorType) blockSize() int64 {
switch tt {
case TensorTypeF32,
TensorTypeF16,
TensorTypeI8,
TensorTypeI16,
TensorTypeI32,
TensorTypeI64,
TensorTypeF64,
TensorTypeBF16:
return 1
case TensorTypeQ4_0,
TensorTypeQ4_1,
TensorTypeQ5_0,
TensorTypeQ5_1,
TensorTypeQ8_0,
TensorTypeQ8_1,
tensorTypeIQ4_NL:
return 32
default:
return 256
}
}
func (tt TensorType) String() string {
switch tt {
case TensorTypeF32:
return "f32"
case TensorTypeF16:
return "f16"
case TensorTypeQ4_0:
return "q4_0"
case TensorTypeQ4_1:
return "q4_1"
case tensorTypeQ4_2:
return "q4_2"
case tensorTypeQ4_3:
return "q4_3"
case TensorTypeQ5_0:
return "q5_0"
case TensorTypeQ5_1:
return "q5_1"
case TensorTypeQ8_0:
return "q8_0"
case TensorTypeQ8_1:
return "q8_1"
case TensorTypeQ2_K:
return "q2_k"
case TensorTypeQ3_K:
return "q3_k"
case TensorTypeQ4_K:
return "q4_k"
case TensorTypeQ5_K:
return "q5_k"
case TensorTypeQ6_K:
return "q6_k"
case TensorTypeQ8_K:
return "q8_k"
case tensorTypeIQ2_XXS:
return "iq2_xxs"
case tensorTypeIQ2_XS:
return "iq2_xs"
case tensorTypeIQ3_XXS:
return "iq3_xxs"
case tensorTypeIQ1_S:
return "iq1_s"
case tensorTypeIQ4_NL:
return "iq4_nl"
case tensorTypeIQ3_S:
return "iq3_s"
case tensorTypeIQ2_S:
return "iq2_s"
case tensorTypeIQ4_XS:
return "iq4_xs"
case TensorTypeI8:
return "i8"
case TensorTypeI16:
return "i16"
case TensorTypeI32:
return "i32"
case TensorTypeI64:
return "i64"
case TensorTypeF64:
return "f64"
case tensorTypeIQ1_M:
return "iq1_m"
case TensorTypeBF16:
return "bf16"
case tensorTypeQ4_0_4_4:
return "q4_0_4_4"
case tensorTypeQ4_0_4_8:
return "q4_0_4_8"
case tensorTypeQ4_0_8_8:
return "q4_0_8_8"
case tensorTypeTQ1_0:
return "tq1_0"
case tensorTypeTQ2_0:
return "tq2_0"
case tensorTypeIQ4_NL_4_4:
return "iq4_nl_4_4"
case tensorTypeIQ4_NL_4_8:
return "iq4_nl_4_8"
case tensorTypeIQ4_NL_8_8:
return "iq4_nl_8_8"
default:
return "unknown"
}
}
func (tt TensorType) LogValue() slog.Value {
return slog.GroupValue(
slog.Uint64("value", uint64(tt)),
slog.String("name", strings.ToUpper(tt.String())),
slog.Int64("size", tt.typeSize()),
slog.Int64("block_size", tt.blockSize()),
slog.Float64("num_bytes", tt.NumBytes()),
)
}

2
go.mod
View File

@@ -19,7 +19,7 @@ require (
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/google/go-cmp v0.6.0
github.com/google/go-cmp v0.7.0
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c

4
go.sum
View File

@@ -112,8 +112,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=

View File

File diff suppressed because one or more lines are too long

View File

@@ -30,6 +30,11 @@ type Causal struct {
// ** current forward pass **
// curReserve indicates that this forward pass is only for
// memory reservation and we should not update our metadata
// based on it.
curReserve bool
// the active layer for Get and Put
curLayer int
@@ -159,12 +164,13 @@ func (c *Causal) Close() {
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
c.curReserve = reserve
c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
c.opts.Except = nil
if !reserve {
if !c.curReserve {
c.updateSlidingWindow()
var err error
@@ -304,6 +310,11 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
if c.curReserve {
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
}
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize {

View File

@@ -0,0 +1,102 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jesse Gross <jesse@ollama.com>
Date: Thu, 24 Apr 2025 14:48:51 -0700
Subject: [PATCH] ggml: Export GPU UUIDs
This enables matching up devices and information reported by the backend
with tools (e.g. nvidia-smi) and system management libraries (e.g. nvml).
---
ggml/include/ggml-backend.h | 1 +
ggml/src/ggml-cuda/ggml-cuda.cu | 33 ++++++++++++++++++++++++++++++++
ggml/src/ggml-metal/ggml-metal.m | 1 +
3 files changed, 35 insertions(+)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 74e46716..a880df33 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -152,6 +152,7 @@ extern "C" {
struct ggml_backend_dev_props {
const char * name;
const char * description;
+ const char * uuid;
size_t memory_free;
size_t memory_total;
enum ggml_backend_dev_type type;
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index cb0d8528..4c829153 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2884,6 +2884,7 @@ struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
+ std::string uuid;
};
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -2896,6 +2897,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
return ctx->description.c_str();
}
+static const char * ggml_backend_cuda_device_get_uuid(ggml_backend_dev_t dev) {
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+ return ctx->uuid.c_str();
+}
+
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
ggml_cuda_set_device(ctx->device);
@@ -2910,6 +2916,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
props->name = ggml_backend_cuda_device_get_name(dev);
props->description = ggml_backend_cuda_device_get_description(dev);
+ props->uuid = ggml_backend_cuda_device_get_uuid(dev);
props->type = ggml_backend_cuda_device_get_type(dev);
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
@@ -3458,6 +3465,32 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name;
+ #if !defined(GGML_USE_HIP)
+ char uuid[64];
+ snprintf(uuid, sizeof(uuid),
+ "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
+ (unsigned char)prop.uuid.bytes[0],
+ (unsigned char)prop.uuid.bytes[1],
+ (unsigned char)prop.uuid.bytes[2],
+ (unsigned char)prop.uuid.bytes[3],
+ (unsigned char)prop.uuid.bytes[4],
+ (unsigned char)prop.uuid.bytes[5],
+ (unsigned char)prop.uuid.bytes[6],
+ (unsigned char)prop.uuid.bytes[7],
+ (unsigned char)prop.uuid.bytes[8],
+ (unsigned char)prop.uuid.bytes[9],
+ (unsigned char)prop.uuid.bytes[10],
+ (unsigned char)prop.uuid.bytes[11],
+ (unsigned char)prop.uuid.bytes[12],
+ (unsigned char)prop.uuid.bytes[13],
+ (unsigned char)prop.uuid.bytes[14],
+ (unsigned char)prop.uuid.bytes[15]
+ );
+ dev_ctx->uuid = uuid;
+ #else
+ dev_ctx->uuid = "GPU-" + std::string(prop.uuid.bytes, 16);
+ #endif
+
ggml_backend_dev_t dev = new ggml_backend_device {
/* .iface = */ ggml_backend_cuda_device_interface,
/* .reg = */ &reg,
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 1b56f858..ee4f2dcb 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -5703,6 +5703,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
props->name = ggml_backend_metal_device_get_name(dev);
props->description = ggml_backend_metal_device_get_description(dev);
+ props->uuid = "0";
props->type = ggml_backend_metal_device_get_type(dev);
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = (struct ggml_backend_dev_caps) {

View File

@@ -797,7 +797,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
res, err := http.DefaultClient.Do(serverReq)
if err != nil {
return fmt.Errorf("POST predict: %v", err)
slog.Error("post predict", "error", err)
return errors.New("model runner has unexpectedly stopped, this may be due to resource limitations or an internal error, check ollama server logs for details")
}
defer res.Body.Close()

View File

@@ -5,6 +5,7 @@ import (
"context"
"encoding/binary"
"fmt"
"log/slog"
"math"
"slices"
"strconv"
@@ -123,6 +124,10 @@ type DeviceMemory struct {
// may not be persistent across instances of the runner.
Name string
// UUID is a unique persistent identifier for the device for matching
// with system management libraries
UUID string
// Weights is the per-layer memory needed for the model weights.
Weights []Memory
@@ -133,6 +138,31 @@ type DeviceMemory struct {
Graph Memory
}
func memoryPresent(mem []Memory) bool {
return slices.ContainsFunc(mem, func(m Memory) bool { return m.Size != 0 })
}
func (m DeviceMemory) LogValue() slog.Value {
var attrs []slog.Attr
if memoryPresent(m.Weights) {
attrs = append(attrs, slog.Any("Weights", m.Weights))
}
if memoryPresent(m.Cache) {
attrs = append(attrs, slog.Any("Cache", m.Cache))
}
if m.Graph.Size != 0 {
attrs = append(attrs, slog.Any("Graph", m.Graph))
}
if len(attrs) > 0 && m.UUID != "" {
attrs = append([]slog.Attr{slog.String("UUID", m.UUID)}, attrs...)
}
return slog.GroupValue(attrs...)
}
// BackendMemory provides the amount of memory required to load the model
// per device based on the BackendParams. In some cases, not all required
// allocations will be known at this point. However, the size of the most recent
@@ -150,6 +180,20 @@ type BackendMemory struct {
GPUs []DeviceMemory
}
func (m BackendMemory) LogValue() slog.Value {
var attrs []slog.Attr
if m.InputWeights.Size != 0 {
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
}
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
for _, g := range m.GPUs {
attrs = append(attrs, slog.Any(g.Name, g))
}
return slog.GroupValue(attrs...)
}
var backends = make(map[string]func(string, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {

View File

@@ -136,6 +136,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
}
requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
var props C.struct_ggml_backend_dev_props
C.ggml_backend_dev_get_props(cpuDeviceBufferType.d, &props)
requiredMemory.CPU.UUID = C.GoString(props.uuid)
requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1)
requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1)
@@ -150,6 +153,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
})
btDeviceMemory[bt] = &requiredMemory.GPUs[i]
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
var props C.struct_ggml_backend_dev_props
C.ggml_backend_dev_get_props(d, &props)
requiredMemory.GPUs[i].UUID = C.GoString(props.uuid)
requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1)
requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
}

View File

@@ -152,6 +152,7 @@ extern "C" {
struct ggml_backend_dev_props {
const char * name;
const char * description;
const char * uuid;
size_t memory_free;
size_t memory_total;
enum ggml_backend_dev_type type;

View File

@@ -2884,6 +2884,7 @@ struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
std::string uuid;
};
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -2896,6 +2897,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
return ctx->description.c_str();
}
static const char * ggml_backend_cuda_device_get_uuid(ggml_backend_dev_t dev) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
return ctx->uuid.c_str();
}
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
ggml_cuda_set_device(ctx->device);
@@ -2910,6 +2916,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
props->name = ggml_backend_cuda_device_get_name(dev);
props->description = ggml_backend_cuda_device_get_description(dev);
props->uuid = ggml_backend_cuda_device_get_uuid(dev);
props->type = ggml_backend_cuda_device_get_type(dev);
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
@@ -3458,6 +3465,32 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name;
#if !defined(GGML_USE_HIP)
char uuid[64];
snprintf(uuid, sizeof(uuid),
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
(unsigned char)prop.uuid.bytes[0],
(unsigned char)prop.uuid.bytes[1],
(unsigned char)prop.uuid.bytes[2],
(unsigned char)prop.uuid.bytes[3],
(unsigned char)prop.uuid.bytes[4],
(unsigned char)prop.uuid.bytes[5],
(unsigned char)prop.uuid.bytes[6],
(unsigned char)prop.uuid.bytes[7],
(unsigned char)prop.uuid.bytes[8],
(unsigned char)prop.uuid.bytes[9],
(unsigned char)prop.uuid.bytes[10],
(unsigned char)prop.uuid.bytes[11],
(unsigned char)prop.uuid.bytes[12],
(unsigned char)prop.uuid.bytes[13],
(unsigned char)prop.uuid.bytes[14],
(unsigned char)prop.uuid.bytes[15]
);
dev_ctx->uuid = uuid;
#else
dev_ctx->uuid = "GPU-" + std::string(prop.uuid.bytes, 16);
#endif
ggml_backend_dev_t dev = new ggml_backend_device {
/* .iface = */ ggml_backend_cuda_device_interface,
/* .reg = */ &reg,

View File

@@ -5703,6 +5703,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
props->name = ggml_backend_metal_device_get_name(dev);
props->description = ggml_backend_metal_device_get_description(dev);
props->uuid = "0";
props->type = ggml_backend_metal_device_get_type(dev);
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = (struct ggml_backend_dev_caps) {

View File

@@ -3,6 +3,7 @@ package model
import (
"cmp"
"context"
"fmt"
"iter"
"log/slog"
"strings"
@@ -210,6 +211,14 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
type lazyIdsString struct {
ids []int32
}
func (l lazyIdsString) LogValue() slog.Value {
return slog.AnyValue(fmt.Sprint(l.ids))
}
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
@@ -234,6 +243,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String())
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
return sb.String(), nil
}

View File

@@ -63,9 +63,9 @@ func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOp
}
type TextExperts struct {
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"`
}
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
hiddenStates = hiddenStates.Mul(ctx, scores)
upStates := e.Up.MulmatID(ctx, hiddenStates, experts)
gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts)
downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {

View File

@@ -66,9 +66,9 @@ type MLP interface {
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"`
}
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
@@ -87,13 +87,13 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
upStates := mlp.Up.MulmatID(ctx, hiddenStates, selectedExperts)
upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
hiddenStates = mlp.Gate.MulmatID(ctx, hiddenStates, selectedExperts)
hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
hiddenStates = hiddenStates.SILU(ctx)
hiddenStates = hiddenStates.Mul(ctx, upStates)
experts := mlp.Down.MulmatID(ctx, hiddenStates, selectedExperts)
experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))

View File

@@ -292,13 +292,18 @@ func filesForModel(path string) ([]string, error) {
}
files = append(files, js...)
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
// only include tokenizer.model is tokenizer.json is not present
if !slices.ContainsFunc(files, func(s string) bool {
return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json")
}) {
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}
}
return files, nil

View File

@@ -61,6 +61,8 @@ const (
ColorGrey = Esc + "[38;5;245m"
ColorDefault = Esc + "[0m"
ColorBold = Esc + "[1m"
StartBracketedPaste = Esc + "[?2004h"
EndBracketedPaste = Esc + "[?2004l"
)

115
server/cache/capabilities.go vendored Normal file
View File

@@ -0,0 +1,115 @@
package cache
import (
"fmt"
"log/slog"
"os"
"slices"
"sync"
"time"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/types/model"
)
// cacheEntry stores capabilities and the modification time of the model file
type cacheEntry struct {
capabilities []model.Capability
modTime time.Time
}
// ggufCapabilities is a cache for gguf model capabilities
var ggufCapabilities = &sync.Map{}
// ModelInfo contains the minimal information needed to determine capabilities
type ModelInfo struct {
ModelPath string
ProjectorPaths []string
Template *template.Template
}
// Capabilities returns the capabilities that the model supports
func Capabilities(info ModelInfo) []model.Capability {
capabilities, err := ggufCapabilties(info.ModelPath)
if err != nil {
slog.Error("could not determine gguf capabilities", "error", err)
}
if info.Template == nil {
return capabilities
}
// Check for tools capability
if slices.Contains(info.Template.Vars(), "tools") {
capabilities = append(capabilities, model.CapabilityTools)
}
// Check for insert capability
if slices.Contains(info.Template.Vars(), "suffix") {
capabilities = append(capabilities, model.CapabilityInsert)
}
// Check for vision capability in projector-based models
if len(info.ProjectorPaths) > 0 {
capabilities = append(capabilities, model.CapabilityVision)
}
// Check for thinking capability
openingTag, closingTag := thinking.InferTags(info.Template.Template)
if openingTag != "" && closingTag != "" {
capabilities = append(capabilities, model.CapabilityThinking)
}
return capabilities
}
func ggufCapabilties(modelPath string) ([]model.Capability, error) {
// Get file info to check modification time
fileInfo, err := os.Stat(modelPath)
if err != nil {
return nil, err
}
currentModTime := fileInfo.ModTime()
// Check if we have a cached entry
if cached, ok := ggufCapabilities.Load(modelPath); ok {
entry := cached.(cacheEntry)
// If the file hasn't been modified since we cached it, return the cached capabilities
if entry.modTime.Equal(currentModTime) {
return entry.capabilities, nil
}
}
// If not cached or file was modified, read the model file to determine capabilities
capabilities := []model.Capability{}
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
f, err := ggml.Decode(r, 1024)
if err != nil {
return nil, err
}
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding)
} else {
capabilities = append(capabilities, model.CapabilityCompletion)
}
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityVision)
}
// Cache the capabilities with the modification time
ggufCapabilities.Store(modelPath, cacheEntry{
capabilities: capabilities,
modTime: currentModTime,
})
return capabilities, nil
}

211
server/cache/capabilities_test.go vendored Normal file
View File

@@ -0,0 +1,211 @@
package cache
import (
"bytes"
"maps"
"os"
"slices"
"testing"
"time"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
// testGGUF creates a temporary GGUF model file for testing with custom key-value pairs
func testGGUF(tb testing.TB, customKV ggml.KV) string {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "test*.gguf")
if err != nil {
tb.Fatal(err)
}
defer f.Close()
kv := ggml.KV{}
maps.Copy(kv, customKV)
tensors := []*ggml.Tensor{
{
Name: "token_embd.weight",
Kind: 0,
Shape: []uint64{1, 1},
WriterTo: bytes.NewBuffer(make([]byte, 4)),
},
}
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
tb.Fatal(err)
}
return f.Name()
}
func TestCapabilities(t *testing.T) {
ggufCapabilities.Range(func(key, value any) bool {
ggufCapabilities.Delete(key)
return true
})
// Create test model paths
completionModelPath := testGGUF(t, ggml.KV{
"general.architecture": "llama",
})
visionModelPath := testGGUF(t, ggml.KV{
"general.architecture": "llama",
"llama.vision.block_count": uint32(1),
})
embeddingModelPath := testGGUF(t, ggml.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(1),
})
// Create templates
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
testCases := []struct {
name string
model ModelInfo
expectedCaps []model.Capability
}{
{
name: "model with completion capability",
model: ModelInfo{
ModelPath: completionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion},
},
{
name: "model with completion, tools, and insert capability",
model: ModelInfo{
ModelPath: completionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools capability",
model: ModelInfo{
ModelPath: completionModelPath,
Template: toolsTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
},
{
name: "model with vision capability from gguf",
model: ModelInfo{
ModelPath: visionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision capability from projector",
model: ModelInfo{
ModelPath: completionModelPath,
ProjectorPaths: []string{"/path/to/projector"},
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision, tools, and insert capability",
model: ModelInfo{
ModelPath: visionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with embedding capability",
model: ModelInfo{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// First call - should read from file
caps := Capabilities(tc.model)
slices.Sort(caps)
slices.Sort(tc.expectedCaps)
if !slices.Equal(caps, tc.expectedCaps) {
t.Errorf("Expected capabilities %v, got %v", tc.expectedCaps, caps)
}
// Verify caching for models that read from GGUF
if tc.model.ModelPath != "" {
// Check that entry is cached
_, ok := ggufCapabilities.Load(tc.model.ModelPath)
if !ok {
t.Error("Expected capabilities to be cached")
}
// Second call - should use cache
caps2 := Capabilities(tc.model)
slices.Sort(caps2)
if !slices.Equal(caps, caps2) {
t.Errorf("Cached capabilities don't match original: expected %v, got %v", caps, caps2)
}
}
})
}
// Test cache invalidation on file modification
t.Run("cache invalidation", func(t *testing.T) {
// Use completion model for this test
info := ModelInfo{
ModelPath: completionModelPath,
Template: chatTemplate,
}
// Get initial cached entry
cached, ok := ggufCapabilities.Load(completionModelPath)
if !ok {
t.Fatal("Expected model to be cached from previous tests")
}
entry := cached.(cacheEntry)
// Modify the file's timestamp to the future
future := time.Now().Add(time.Hour)
err := os.Chtimes(completionModelPath, future, future)
if err != nil {
t.Fatalf("Failed to update file timestamp: %v", err)
}
// Call should re-read from file due to changed modtime
caps := Capabilities(info)
if len(caps) != 1 || caps[0] != model.CapabilityCompletion {
t.Errorf("Expected [CapabilityCompletion], got %v", caps)
}
// Check that cache was updated with new modtime
cached2, ok := ggufCapabilities.Load(completionModelPath)
if !ok {
t.Error("Expected capabilities to be cached after re-read")
}
entry2 := cached2.(cacheEntry)
if entry2.modTime.Equal(entry.modTime) {
t.Error("Expected cache entry to have updated modTime")
}
})
}

View File

@@ -464,6 +464,10 @@ type downloadOpts struct {
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
if opts.digest == "" {
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is is empty")
}
fp, err := GetBlobsPath(opts.digest)
if err != nil {
return false, err

View File

@@ -23,8 +23,8 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/cache"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
@@ -37,6 +37,7 @@ var (
errCapabilityInsert = errors.New("insert")
errCapabilityVision = errors.New("vision")
errCapabilityEmbedding = errors.New("embedding")
errCapabilityThinking = errors.New("thinking")
errInsecureProtocol = errors.New("insecure protocol http")
)
@@ -66,58 +67,14 @@ type Model struct {
Template *template.Template
}
// Capabilities returns the capabilities that the model supports
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for completion capability
r, err := os.Open(m.ModelPath)
if err == nil {
defer r.Close()
f, err := ggml.Decode(r, 1024)
if err == nil {
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding)
} else {
capabilities = append(capabilities, model.CapabilityCompletion)
}
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityVision)
}
} else {
slog.Error("couldn't decode ggml", "error", err)
}
} else {
slog.Error("couldn't open model file", "error", err)
}
if m.Template == nil {
return capabilities
}
// Check for tools capability
if slices.Contains(m.Template.Vars(), "tools") {
capabilities = append(capabilities, model.CapabilityTools)
}
// Check for insert capability
if slices.Contains(m.Template.Vars(), "suffix") {
capabilities = append(capabilities, model.CapabilityInsert)
}
// Check for vision capability in projector-based models
if len(m.ProjectorPaths) > 0 {
capabilities = append(capabilities, model.CapabilityVision)
}
return capabilities
}
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
// any missing or unknown capabilities
func (m *Model) CheckCapabilities(want ...model.Capability) error {
available := m.Capabilities()
available := cache.Capabilities(cache.ModelInfo{
ModelPath: m.ModelPath,
ProjectorPaths: m.ProjectorPaths,
Template: m.Template,
})
var errs []error
// Map capabilities to their corresponding error
@@ -127,6 +84,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
model.CapabilityInsert: errCapabilityInsert,
model.CapabilityVision: errCapabilityVision,
model.CapabilityEmbedding: errCapabilityEmbedding,
model.CapabilityThinking: errCapabilityThinking,
}
for _, cap := range want {
@@ -141,11 +99,19 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
}
}
var err error
if len(errs) > 0 {
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
err = fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
}
return nil
if slices.Contains(errs, errCapabilityThinking) {
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
// append a message to the existing error
return fmt.Errorf("%w. Pull the model again to get the latest version with full thinking support", err)
}
}
return err
}
func (m *Model) String() string {

View File

@@ -1,252 +1,42 @@
package server
import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
// Constants for GGUF magic bytes and version
var (
ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF"
ggufVer = uint32(3) // Version 3
)
// Helper function to create mock GGUF data
func createMockGGUFData(architecture string, vision bool) []byte {
var buf bytes.Buffer
// Write GGUF header
buf.Write(ggufMagic)
binary.Write(&buf, binary.LittleEndian, ggufVer)
// Write tensor count (0 for our test)
var numTensors uint64 = 0
binary.Write(&buf, binary.LittleEndian, numTensors)
// Calculate number of metadata entries
numMetaEntries := uint64(1) // architecture entry
if vision {
numMetaEntries++
}
// Add embedding entry if architecture is "bert"
if architecture == "bert" {
numMetaEntries++
}
binary.Write(&buf, binary.LittleEndian, numMetaEntries)
// Write architecture metadata
archKey := "general.architecture"
keyLen := uint64(len(archKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(archKey)
// String type (8)
var strType uint32 = 8
binary.Write(&buf, binary.LittleEndian, strType)
// String length
strLen := uint64(len(architecture))
binary.Write(&buf, binary.LittleEndian, strLen)
buf.WriteString(architecture)
if vision {
visionKey := architecture + ".vision.block_count"
keyLen = uint64(len(visionKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(visionKey)
// uint32 type (4)
var uint32Type uint32 = 4
binary.Write(&buf, binary.LittleEndian, uint32Type)
// uint32 value (1)
var countVal uint32 = 1
binary.Write(&buf, binary.LittleEndian, countVal)
}
// Write embedding metadata if architecture is "bert"
if architecture == "bert" {
poolKey := architecture + ".pooling_type"
keyLen = uint64(len(poolKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(poolKey)
// uint32 type (4)
var uint32Type uint32 = 4
binary.Write(&buf, binary.LittleEndian, uint32Type)
// uint32 value (1)
var poolingVal uint32 = 1
binary.Write(&buf, binary.LittleEndian, poolingVal)
}
return buf.Bytes()
}
func TestModelCapabilities(t *testing.T) {
// Create a temporary directory for test files
tempDir := t.TempDir()
// Create different types of mock model files
completionModelPath := filepath.Join(tempDir, "model.bin")
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
// Create a simple model file for tests that don't depend on GGUF content
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
if err := errors.Join(
os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644),
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
); err != nil {
t.Fatalf("Failed to create model files: %v", err)
}
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
testModels := []struct {
name string
model Model
expectedCaps []model.Capability
}{
{
name: "model with completion capability",
model: Model{
ModelPath: completionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion},
},
{
name: "model with completion, tools, and insert capability",
model: Model{
ModelPath: completionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools and insert capability",
model: Model{
ModelPath: simpleModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools capability",
model: Model{
ModelPath: simpleModelPath,
Template: toolsTemplate,
},
expectedCaps: []model.Capability{model.CapabilityTools},
},
{
name: "model with vision capability",
model: Model{
ModelPath: visionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision, tools, and insert capability",
model: Model{
ModelPath: visionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with embedding capability",
model: Model{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
}
// compare two slices of model.Capability regardless of order
compareCapabilities := func(a, b []model.Capability) bool {
if len(a) != len(b) {
return false
}
aCount := make(map[model.Capability]int)
for _, cap := range a {
aCount[cap]++
}
bCount := make(map[model.Capability]int)
for _, cap := range b {
bCount[cap]++
}
for cap, count := range aCount {
if bCount[cap] != count {
return false
}
}
return true
}
for _, tt := range testModels {
t.Run(tt.name, func(t *testing.T) {
// Test Capabilities method
caps := tt.model.Capabilities()
if !compareCapabilities(caps, tt.expectedCaps) {
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
}
})
}
}
func TestModelCheckCapabilities(t *testing.T) {
// Create a temporary directory for test files
tempDir := t.TempDir()
// Create simple model file for tests that don't depend on GGUF content
completionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
}, []*ggml.Tensor{})
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
simpleModelPath := filepath.Join(tempDir, "model.bin")
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
// Create vision model (llama architecture with vision block count)
visionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.vision.block_count": uint32(1),
}, []*ggml.Tensor{})
if err := errors.Join(
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
); err != nil {
t.Fatalf("Failed to create model files: %v", err)
}
// Create embedding model (bert architecture with pooling type)
embeddingModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(1),
}, []*ggml.Tensor{})
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
@@ -261,7 +51,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{
name: "completion model without tools capability",
model: Model{
ModelPath: simpleModelPath,
ModelPath: completionModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityTools},
@@ -270,7 +60,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{
name: "model with all needed capabilities",
model: Model{
ModelPath: simpleModelPath,
ModelPath: completionModelPath,
Template: toolsInsertTemplate,
},
checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
@@ -278,7 +68,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{
name: "model missing insert capability",
model: Model{
ModelPath: simpleModelPath,
ModelPath: completionModelPath,
Template: toolsTemplate,
},
checkCaps: []model.Capability{model.CapabilityInsert},
@@ -287,7 +77,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{
name: "model missing vision capability",
model: Model{
ModelPath: simpleModelPath,
ModelPath: completionModelPath,
Template: toolsTemplate,
},
checkCaps: []model.Capability{model.CapabilityVision},
@@ -312,7 +102,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{
name: "unknown capability",
model: Model{
ModelPath: simpleModelPath,
ModelPath: completionModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{"unknown"},

View File

@@ -10,9 +10,6 @@ import (
"log/slog"
"net/http"
"os"
"slices"
"strings"
"text/template/parse"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
@@ -128,124 +125,3 @@ func detectContentType(r io.Reader) (string, error) {
return "unknown", nil
}
func parseObjects(s string) []map[string]any {
var objs []map[string]any
for offset := 0; offset < len(s); {
var obj map[string]any
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
// skip over any syntax errors
offset += int(syntax.Offset)
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
// skip over any unmarshalable types
offset += int(unmarshalType.Offset)
} else if err != nil {
return nil
} else {
offset += int(decoder.InputOffset())
objs = append(objs, obj)
}
}
return objs
}
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// create a subtree from the node that ranges over .ToolCalls
tmpl := m.Template.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, false
}
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
Function: api.ToolCallFunction{
Name: "@@name@@",
Arguments: api.ToolCallFunctionArguments{
"@@argument@@": 1,
},
},
},
},
}); err != nil {
return nil, false
}
templateObjects := parseObjects(b.String())
if len(templateObjects) == 0 {
return nil, false
}
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range templateObjects[0] {
switch v.(type) {
case string:
name = k
case map[string]any:
arguments = k
}
}
if name == "" || arguments == "" {
return nil, false
}
responseObjects := parseObjects(s)
if len(responseObjects) == 0 {
return nil, false
}
// collect all nested objects
var collect func(any) []map[string]any
collect = func(obj any) (all []map[string]any) {
switch o := obj.(type) {
case map[string]any:
all = append(all, o)
for _, v := range o {
all = append(all, collect(v)...)
}
case []any:
for _, v := range o {
all = append(all, collect(v)...)
}
}
return all
}
var objs []map[string]any
for _, p := range responseObjects {
objs = append(objs, collect(p)...)
}
var toolCalls []api.ToolCall
for _, kv := range objs {
n, nok := kv[name].(string)
a, aok := kv[arguments].(map[string]any)
if nok && aok {
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
Arguments: a,
},
})
}
}
return toolCalls, len(toolCalls) > 0
}

View File

@@ -1,179 +0,0 @@
package server
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper()
bts, err := os.ReadFile(filepath.Join(base, name))
if err != nil {
t.Fatal(err)
}
return bytes.NewBuffer(bts)
}
func TestExecuteWithTools(t *testing.T) {
p := filepath.Join("testdata", "tools")
cases := []struct {
model string
output string
ok bool
}{
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"command-r-plus", "Action: ```json" + `
[
{
"tool_name": "get_current_weather",
"parameters": {
"format": "fahrenheit",
"location": "San Francisco, CA"
}
},
{
"tool_name": "get_current_weather",
"parameters": {
"format": "celsius",
"location": "Toronto, Canada"
}
}
]
` + "```", true},
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"llama3-groq-tool-use", `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
</tool_call>`, true},
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
{"nemotron", `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true},
}
var tools []api.Tool
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
t.Fatal(err)
}
var messages []api.Message
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
t.Fatal(err)
}
calls := []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"format": "celsius",
"location": "Toronto, Canada",
},
},
},
}
for _, tt := range cases {
t.Run(tt.model, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
if err != nil {
t.Fatal(err)
}
t.Run("template", func(t *testing.T) {
var actual bytes.Buffer
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("parse", func(t *testing.T) {
m := &Model{Template: tmpl}
actual, ok := m.parseToolCalls(tt.output)
if ok != tt.ok {
t.Fatalf("expected %t, got %t", tt.ok, ok)
}
if tt.ok {
if diff := cmp.Diff(actual, calls); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
})
})
}
}
func TestParseObjects(t *testing.T) {
tests := []struct {
input string
want []map[string]any
}{
{
input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}},
},
},
{
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall>`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
},
},
{
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall> <toolcall>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} </toolcall>`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}},
},
},
{
input: `{"name": "get_current_weather", "arguments": `,
want: nil,
},
}
for _, tc := range tests {
t.Run(tc.input, func(t *testing.T) {
got := parseObjects(tc.input)
if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -116,7 +116,7 @@ func (mp ModelPath) BaseURL() *url.URL {
func GetManifestPath() (string, error) {
path := filepath.Join(envconfig.Models(), "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", err
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
@@ -139,7 +139,7 @@ func GetBlobsPath(digest string) (string, error) {
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", err
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil

View File

@@ -19,7 +19,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *bool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
@@ -41,8 +41,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
}
}
thinkVal := false
if think != nil {
thinkVal = *think
}
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
return "", nil, err
}
@@ -96,7 +100,11 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
thinkVal := false
if think != nil {
thinkVal = *think
}
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
return "", nil, err
}

View File

@@ -208,7 +208,8 @@ func TestChatPrompt(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
think := false
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
if tt.error == nil && err != nil {
t.Fatal(err)
} else if tt.error != nil && err != tt.error {

View File

@@ -257,16 +257,8 @@ func TestQuantizeModel(t *testing.T) {
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
f, err := os.CreateTemp(t.TempDir(), tt.name)
if err != nil {
t.Fatal(err.Error())
}
defer f.Close()
err = fsggml.WriteGGUF(f, tt.kv, tt.tensors)
if err != nil {
t.Fatalf("failed to create initial model: %s", err)
}
fp, err := os.Open(f.Name())
p, _ := createBinFile(t, tt.kv, tt.tensors)
fp, err := os.Open(p)
if err != nil {
t.Fatal(err.Error())
}

View File

@@ -17,7 +17,6 @@ import (
"net/netip"
"os"
"os/signal"
"regexp"
"slices"
"strings"
"syscall"
@@ -35,9 +34,12 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/cache"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/tools"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
@@ -185,6 +187,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Suffix != "" {
caps = append(caps, model.CapabilityInsert)
}
if req.Think != nil && *req.Think {
caps = append(caps, model.CapabilityThinking)
// TODO(drifkin): consider adding a warning if it's false and the model
// doesn't support thinking. It's not strictly required, but it can be a
// hint that the user is on an older qwen3/r1 model that doesn't have an
// updated template supporting thinking
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
@@ -253,6 +262,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
}
values.Think = req.Think != nil && *req.Think
values.IsThinkSet = req.Think != nil
var b bytes.Buffer
if req.Context != nil {
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
@@ -272,6 +284,15 @@ func (s *Server) GenerateHandler(c *gin.Context) {
prompt = b.String()
}
var thinkingState *thinking.Parser
openingTag, closingTag := thinking.InferTags(m.Template.Template)
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
thinkingState = &thinking.Parser{
OpeningTag: openingTag,
ClosingTag: closingTag,
}
}
ch := make(chan any)
go func() {
// TODO (jmorganca): avoid building the response twice both here and below
@@ -296,6 +317,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
},
}
if thinkingState != nil {
thinking, content := thinkingState.AddContent(cr.Content)
res.Thinking = thinking
res.Response = content
}
if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()}
}
@@ -323,11 +350,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Stream != nil && !*req.Stream {
var r api.GenerateResponse
var sb strings.Builder
var sbThinking strings.Builder
var sbContent strings.Builder
for rr := range ch {
switch t := rr.(type) {
case api.GenerateResponse:
sb.WriteString(t.Response)
sbThinking.WriteString(t.Thinking)
sbContent.WriteString(t.Response)
r = t
case gin.H:
msg, ok := t["error"].(string)
@@ -343,7 +372,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
}
r.Response = sb.String()
r.Thinking = sbThinking.String()
r.Response = sbContent.String()
c.JSON(http.StatusOK, r)
return
}
@@ -789,13 +820,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
resp := &api.ShowResponse{
License: strings.Join(m.License, "\n"),
System: m.System,
Template: m.Template.String(),
Details: modelDetails,
Messages: msgs,
Capabilities: m.Capabilities(),
ModifiedAt: manifest.fi.ModTime(),
License: strings.Join(m.License, "\n"),
System: m.System,
Template: m.Template.String(),
Details: modelDetails,
Messages: msgs,
Capabilities: cache.Capabilities(cache.ModelInfo{
ModelPath: m.ModelPath,
Template: m.Template,
ProjectorPaths: m.ProjectorPaths,
}),
ModifiedAt: manifest.fi.ModTime(),
}
var params []string
@@ -1435,6 +1470,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 {
caps = append(caps, model.CapabilityTools)
}
if req.Think != nil && *req.Think {
caps = append(caps, model.CapabilityThinking)
}
name := model.ParseName(req.Model)
if !name.IsValid() {
@@ -1475,18 +1513,31 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
msgs = filterThinkTags(msgs, m)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think)
if err != nil {
slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var thinkingState *thinking.Parser
openingTag, closingTag := thinking.InferTags(m.Template.Template)
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
thinkingState = &thinking.Parser{
OpeningTag: openingTag,
ClosingTag: closingTag,
}
}
var toolParser *tools.Parser
if len(req.Tools) > 0 {
toolParser = tools.NewParser(m.Template.Template, req.Tools)
}
ch := make(chan any)
go func() {
defer close(ch)
var sb strings.Builder
var toolCallIndex int = 0
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
@@ -1506,43 +1557,41 @@ func (s *Server) ChatHandler(c *gin.Context) {
},
}
if thinkingState != nil {
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Content = remainingContent
res.Message.Thinking = thinkingContent
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
// however this was a simple change for now without reworking streaming logic of this (and other)
// handlers
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
ch <- res
return
if len(req.Tools) > 0 {
toolCalls, content := toolParser.Add(res.Message.Content)
if len(content) > 0 {
res.Message.Content = content
} else if len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else if res.Message.Thinking != "" {
// don't return
} else {
if r.Done {
res.Message.Content = toolParser.Content()
ch <- res
}
return
}
}
// Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
toolCalls[i].Function.Index = toolCallIndex
toolCallIndex++
}
res.Message.Content = ""
sb.Reset()
ch <- res
return
}
if r.Done {
// Send any remaining content if no tool calls were detected
if toolCallIndex == 0 {
res.Message.Content = sb.String()
}
ch <- res
}
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
@@ -1550,12 +1599,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
if req.Stream != nil && !*req.Stream {
var resp api.ChatResponse
var sb strings.Builder
var toolCalls []api.ToolCall
var sbThinking strings.Builder
var sbContent strings.Builder
for rr := range ch {
switch t := rr.(type) {
case api.ChatResponse:
sb.WriteString(t.Message.Content)
sbThinking.WriteString(t.Message.Thinking)
sbContent.WriteString(t.Message.Content)
resp = t
if len(req.Tools) > 0 {
toolCalls = append(toolCalls, t.Message.ToolCalls...)
}
case gin.H:
msg, ok := t["error"].(string)
if !ok {
@@ -1570,13 +1625,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}
resp.Message.Content = sb.String()
resp.Message.Content = sbContent.String()
resp.Message.Thinking = sbThinking.String()
if len(req.Tools) > 0 {
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
if len(toolCalls) > 0 {
resp.Message.ToolCalls = toolCalls
}
c.JSON(http.StatusOK, resp)
@@ -1601,8 +1654,6 @@ func handleScheduleError(c *gin.Context, name string, err error) {
}
}
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`)
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
finalUserIndex := -1
@@ -1614,7 +1665,17 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
for i, msg := range msgs {
if msg.Role == "assistant" && i < finalUserIndex {
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "")
// TODO(drifkin): this is from before we added proper thinking support.
// However, even if thinking is not enabled (and therefore we shouldn't
// change the user output), we should probably perform this filtering
// for all thinking models (not just qwen3 & deepseek-r1) since it tends
// to save tokens and improve quality.
thinkingState := &thinking.Parser{
OpeningTag: "<think>",
ClosingTag: "</think>",
}
_, content := thinkingState.AddContent(msg.Content)
msgs[i].Content = content
}
}
}

View File

@@ -143,6 +143,25 @@ func TestGenerateChat(t *testing.T) {
}
})
t.Run("missing thinking capability", func(t *testing.T) {
think := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Think: &think,
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
if w.Code != http.StatusBadRequest {

View File

@@ -112,11 +112,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
b.ctx, b.ctxDone = context.WithCancel(ctx)
t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName)
require.NoError(t, err)
defer f.Close()
require.NoError(t, ggml.WriteGGUF(f, ggml.KV{
p, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
@@ -129,14 +125,14 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
}, []*ggml.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
}))
require.NoError(t, err)
fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname}
b.f, err = llm.LoadModel(model.ModelPath, 0)
require.NoError(t, err)
})
model := &Model{Name: modelName, ModelPath: p}
f, err := llm.LoadModel(model.ModelPath, 0)
if err != nil {
t.Fatal(err)
}
b.f = f
if duration == nil {
duration = &api.Duration{Duration: 5 * time.Millisecond}
}

View File

@@ -1,67 +0,0 @@
{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
{{- if .Tools }}# Safety Preamble
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
# System Preamble
## Basic Rules
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
{{ if .System }}# User Preamble
{{ .System }}
{{- end }}
## Available Tools
Here is a list of tools that you have available to you:
{{- range .Tools }}
```python
def {{ .Function.Name }}(
{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]:
"""{{ .Function.Description }}
{{- if .Function.Parameters.Properties }}
Args:
{{- range $name, $property := .Function.Parameters.Properties }}
{{ $name }} ({{ $property.Type }}): {{ $property.Description }}
{{- end }}
{{- end }}
"""
pass
```
{{- end }}
{{- else if .System }}{{ .System }}
{{- end }}<|END_OF_TURN_TOKEN|>
{{- end }}
{{- range .Messages }}
{{- if eq .Role "system" }}
{{- continue }}
{{- end }}<|START_OF_TURN_TOKEN|>
{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }}
{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}
Action: ```json
[
{{- range .ToolCalls }}
{
"tool_name": "{{ .Function.Name }}",
"parameters": {{ .Function.Arguments }}
}
{{- end }}
]```
{{ continue }}
{{ end }}
{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|><results>
{{ .Content }}</results>
{{- end }}<|END_OF_TURN_TOKEN|>
{{- end }}
{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
```json
[
{
"tool_name": title of the tool in the specification,
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
}
]```
{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@@ -1,39 +0,0 @@
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
# System Preamble
## Basic Rules
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
# User Preamble
You are a knowledgeable assistant. You can answer questions and perform tasks.
## Available Tools
Here is a list of tools that you have available to you:
```python
def get_current_weather(format: string, location: string, ) -> List[Dict]:
"""Get the current weather
Args:
format (string): The temperature unit to use. Infer this from the user's location.
location (string): The city and state, e.g. San Francisco, CA
"""
pass
```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
Action: ```json
[
{
"tool_name": "get_current_weather",
"parameters": {"format":"celsius","location":"Paris, France"}
}
]```
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>
22</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
```json
[
{
"tool_name": title of the tool in the specification,
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
}
]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@@ -1,31 +0,0 @@
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
{{- if .System }}
{{ .System }}
{{- end }}
In addition to plain text responses, you can chose to call one or more of the provided functions.
Use the following rule to decide when to call a function:
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
If you decide to call functions:
* prefix function calls with functools marker (no closing marker required)
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
* make sure you pick the right functions that match the user intent
Available functions as JSON spec:
{{- if .Tools }}
{{ .Tools }}
{{- end }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>
{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }}
{{- end }}<|end_header_id|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }} functools[
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
{{- end }}]
{{- end }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,17 +0,0 @@
<|start_header_id|>system<|end_header_id|>
You are a knowledgeable assistant. You can answer questions and perform tasks.
In addition to plain text responses, you can chose to call one or more of the provided functions.
Use the following rule to decide when to call a function:
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
If you decide to call functions:
* prefix function calls with functools marker (no closing marker required)
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
* make sure you pick the right functions that match the user intent
Available functions as JSON spec:
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgeable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,43 +0,0 @@
{{- if .Messages }}
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
{{ .System }}
{{- if .Tools }} You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"name": <function-name>,"arguments": <args-dict>}
</tool_call>
Here are the available tools:
<tools>
{{- range .Tools }} {{ .Function }}
{{- end }} </tools>
{{- end }}
{{- end }}<|eot_id|>
{{- range .Messages }}
{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|>
{{ if eq .Role "user" }}{{ .Content }}
{{- else if eq .Role "assistant" }}
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}
</tool_call>
{{- end }}
{{- else if eq .Role "tool" }}<tool_response>
{{ .Content }}
</tool_response>
{{- end }}<|eot_id|>
{{- end }}
{{- end }}<|start_header_id|>assistant<|end_header_id|>
{{ else }}
{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ end }}{{ .Response }}
{{- if .Response }}<|eot_id|>
{{- end }}

View File

@@ -1,24 +0,0 @@
<|start_header_id|>system<|end_header_id|>
You are a knowledgeable assistant. You can answer questions and perform tasks. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"name": <function-name>,"arguments": <args-dict>}
</tool_call>
Here are the available tools:
<tools> {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} </tools><|eot_id|><|start_header_id|>user<|end_header_id|>
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
<tool_call>
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
</tool_call><|eot_id|><|start_header_id|>tool<|end_header_id|>
<tool_response>
22
</tool_response><|eot_id|><|start_header_id|>assistant<|end_header_id|>
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,39 +0,0 @@
[
{
"role": "system",
"content": "You are a knowledgeable assistant. You can answer questions and perform tasks."
},
{
"role": "user",
"content": "What's the weather like today in Paris?"
},
{
"role": "assistant",
"tool_calls": [
{
"id": "89a1e453-0bce-4de3-a456-c54bed09c520",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": {
"location": "Paris, France",
"format": "celsius"
}
}
}
]
},
{
"role": "tool",
"tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520",
"content": "22"
},
{
"role": "assistant",
"content": "The current temperature in Paris, France is 22 degrees Celsius."
},
{
"role": "user",
"content": "What's the weather like today in San Francisco and Toronto?"
}
]

View File

@@ -1,15 +0,0 @@
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
{{ end }}{{ .Content }}[/INST]
{{- else if eq .Role "assistant" }}
{{- if .Content }} {{ .Content }}</s>
{{- else if .ToolCalls }}[TOOL_CALLS] [
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}]</s>
{{- end }}
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
{{- end }}
{{- end }}

View File

@@ -1,3 +0,0 @@
[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]</s>[TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.</s>[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgeable assistant. You can answer questions and perform tasks.
What's the weather like today in San Francisco and Toronto?[/INST]

View File

@@ -1,33 +0,0 @@
{{- if (or .Tools .System) }}<extra_id_0>System
{{ if .System }}{{ .System }}
{{ end }}
{{- if .Tools }}
{{- range .Tools }}<tool> {{ . }} </tool>{{ end }}
{{ end }}
{{- end }}
{{- range $i, $m := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{- if eq .Role "user" }}<extra_id_1>User
{{ .Content }}
{{- if $last }}
<extra_id_1>Assistant
{{- end }}
{{ else if eq .Role "tool" }}<extra_id_1>Tool
{{ .Content }}
{{- if $last }}
<extra_id_1>Assistant
{{- end }}
{{ else if eq .Role "assistant" }}<extra_id_1>Assistant
{{- if .ToolCalls }}
{{ range .ToolCalls }}<toolcall> {"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} </toolcall> {{ end }}
{{ else }}
{{ .Content }}
{{- if not $last }}
{{ end }}
{{- end }}
{{- end }}
{{- end }}

View File

@@ -1,18 +0,0 @@
<extra_id_0>System
You are a knowledgeable assistant. You can answer questions and perform tasks.
<tool> {"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} </tool>
<extra_id_1>User
What's the weather like today in Paris?
<extra_id_1>Assistant
<toolcall> {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} </toolcall>
<extra_id_1>Tool
22
<extra_id_1>Assistant
The current temperature in Paris, France is 22 degrees Celsius.
<extra_id_1>User
What's the weather like today in San Francisco and Toronto?
<extra_id_1>Assistant

View File

@@ -1,30 +0,0 @@
[
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
],
"description": "The temperature unit to use. Infer this from the user's location."
}
},
"required": [
"location",
"format"
]
}
}
}
]

View File

@@ -1,45 +0,0 @@
{{- if .System }}{{ .System }}
{{ end }}
{{- range $i, $_ := .Messages }}
{{- if eq .Role "user" }}### Instruction:
{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }}
[BEGIN OF TASK INSTRUCTION]
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the functions can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out.
[END OF TASK INSTRUCTION]
[BEGIN OF AVAILABLE TOOLS]
{{ $.Tools }}
[END OF AVAILABLE TOOLS]
[BEGIN OF FORMAT INSTRUCTION]
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
[BEGIN OF QUERY]
{{ .Content }}
[END OF QUERY]
{{ else }}
{{ .Content }}
{{ end }}
{{- else if .ToolCalls }}### Response:
{"tool_calls": [{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}]}
<|EOT|>
{{ else if eq .Role "assistant" }}### Response:
{{ .Content }}
<|EOT|>
{{ end }}
{{- end }}### Response:

View File

@@ -1,40 +0,0 @@
You are a knowledgeable assistant. You can answer questions and perform tasks.
### Instruction:
What's the weather like today in Paris?
### Response:
{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]}
<|EOT|>
### Response:
The current temperature in Paris, France is 22 degrees Celsius.
<|EOT|>
### Instruction:
[BEGIN OF TASK INSTRUCTION]
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the functions can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out.
[END OF TASK INSTRUCTION]
[BEGIN OF AVAILABLE TOOLS]
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]
[END OF AVAILABLE TOOLS]
[BEGIN OF FORMAT INSTRUCTION]
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
[BEGIN OF QUERY]
What's the weather like today in San Francisco and Toronto?
[END OF QUERY]
### Response:

View File

@@ -167,6 +167,10 @@ type Values struct {
api.Tools
Prompt string
Suffix string
Think bool
// whether or not the user explicitly set the thinking flag (vs. it being
// implicitly false). Templates can't see whether `Think` is nil
IsThinkSet bool
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool
@@ -222,16 +226,20 @@ func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages)
if v.Prompt != "" && v.Suffix != "" {
return t.Template.Execute(w, map[string]any{
"Prompt": v.Prompt,
"Suffix": v.Suffix,
"Response": "",
"Prompt": v.Prompt,
"Suffix": v.Suffix,
"Response": "",
"Think": v.Think,
"IsThinkSet": v.IsThinkSet,
})
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": messages,
"Tools": v.Tools,
"Response": "",
"System": system,
"Messages": messages,
"Tools": v.Tools,
"Response": "",
"Think": v.Think,
"IsThinkSet": v.IsThinkSet,
})
}
@@ -241,9 +249,11 @@ func (t *Template) Execute(w io.Writer, v Values) error {
for _, m := range messages {
execute := func() error {
if err := t.Template.Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
"System": system,
"Prompt": prompt,
"Response": response,
"Think": v.Think,
"IsThinkSet": v.IsThinkSet,
}); err != nil {
return err
}
@@ -286,9 +296,11 @@ func (t *Template) Execute(w io.Writer, v Values) error {
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
"System": system,
"Prompt": prompt,
"Response": response,
"Think": v.Think,
"IsThinkSet": v.IsThinkSet,
}); err != nil {
return err
}

171
thinking/parser.go Normal file
View File

@@ -0,0 +1,171 @@
package thinking
import (
"strings"
"unicode"
)
type thinkingState int
const (
// We're looking for the opening tag, but we haven't seen any non-whitespace
// characters yet
thinkingState_LookingForOpening thinkingState = iota
// We've seen the opening tag, but we haven't seen any non-whitespace
// characters yet (we want to eat any whitespace between the opening tag and
// the thinking content)
thinkingState_ThinkingStartedEatingWhitespace
// We've seen non-whitespace characters after the opening tag, but we haven't
// seen the closing tag yet
thinkingState_Thinking
// We've seen the closing tag, but we haven't seen any non-whitespace
// characters after the closing tag yet (we want to eat any whitespace between
// the closing tag and the content)
thinkingState_ThinkingDoneEatingWhitespace
// We've seen the closing tag and seen at least one non-whitespace character
// after it
thinkingState_ThinkingDone
)
func (s thinkingState) String() string {
switch s {
case thinkingState_LookingForOpening:
return "LookingForOpening"
case thinkingState_ThinkingStartedEatingWhitespace:
return "ThinkingStartedEatingWhitespace"
case thinkingState_Thinking:
return "Thinking"
case thinkingState_ThinkingDoneEatingWhitespace:
return "ThinkingDoneEatingWhitespace"
case thinkingState_ThinkingDone:
return "ThinkingDone"
default:
return "Unknown"
}
}
type Parser struct {
state thinkingState
OpeningTag string
ClosingTag string
acc strings.Builder
}
// AddContent returns the thinking content and the non-thinking content that
// should be immediately sent to the user. It will internally buffer if it needs
// to see more raw content to disambiguate
func (s *Parser) AddContent(content string) (string, string) {
s.acc.WriteString(content)
var thinkingSb, remainingSb strings.Builder
var thinking, remaining string
keepLooping := true
// we loop because we might pass through multiple parsing states in a single
// call to addContent, and we want to make sure callers don't have to wait for
// data that's already unambiguous
for keepLooping {
thinking, remaining, keepLooping = eat(s)
thinkingSb.WriteString(thinking)
remainingSb.WriteString(remaining)
}
return thinkingSb.String(), remainingSb.String()
}
// the additional bool return is true iff we should continue eating
func eat(s *Parser) (string, string, bool) {
switch s.state {
case thinkingState_LookingForOpening:
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, s.OpeningTag) {
after := strings.Join(strings.Split(trimmed, s.OpeningTag)[1:], s.OpeningTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
// after might contain more than just thinking tokens, so we continue
// parsing instead of returning it as thinking tokens here
s.acc.Reset()
s.acc.WriteString(after)
if after == "" {
s.state = thinkingState_ThinkingStartedEatingWhitespace
} else {
s.state = thinkingState_Thinking
}
return "", "", true
} else if strings.HasPrefix(s.OpeningTag, trimmed) {
// partial opening seen, so let's keep accumulating
return "", "", false
} else if trimmed == "" {
// saw whitespace only, so let's keep accumulating
return "", "", false
} else {
// didn't see an opening tag, but we have content, so thinking was skipped
s.state = thinkingState_ThinkingDone
// note that we use the original content, not the trimmed one because we
// don't want to eat any whitespace in the real content if there were no
// thinking tags
return "", s.acc.String(), false
}
case thinkingState_ThinkingStartedEatingWhitespace:
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
s.acc.Reset()
if trimmed == "" {
return "", "", false
} else {
s.state = thinkingState_Thinking
s.acc.WriteString(trimmed)
return "", "", true
}
case thinkingState_Thinking:
acc := s.acc.String()
if strings.Contains(acc, s.ClosingTag) {
split := strings.Split(acc, s.ClosingTag)
thinking := split[0]
remaining := strings.Join(split[1:], s.ClosingTag)
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
s.acc.Reset()
if remaining == "" {
s.state = thinkingState_ThinkingDoneEatingWhitespace
} else {
s.state = thinkingState_ThinkingDone
}
return thinking, remaining, false
} else if overlapLen := overlap(acc, s.ClosingTag); overlapLen > 0 {
thinking := acc[:len(acc)-overlapLen]
remaining := acc[len(acc)-overlapLen:]
s.acc.Reset()
// keep track of the candidate closing tag. We have to buffer it until it
// becomes disambiguated
s.acc.WriteString(remaining)
return thinking, "", false
} else {
// purely just thinking tokens, so we can return them
s.acc.Reset()
return acc, "", false
}
case thinkingState_ThinkingDoneEatingWhitespace:
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
s.acc.Reset()
// if we see non-whitespace, we're done eating the leading whitespace of the content
if trimmed != "" {
s.state = thinkingState_ThinkingDone
}
return "", trimmed, false
case thinkingState_ThinkingDone:
acc := s.acc.String()
s.acc.Reset()
return "", acc, false
default:
panic("unknown state")
}
}
// longest overlap between suffix of s and prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}

278
thinking/parser_test.go Normal file
View File

@@ -0,0 +1,278 @@
package thinking
import (
"testing"
)
func TestExtractThinking(t *testing.T) {
tests := []struct {
in, wantContent, wantThink string
}{
{
in: "<think> internal </think> world",
wantThink: "internal ",
wantContent: "world",
},
{
in: "<think>a</think><think>b</think>c",
wantThink: "a",
wantContent: "<think>b</think>c",
},
{
in: "no think",
wantThink: "",
wantContent: "no think",
},
}
for i, tt := range tests {
parser := Parser{
OpeningTag: "<think>",
ClosingTag: "</think>",
}
gotThinking, gotContent := parser.AddContent(tt.in)
if gotContent != tt.wantContent || gotThinking != tt.wantThink {
t.Errorf("case %d: got (%q,%q), want (%q,%q)", i, gotThinking, gotContent, tt.wantThink, tt.wantContent)
}
}
}
func TestThinkingStreaming(t *testing.T) {
type step struct {
input string
wantThinking string
wantContent string
wantStateAfter thinkingState
}
cases := []struct {
desc string
skip bool
steps []step
}{
{
desc: "content without a thinking tag",
steps: []step{
{
input: " abc",
wantThinking: "",
wantContent: " abc",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "content before a thinking tag nerfs the thinking tag",
steps: []step{
{
input: " abc <think>def</think> ghi",
wantThinking: "",
wantContent: " abc <think>def</think> ghi",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "building up a thinking tag partially",
steps: []step{
{
input: " <th",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_LookingForOpening,
},
{
input: "in",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_LookingForOpening,
},
{
input: "k>a",
wantThinking: "a",
wantContent: "",
wantStateAfter: thinkingState_Thinking,
},
},
},
{
desc: "partial closing tag",
steps: []step{
{
input: "<think>abc</th",
wantThinking: "abc",
wantContent: "",
wantStateAfter: thinkingState_Thinking,
},
{
input: "ink>def",
wantThinking: "",
wantContent: "def",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "partial closing tag fakeout",
steps: []step{
{
input: "<think>abc</th",
wantThinking: "abc",
wantContent: "",
wantStateAfter: thinkingState_Thinking,
},
{
input: "ing>def",
wantThinking: "</thing>def",
wantContent: "",
wantStateAfter: thinkingState_Thinking,
},
{
input: "ghi</thi",
wantThinking: "ghi",
wantContent: "",
wantStateAfter: thinkingState_Thinking,
},
{
input: "nk>jkl",
wantThinking: "",
wantContent: "jkl",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "whitespace after thinking tag",
steps: []step{
{
input: " <think>abc</think>\n\ndef",
wantThinking: "abc",
wantContent: "def",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "whitespace after thinking tag (incremental)",
steps: []step{
{
input: " <think>abc</think>",
wantThinking: "abc",
wantContent: "",
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
},
{
input: "\n\ndef",
wantThinking: "",
wantContent: "def",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "whitespace after thinking tag with content and more whitespace",
steps: []step{
{
input: " <think>abc</think>\n\ndef ",
wantThinking: "abc",
wantContent: "def ",
wantStateAfter: thinkingState_ThinkingDone,
},
{
input: " ghi",
wantThinking: "",
wantContent: " ghi",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "token by token",
steps: []step{
{
input: "<think>",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace,
},
{
input: "\n",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace,
},
{
input: "</think>",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
},
{
input: "\n\n",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
},
{
input: "Hi",
wantThinking: "",
wantContent: "Hi",
wantStateAfter: thinkingState_ThinkingDone,
},
{
input: " there",
wantThinking: "",
wantContent: " there",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{
desc: "leading thinking whitespace",
steps: []step{
{
input: " <think> \t ",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace,
},
{
input: " these are some ",
wantThinking: "these are some ",
wantContent: "",
wantStateAfter: thinkingState_Thinking,
},
{
input: "thoughts </think> ",
wantThinking: "thoughts ",
wantContent: "",
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
},
{
input: " more content",
wantThinking: "",
wantContent: "more content",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
}
for _, c := range cases {
parser := Parser{
OpeningTag: "<think>",
ClosingTag: "</think>",
}
if c.skip {
continue
}
for i, step := range c.steps {
thinking, content := parser.AddContent(step.input)
if content != step.wantContent || thinking != step.wantThinking {
t.Errorf("case %q (step %d): got (%q,%q), want (%q,%q)", c.desc, i, content, thinking, step.wantContent, step.wantThinking)
}
if parser.state != step.wantStateAfter {
t.Errorf("case %q (step %d): got state %s, want %s", c.desc, i, parser.state, step.wantStateAfter)
}
}
}
}

134
thinking/template.go Normal file
View File

@@ -0,0 +1,134 @@
package thinking
import (
"strings"
"text/template"
"text/template/parse"
)
func templateVisit(n parse.Node, enterFn func(parse.Node) bool, exitFn func(parse.Node)) {
if n == nil {
return
}
shouldContinue := enterFn(n)
if !shouldContinue {
return
}
switch x := n.(type) {
case *parse.ListNode:
for _, c := range x.Nodes {
templateVisit(c, enterFn, exitFn)
}
case *parse.BranchNode:
if x.Pipe != nil {
templateVisit(x.Pipe, enterFn, exitFn)
}
if x.List != nil {
templateVisit(x.List, enterFn, exitFn)
}
if x.ElseList != nil {
templateVisit(x.ElseList, enterFn, exitFn)
}
case *parse.ActionNode:
templateVisit(x.Pipe, enterFn, exitFn)
case *parse.WithNode:
templateVisit(&x.BranchNode, enterFn, exitFn)
case *parse.RangeNode:
templateVisit(&x.BranchNode, enterFn, exitFn)
case *parse.IfNode:
templateVisit(&x.BranchNode, enterFn, exitFn)
case *parse.TemplateNode:
templateVisit(x.Pipe, enterFn, exitFn)
case *parse.PipeNode:
for _, c := range x.Cmds {
templateVisit(c, enterFn, exitFn)
}
case *parse.CommandNode:
for _, a := range x.Args {
templateVisit(a, enterFn, exitFn)
}
// text, field, number, etc. are leaves nothing to recurse into
}
if exitFn != nil {
exitFn(n)
}
}
// InferTags uses a heuristic to infer the tags that surround thinking traces:
// We look for a range node that iterates over "Messages" and then look for a
// reference to "Thinking" like `{{.Thinking}}`. We then go up to the nearest
// ListNode and take the first and last TextNodes as the opening and closing
// tags.
func InferTags(t *template.Template) (string, string) {
ancestors := []parse.Node{}
openingTag := ""
closingTag := ""
enterFn := func(n parse.Node) bool {
ancestors = append(ancestors, n)
switch x := n.(type) {
case *parse.FieldNode:
if len(x.Ident) > 0 && x.Ident[0] == "Thinking" {
var mostRecentRange *parse.RangeNode
for i := len(ancestors) - 1; i >= 0; i-- {
if r, ok := ancestors[i].(*parse.RangeNode); ok {
mostRecentRange = r
break
}
}
if mostRecentRange == nil || !rangeUsesField(mostRecentRange, "Messages") {
return true
}
// TODO(drifkin): to be more robust, check that it's in the action
// part, not the `if`'s pipeline part. We do match on the nearest list
// that starts and ends with text nodes, which makes this not strictly
// necessary for our heuristic
// go up to the nearest ancestor that is a *parse.ListNode
for i := len(ancestors) - 1; i >= 0; i-- {
if l, ok := ancestors[i].(*parse.ListNode); ok {
firstNode := l.Nodes[0]
if t, ok := firstNode.(*parse.TextNode); ok {
openingTag = strings.TrimSpace(t.String())
}
lastNode := l.Nodes[len(l.Nodes)-1]
if t, ok := lastNode.(*parse.TextNode); ok {
closingTag = strings.TrimSpace(t.String())
}
break
}
}
}
}
return true
}
exitFn := func(n parse.Node) {
ancestors = ancestors[:len(ancestors)-1]
}
templateVisit(t.Root, enterFn, exitFn)
return openingTag, closingTag
}
// checks to see if the given field name is present in the pipeline of the given range node
func rangeUsesField(rangeNode *parse.RangeNode, field string) bool {
found := false
enterFn := func(n parse.Node) bool {
switch x := n.(type) {
case *parse.FieldNode:
if x.Ident[0] == field {
found = true
}
}
return true
}
templateVisit(rangeNode.BranchNode.Pipe, enterFn, nil)
return found
}

130
thinking/template_test.go Normal file
View File

@@ -0,0 +1,130 @@
package thinking
import (
"testing"
"text/template"
)
func TestInferThinkingTags(t *testing.T) {
cases := []struct {
desc string
tmplString string
wantOpeningTag string
wantClosingTag string
}{
{
desc: "basic",
tmplString: `
{{ if .Thinking}}
/think
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{ if and $last .Thinking }}
<think>{{ .Thinking }}</think>
{{ end }}
{{ end }}
`,
wantOpeningTag: "<think>",
wantClosingTag: "</think>",
},
{
desc: "doubly nested range",
tmplString: `
{{ if .Thinking}}
/think
{{ end }}
{{- range $i, $_ := .Messages }}
{{- range $j, $_ := .NotMessages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{ if and $last .Thinking }}
<think>{{ .Thinking }}</think>
{{ end }}
{{ end }}
{{ end }}
`,
wantOpeningTag: "",
wantClosingTag: "",
},
{
desc: "whitespace is trimmed",
tmplString: `
{{ if .Thinking}}
/think
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{ if and $last .Thinking }}
Some text before {{ .Thinking }} Some text after
{{ end }}
{{ end }}
`,
wantOpeningTag: "Some text before",
wantClosingTag: "Some text after",
},
{
desc: "qwen3",
tmplString: `
{{- if or .System .Tools .Thinking }}<|im_start|>system
{{- if .System }}
{{ .System }}
{{- end }}
{{- if .Tools }}
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{{- range .Tools }}
{"type": "function", "function": {{ .Function }}}
{{- end }}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
{{- end }}
{{- if .Thinking }}
/think
{{- else }}
/no_think
{{- end }}<|im_end|>
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{- if eq .Role "user" }}<|im_start|>user
{{ .Content }}<|im_end|>
{{ else if eq .Role "assistant" }}<|im_start|>assistant
{{ if and $last .Thinking }}
<think>{{ .Thinking }}</think>
{{ end }}
{{ if .Content }}{{ .Content }}
{{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{ end }}</tool_call>
{{- end }}{{ if not $last }}<|im_end|>
{{ end }}
{{- else if eq .Role "tool" }}<|im_start|>user
<tool_response>
{{ .Content }}
</tool_response><|im_end|>
{{ end }}
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
{{ end }}
{{- end }}
`,
wantOpeningTag: "<think>",
wantClosingTag: "</think>",
},
}
for _, c := range cases {
tmpl := template.Must(template.New("test").Parse(c.tmplString))
openingTag, closingTag := InferTags(tmpl)
if openingTag != c.wantOpeningTag || closingTag != c.wantClosingTag {
t.Errorf("case %q: got (%q,%q), want (%q,%q)", c.desc, openingTag, closingTag, c.wantOpeningTag, c.wantClosingTag)
}
}
}

156
tools/template.go Normal file
View File

@@ -0,0 +1,156 @@
package tools
import (
"bytes"
"log/slog"
"slices"
"strings"
"text/template"
"text/template/parse"
)
// parseTag finds the tool calling tag from a Go template
// often <tool_call> [TOOL_CALL] or similar by finding the
// first text node after .ToolCalls and returning the content
// if no tag is found, return "{" to indicate that json objects
// should be attempted to be parsed as tool calls
func parseTag(tmpl *template.Template) string {
if tmpl == nil || tmpl.Tree == nil {
slog.Debug("template or tree is nil")
return "{"
}
tc := findToolCallNode(tmpl.Tree.Root.Nodes)
if tc == nil {
return "{"
}
tn := findTextNode(tc.List.Nodes)
if tn == nil {
return "{"
}
tag := string(tn.Text)
tag = strings.ReplaceAll(tag, "\r\n", "\n")
// avoid parsing { onwards as this may be a tool call
// however keep '{' as a prefix if there is no tag
// so that all json objects will be attempted to
// be parsed as tool calls
tag, _, _ = strings.Cut(tag, "{")
tag = strings.TrimSpace(tag)
if tag == "" {
tag = "{"
}
return tag
}
// findToolCallNode searches for and returns an IfNode with .ToolCalls
func findToolCallNode(nodes []parse.Node) *parse.IfNode {
isToolCallsNode := func(n *parse.IfNode) bool {
for _, cmd := range n.Pipe.Cmds {
for _, arg := range cmd.Args {
if field, ok := arg.(*parse.FieldNode); ok {
if slices.Contains(field.Ident, "ToolCalls") {
return true
}
}
}
}
return false
}
for _, node := range nodes {
switch n := node.(type) {
case *parse.IfNode:
if isToolCallsNode(n) {
return n
}
// Recursively search in nested IfNodes
if result := findToolCallNode(n.List.Nodes); result != nil {
return result
}
if n.ElseList != nil {
if result := findToolCallNode(n.ElseList.Nodes); result != nil {
return result
}
}
case *parse.ListNode:
if result := findToolCallNode(n.Nodes); result != nil {
return result
}
case *parse.RangeNode:
if result := findToolCallNode(n.List.Nodes); result != nil {
return result
}
if n.ElseList != nil {
if result := findToolCallNode(n.ElseList.Nodes); result != nil {
return result
}
}
case *parse.WithNode:
if result := findToolCallNode(n.List.Nodes); result != nil {
return result
}
if n.ElseList != nil {
if result := findToolCallNode(n.ElseList.Nodes); result != nil {
return result
}
}
}
}
return nil
}
// findTextNode does a depth-first search for the first text content in nodes,
// stopping at template constructs to avoid parsing text after the tool calls
func findTextNode(nodes []parse.Node) *parse.TextNode {
for _, node := range nodes {
switch n := node.(type) {
case *parse.TextNode:
// skip whitespace-only text nodes
if len(bytes.TrimSpace(n.Text)) == 0 {
continue
}
return n
case *parse.IfNode:
if text := findTextNode(n.List.Nodes); text != nil {
return text
}
if n.ElseList != nil {
if text := findTextNode(n.ElseList.Nodes); text != nil {
return text
}
}
return nil
case *parse.ListNode:
if text := findTextNode(n.Nodes); text != nil {
return text
}
case *parse.RangeNode:
if text := findTextNode(n.List.Nodes); text != nil {
return text
}
if n.ElseList != nil {
if text := findTextNode(n.ElseList.Nodes); text != nil {
return text
}
}
return nil
case *parse.WithNode:
if text := findTextNode(n.List.Nodes); text != nil {
return text
}
if n.ElseList != nil {
if text := findTextNode(n.ElseList.Nodes); text != nil {
return text
}
}
return nil
case *parse.ActionNode:
return nil
}
}
return nil
}

139
tools/template_test.go Normal file
View File

@@ -0,0 +1,139 @@
package tools
import (
"testing"
"text/template"
)
func TestParseTag(t *testing.T) {
cases := []struct {
name string
template string
want string
}{
{
name: "empty",
template: "",
want: "{",
},
{
name: "no tag",
template: "{{if .ToolCalls}}{{end}}",
want: "{",
},
{
name: "no tag with range",
template: "{{if .ToolCalls}}{{range .ToolCalls}}{{ . }}{{end}}{{end}}",
want: "{",
},
{
name: "tool call with json format",
template: "{{if .ToolCalls}}```json\n{{end}}",
want: "```json",
},
{
name: "square brackets",
template: "{{if .ToolCalls}}[{{range .ToolCalls}}{{ . }}{{end}}]{{end}}",
want: "[",
},
{
name: "square brackets with whitespace",
template: "{{if .ToolCalls}}\n [ {{range .ToolCalls}}{{ . }}{{end}}]{{end}}",
want: "[",
},
{
name: "tailing ]",
template: "{{if .ToolCalls}}{{range .ToolCalls}}{{ . }}{{end}}]{{end}}",
want: "{",
},
{
name: "whitespace only",
template: "{{if .ToolCalls}} {{range .ToolCalls}}{{ . }}{{end}}{{end}}",
want: "{",
},
{
name: "whitespace only in range",
template: "{{if .ToolCalls}}{{range .ToolCalls}}\n{{ . }}\n{{end}}{{end}}",
want: "{",
},
{
name: "json objects",
template: `{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{end}}{{end}}`,
want: "{",
},
{
name: "json objects with whitespace",
template: "{{if .ToolCalls}}{{range .ToolCalls}}\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}",
want: "{",
},
{
name: "json objects with CRLF",
template: "{{if .ToolCalls}}{{range .ToolCalls}}\r\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}",
want: "{",
},
{
name: "json objects with whitespace before and after range",
template: "{{if .ToolCalls}}\n{{range .ToolCalls}}\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\r\n{{end}}\r\n{{end}}",
want: "{",
},
{
name: "before and after range",
template: "{{if .ToolCalls}}<|tool▁calls▁begin|>{{range .ToolCalls}}<|tool▁call▁begin|>functionget_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|>\n{{end}}<|tool▁calls▁end|>{{end}}",
want: "<|tool▁calls▁begin|>",
},
{
name: "after range",
template: "{{if .ToolCalls}}{{range .ToolCalls}}<tool_call>{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}</tool_call>{{end}}{{end}}",
want: "<tool_call>",
},
{
name: "after range with leading whitespace before range",
template: "{{if .ToolCalls}}\n{{range .ToolCalls}}<tool_call>{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}</tool_call>{{end}}{{end}}",
want: "<tool_call>",
},
{
name: "tool call in range with {",
template: `{{if .ToolCalls}}{{range .ToolCalls}}<tool_call>{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}<tool_call>{{end}}{{end}}`,
want: "<tool_call>",
},
{
name: "tool call with multiple text nodes",
template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}",
want: "First text",
},
{
name: "action tag",
template: "{{if .ToolCalls}}Action: ```json{{end}}",
want: "Action: ```json",
},
{
name: "incomplete functools bracket",
template: "{{if .ToolCalls}}functools[{{end}}",
want: "functools[",
},
{
name: "uppercase tool call with incomplete bracket",
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
want: "[TOOL_CALL] [",
},
{
name: "uppercase tool call with adjacent bracket",
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
want: "[TOOL_CALL][",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.template)
if err != nil && tc.template != "" {
t.Fatalf("failed to parse template: %v", err)
}
got := parseTag(tmpl)
if got != tc.want {
t.Errorf("got text %q, want %q", got, tc.want)
}
})
}
}

287
tools/tools.go Normal file
View File

@@ -0,0 +1,287 @@
package tools
import (
"bytes"
"encoding/json"
"strings"
"text/template"
"github.com/ollama/ollama/api"
)
type toolsState int
const (
toolsState_LookingForTag toolsState = iota
toolsState_ToolCalling
toolsState_Done
)
type Parser struct {
tag string
names []string
properties []string
state toolsState
buffer []byte
n int
}
// NewParser creates a new tool call parser from a model's chat
// template and a list of provided tools.
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
return NewParserWithTag(tools, parseTag(tmpl))
}
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
var p Parser
for _, t := range tools {
p.names = append(p.names, t.Function.Name)
for r := range t.Function.Parameters.Properties {
p.properties = append(p.properties, r)
}
}
p.tag = tag
return &p
}
// Add processes a string input to parse tool calls and content that
// should be sent back to the user.
func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
if p.state == toolsState_Done {
return nil, s
}
p.buffer = append(p.buffer, s...)
if p.state == toolsState_LookingForTag {
i, found := p.findTag()
if i == -1 {
content = string(p.buffer)
p.buffer = []byte{}
} else {
content = string(p.buffer[:i])
p.buffer = p.buffer[i:]
}
// for models where { or [ are used as tool calling
// tags, we only support parsing tools if the first non-
// whitespace character is { or [
if p.tag == "{" || p.tag == "[" {
if strings.TrimSpace(content) != "" {
p.state = toolsState_Done
return nil, content + string(p.buffer)
}
}
if !found {
return nil, content
}
p.state = toolsState_ToolCalling
}
for {
call := p.parseToolCall()
if call == nil {
break
}
calls = append(calls, *call)
}
if p.done() {
p.state = toolsState_Done
content = string(p.buffer)
p.buffer = []byte{}
}
return calls, content
}
// findTag searches the buffer to find and handle a tool calling tag
// returning true if the tag was found and false otherwise, and
// a string content signaling any content that should be sent back to the user
func (p *Parser) findTag() (int, bool) {
// First check for complete substring anywhere in s
if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
return i, true
}
// Then check for partial suffix overlap
max := min(len(p.buffer), len(p.tag))
for i := max; i > 0; i-- {
if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
return len(p.buffer) - i, false
}
}
return -1, false
}
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
var name string
var args map[string]any
var end int = len(p.buffer)
// find tool name
var i int
for _, n := range p.names {
if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
if i+len(n) < end {
name = n
end = i + len(n)
}
}
}
if name == "" {
return nil
}
if args, i = p.findArguments(); args == nil {
return nil
}
if i > end {
end = i
}
tc := &api.ToolCall{
Function: api.ToolCallFunction{
Name: name,
Arguments: args,
Index: p.n,
},
}
p.n++
p.buffer = p.buffer[end:]
return tc
}
// findArguments returns the first object that appears to be
// arguments and the position where the arguments end, returning nil and 0 if
// an invalid JSON object or non-arguments object is found first
func (p *Parser) findArguments() (map[string]any, int) {
if len(p.buffer) == 0 {
return nil, 0
}
var braces int
var start int = -1
var end int
var object []byte
// find any outer json object
for i, c := range p.buffer {
if c == '{' {
braces++
if start == -1 {
start = i
}
}
if c == '}' {
braces--
if braces == 0 && start != -1 {
end = i + 1
object = p.buffer[start:end]
break
}
}
}
if braces > 0 {
return nil, 0
}
var data map[string]any
// not valid json
if err := json.Unmarshal(object, &data); err != nil {
return nil, 0
}
var find func(obj any) map[string]any
find = func(obj any) map[string]any {
switch v := obj.(type) {
case map[string]any:
// check if the object keys are valid tool properties
// TODO (jmorganca): check only sets of properties that
// go together instead of the entire set
for _, prop := range p.properties {
if _, exists := v[prop]; exists {
return v
}
}
for _, value := range v {
if result := find(value); result != nil {
return result
}
}
case []any:
for _, item := range v {
if result := find(item); result != nil {
return result
}
}
}
return nil
}
result := find(data)
if result != nil {
return result, end
}
return nil, 0
}
// done checks if the parser is done parsing by looking
// for closing tag. currently only } and ] are supported
// for closing tags as {} or [] pairs may not always
// represent tool calls and we need to send the content back
func (p *Parser) done() bool {
var open, close rune
switch p.tag {
case "{":
open, close = '{', '}'
case "[":
open, close = '[', ']'
default:
return false
}
var count int
for _, c := range p.buffer {
if c == byte(open) {
count++
} else if c == byte(close) {
count--
if count == 0 {
return true
}
}
}
return false
}
// Content returns any remaining content that
// should be sent to the user. This should be the empty string
// string unless the tag is { or [ and a tool call was not found
func (p *Parser) Content() string {
if p.n > 0 {
return ""
}
if p.tag == "{" || p.tag == "[" {
return string(p.buffer)
}
return ""
}

805
tools/tools_test.go Normal file
View File

@@ -0,0 +1,805 @@
package tools
import (
"testing"
"text/template"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestParser(t *testing.T) {
qwen, err := template.New("qwen").Parse(`{{if .ToolCalls}}<tool_call>{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}</tool_call>{{end}}`)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
deepseek, err := template.New("deepseek").Parse("{{if .ToolCalls}}<|tool▁calls▁begin|>{{range .ToolCalls}}<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|>{{end}}<|tool▁calls▁end|><|end▁of▁sentence|>{{end}}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
json, err := template.New("json").Parse(`{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}{{end}}`)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
mistral, err := template.New("mistral").Parse(`{{if .ToolCalls}}[TOOL_CALLS] [{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}][/TOOL_CALLS]{{end}}`)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
list, err := template.New("list").Parse(`{{if .ToolCalls}}[{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}]{{end}}`)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_temperature",
Description: "Retrieve the temperature for a given location",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Properties: map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
}{
"format": {
Type: api.PropertyType{"string"},
Description: "The format to return the temperature in",
Enum: []any{"fahrenheit", "celsius"},
},
"city": {
Type: api.PropertyType{"string"},
Description: "The city to get the temperature for",
},
},
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "get_conditions",
Description: "Retrieve the current weather conditions for a given location",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Properties: map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
}{
"location": {
Type: api.PropertyType{"string"},
Description: "The location to get the weather conditions for",
},
},
},
},
},
}
tests := []struct {
name string
inputs []string
tmpl *template.Template
content string
calls []api.ToolCall
}{
{
name: "no tool calls - just text",
inputs: []string{"Hello, how can I help you today?"},
content: "Hello, how can I help you today?",
tmpl: qwen,
calls: nil,
},
{
name: "empty input",
inputs: []string{""},
content: "",
tmpl: qwen,
calls: nil,
},
{
name: "tool call",
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {"location": "San Francisco"}}</tool_call>`},
content: "",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "San Francisco",
},
},
},
},
},
{
name: "text before tool call",
inputs: []string{`Let me check the weather. <tool_call>{"name": "get_temperature", "arguments": {"city": "New York"}}</tool_call>`},
content: "Let me check the weather. ",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "New York",
},
},
},
},
},
{
name: "two tool calls in a list",
inputs: []string{`[TOOL_CALLS] [{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}, {"name": "get_conditions", "arguments": {"location": "Tokyo"}}][/TOOL_CALLS]`},
content: "",
tmpl: mistral,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "London",
"format": "fahrenheit",
},
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "two tool calls",
inputs: []string{`Okay, let's call both tools! <tool_call>{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}</tool_call>`},
content: "Okay, let's call both tools! ",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "London",
"format": "fahrenheit",
},
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "deepseek",
inputs: []string{"<think>Wait, I need to call a tool</think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"city\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"},
content: "<think>Wait, I need to call a tool</think>",
tmpl: deepseek,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "Tokyo",
},
},
},
},
},
{
name: "deepseek incremental",
inputs: []string{
"<think>Wait",
", I need",
" to call",
" a tool</think><|too",
"l▁calls▁begin",
"|>",
"<|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n",
"```json\n",
"{\"city\": \"Tokyo\"}\n",
"```",
"<|tool▁c", "all▁end|>",
"<|tool▁calls▁end|>",
"<|end▁of▁sentence|>",
},
content: "<think>Wait, I need to call a tool</think>",
tmpl: deepseek,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "Tokyo",
},
},
},
},
},
{
name: "json",
inputs: []string{
"{",
"\"name\": \"get_temperature\",",
"\"arguments\": {",
"\"city\": \"Tokyo\"",
"}",
"}",
},
content: "",
tmpl: json,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "Tokyo",
},
},
},
},
},
{
name: "json maybe a tool call",
inputs: []string{
"{",
"\"name\": \"get_temperature\",",
"\"arguments\": {",
},
content: "",
tmpl: json,
calls: nil,
},
{
name: "json not a tool call",
inputs: []string{
"{",
"\"name\": \"search\", ",
"\"arguments\": {",
"\"query\": \"What is the capital of Canada?\"",
"}",
"}",
},
content: "{\"name\": \"search\", \"arguments\": {\"query\": \"What is the capital of Canada?\"}}",
tmpl: json,
calls: nil,
},
{
name: "json object followed by tool call",
inputs: []string{
"{\"name\": \"jeff\"}",
"{\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}",
},
content: "{\"name\": \"jeff\"}{\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}",
tmpl: json,
},
{
name: "json object followed by tool call split",
inputs: []string{
"{\"name\": \"jeff\"} {",
"\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}",
},
content: "{\"name\": \"jeff\"} {\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}",
tmpl: json,
},
{
name: "json code",
inputs: []string{
"for { fmt.Println(\"hello\") }",
},
content: "for { fmt.Println(\"hello\") }",
tmpl: json,
},
{
name: "list multiple",
inputs: []string{
"[",
"{",
"\"name\": \"get_temperature\", ",
"\"arguments\": {",
"\"city\": \"London\"",
"}",
"},",
"{",
"\"name\": \"get_conditions\", ",
"\"arguments\": {",
"\"location\": \"Tokyo\"",
"}",
"}]",
},
content: "",
tmpl: list,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "London",
},
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "list partial",
inputs: []string{
"[",
"{",
"\"name\": \"search\", ",
"\"arguments\": {",
"\"query\": \"What is the capital of Canada?\"",
"}",
"}",
},
content: "",
tmpl: list,
calls: nil,
},
{
name: "list not a tool call",
inputs: []string{
"[special",
" del",
"ivery]",
},
content: "[special delivery]",
tmpl: list,
calls: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := NewParser(tt.tmpl, tools)
var calls []api.ToolCall
var content string
for _, input := range tt.inputs {
tcs, c := parser.Add(input)
calls = append(calls, tcs...)
content += c
}
if content != tt.content {
t.Errorf("Expected content %q, got %q", tt.content, content)
}
if len(calls) != len(tt.calls) {
t.Fatalf("Expected %d tool calls, got %d", len(tt.calls), len(calls))
}
for i, want := range tt.calls {
if diff := cmp.Diff(calls[i], want); diff != "" {
t.Errorf("Tool call %d mismatch (-got +want):\n%s", i, diff)
}
}
})
}
}
func TestDone(t *testing.T) {
tests := []struct {
name string
tag string
buffer []byte
want bool
}{
{
name: "empty",
tag: "<tool_call>",
buffer: []byte{},
want: false,
},
{
name: "empty",
tag: "<tool_call>",
buffer: []byte{},
want: false,
},
{
name: "json open",
tag: "{",
buffer: []byte("{\"name\": \"get_weather\""),
want: false,
},
{
name: "json closed",
tag: "{",
buffer: []byte("{\"name\": \"get_weather\"}"),
want: true,
},
{
name: "json empty",
tag: "{",
buffer: []byte("{}"),
want: true,
},
{
name: "list open",
tag: "[",
buffer: []byte("[{\"name\": \"get_weather\""),
want: false,
},
{
name: "list closed",
tag: "[",
buffer: []byte("[{\"name\": \"get_weather\"}]"),
want: true,
},
{
name: "list empty",
tag: "[",
buffer: []byte("[]"),
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &Parser{
tag: tt.tag,
buffer: tt.buffer,
}
got := parser.done()
if got != tt.want {
t.Errorf("done() = %t, want %t", got, tt.want)
}
})
}
}
func TestContent(t *testing.T) {
tests := []struct {
name string
tag string
content []byte
want string
n int
}{
{
name: "empty",
content: []byte{},
tag: "{",
want: "",
n: 0,
},
{
name: "tag",
tag: "<tool_call>",
content: []byte("<tool_call>{\"name\": \"get_temperature\""),
want: "",
n: 0,
},
{
name: "json object",
tag: "{",
content: []byte("{\"name\": \"get_temperature\"}"),
want: "{\"name\": \"get_temperature\"}",
n: 0,
},
{
name: "json object after called",
tag: "{",
content: []byte("{\"hello\": \"world\"}"),
want: "{\"hello\": \"world\"}",
n: 0,
},
{
name: "json object after called",
tag: "{",
content: []byte("{\"hello\": \"world\"}"),
want: "",
n: 1,
},
{
name: "list",
tag: "[",
content: []byte("[{\"name\": \"get_temperature\"}]"),
want: "[{\"name\": \"get_temperature\"}]",
n: 0,
},
{
name: "code",
tag: "{",
content: []byte("{ fmt.Println(\"hello\")"),
want: "{ fmt.Println(\"hello\")",
n: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &Parser{
tag: tt.tag,
buffer: tt.content,
n: tt.n,
}
got := parser.Content()
if got != tt.want {
t.Errorf("Content() = %q, want %q", got, tt.want)
}
})
}
}
func TestFindTag(t *testing.T) {
cases := []struct {
name string
buffer []byte
tag string
i int
found bool
}{
{
name: "no overlap",
buffer: []byte("hello world"),
tag: "<tool_call>",
i: -1,
found: false,
},
{
name: "full overlap",
buffer: []byte("<tool_call>"),
tag: "<tool_call>",
i: 0,
found: true,
},
{
name: "whitespace",
buffer: []byte(" <tool_call>\n {\"name\": \"bob\"}"),
tag: "<tool_call>",
i: 4,
found: true,
},
{
name: "over",
buffer: []byte("<tool_call>{\"name\""),
tag: "<tool_call>",
i: 0,
found: true,
},
{
name: "partial overlap",
buffer: []byte("text <tool_call>"),
tag: "<tool_call>",
i: 5,
found: true,
},
{
name: "overlap with extra",
buffer: []byte("<tool_calls><tool_call>"),
tag: "<tool_calls>",
i: 0,
found: true,
},
{
name: "delimiter longer than string",
buffer: []byte("<tool>"),
tag: "<tool_call>",
i: -1,
found: false,
},
{
name: "empty string",
buffer: []byte{},
tag: "<tool_call>",
i: -1,
found: false,
},
{
name: "single char overlap",
buffer: []byte("test<"),
tag: "<tool_call>",
i: 4,
found: false,
},
{
name: "partial tool call",
buffer: []byte("hello <tool_"),
tag: "<tool_call>",
i: 6,
found: false,
},
{
name: "square bracket",
buffer: []byte("calling tools: ["),
tag: "[",
i: 15,
found: true,
},
{
name: "bracket",
buffer: []byte("{\"name\": \"bob\""),
tag: "{",
i: 0,
found: true,
},
{
name: "bracket with whitespace",
buffer: []byte("\n\n{\n\"name\": \"bob\""),
tag: "{",
i: 2,
found: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
parser := &Parser{
tag: tt.tag,
buffer: tt.buffer,
n: 0,
}
i, found := parser.findTag()
if i != tt.i {
t.Errorf("findTag(%q, %q) = %d; want %d", tt.buffer, tt.tag, i, tt.i)
}
if found != tt.found {
t.Errorf("findTag(%q, %q) = %t; want %t", tt.buffer, tt.tag, found, tt.found)
}
})
}
}
func TestFindArguments(t *testing.T) {
tests := []struct {
name string
buffer []byte
want map[string]any
}{
{
name: "empty string",
buffer: []byte{},
want: nil,
},
{
name: "whitespace only",
buffer: []byte(" \n\t "),
want: nil,
},
{
name: "unbalanced braces - missing closing",
buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`),
want: nil,
},
{
name: "unbalanced braces - extra closing",
buffer: []byte(`{"format": "fahrenheit"}}`),
want: map[string]any{
"format": "fahrenheit",
},
},
{
name: "invalid JSON",
buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`),
want: nil,
},
{
name: "valid json",
buffer: []byte(`{"name": "get_temperature", "arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`),
want: map[string]any{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
{
name: "valid arguments with special tokens",
buffer: []byte(`[tool]get_temperature[args]{"format": "fahrenheit", "location": "San Francisco, CA"}[end]`),
want: map[string]any{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
{
name: "valid arguments in array",
buffer: []byte(`[{"arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`),
want: map[string]any{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
{
name: "nested deep",
buffer: []byte(`{"function": {"name": "get_temperature", "arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}}`),
want: map[string]any{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
{
name: "one arg",
buffer: []byte(`get_weather({"location": "San Francisco, CA"})`),
want: map[string]any{
"location": "San Francisco, CA",
},
},
{
name: "two args",
buffer: []byte(`[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`),
want: map[string]any{
"location": "San Francisco, CA",
"format": "fahrenheit",
},
},
{
name: "deepseek",
buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"),
want: map[string]any{
"location": "Tokyo",
},
},
}
for _, tt := range tests {
parser := &Parser{
buffer: tt.buffer,
properties: []string{"format", "location"},
}
t.Run(tt.name, func(t *testing.T) {
got, _ := parser.findArguments()
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -8,6 +8,7 @@ const (
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
)
func (c Capability) String() string {