Compare commits

...

18 Commits

Author SHA1 Message Date
Bruce MacDonald
365a3657ad fix test home on windows 2026-02-18 18:37:24 -08:00
Bruce MacDonald
71c1d8d0a9 cmd: ollama update
Add interactive update check to CLI TUI and `ollama update` command

On TUI launch, check for updates in the background and cache the result as a marker file (~/.ollama/update). On the next launch, if a cached update exists, print a one-line notice before the TUI starts. The check is skipped for dev builds (0.0.0), alternative installs (e.g. brew, choco), and remote Ollama hosts.

Add `ollama update` subcommand that downloads and runs the platform-appropriate install script (install.sh on Linux/macOS, install.ps1 on Windows). Refuses to run if the binary wasn't installed via official channels unless --force is passed.

Co-Authored-By: RajeshKumar11 <22585507+rajeshkumar11@users.noreply.github.com>
2026-02-18 18:21:17 -08:00
Parth Sareen
325b72bc31 cmd/tui: default to single-select for editor integrations (#14302) 2026-02-17 18:17:27 -08:00
Patrick Devine
f01a9a7859 chore: update mlx-c bindings to 0.5.0 (#14303) 2026-02-17 16:48:16 -08:00
Patrick Devine
9aefd2dfee model: add qwen3 support to mlxrunner (#14293) 2026-02-17 13:58:49 -08:00
Patrick Devine
d07e4a1dd3 bugfix: better mlx model scheduling (#14290)
This fixes a bug with current MLX based models which don't get loaded/unloaded correctly. The first model currently gets loaded and then subsequent model starts get shunted to the first runner which results in the wrong model being run.
2026-02-17 13:57:05 -08:00
Parth Sareen
8a257ec00a docs: make integrations more discoverable (#14301)
* docs: add Pi integration page

* docs: flatten integration sidebar with expanded subheadings

* docs: add OpenClaw and Claude Code to quickstart
2026-02-17 13:27:25 -08:00
Parth Sareen
2f4de1acf7 cmd: ollama launch always show model picker (#14299) 2026-02-17 12:02:14 -08:00
Parth Sareen
ec95c45f70 cmd/config: ollama launch cline CLI (#14294) 2026-02-17 11:37:53 -08:00
Patrick Devine
3a88f7eb20 bugfix: add missing linear layer factory (#14289) 2026-02-16 17:22:20 -08:00
Patrick Devine
0d5da826d4 bugfix: display the parameter count correctly in mlx for ollama show (#14285) 2026-02-16 13:03:34 -08:00
Patrick Devine
9b795698b8 model: add llama3 architecture to mlxrunner (#14277) 2026-02-15 23:06:28 -08:00
Patrick Devine
041fb77639 model: add gemma3 to the mlxrunner (#14276)
This change adds the gemma3 model to the mlxrunner and simplifies some of the quantization
code for loading weights.
2026-02-15 22:47:59 -08:00
Saumil Shah
8224cce583 readme: update download link for macOS (#1) (#14271) 2026-02-15 15:25:15 -08:00
Patrick Devine
d18dcd7775 mlxrunner fixes (#14247)
* load glm4_moe_lite from the mlxrunner

* fix loading diffusion models

* remove log lines

* fix --imagegen flag
2026-02-13 22:30:42 -08:00
Parth Sareen
5f5ef20131 anthropic: enable websearch (#14246) 2026-02-13 19:20:46 -08:00
Parth Sareen
f0a07a353b cmd/tui: fix powershell search (#14242) 2026-02-13 15:53:11 -08:00
Devon Rifkin
948de6bbd2 add ability to disable cloud (#14221)
* add ability to disable cloud

Users can now easily opt-out of cloud inference and web search by
setting

```
"disable_ollama_cloud": true
```

in their `~/.ollama/server.json` settings file. After a setting update,
the server must be restarted.

Alternatively, setting the environment variable `OLLAMA_NO_CLOUD=1` will
also disable cloud features. While users previously were able to avoid
cloud models by not pulling or `ollama run`ing them, this gives them an
easy way to enforce that decision. Any attempt to run a cloud model when
cloud is disabled will fail.

The app's old "airplane mode" setting, which did a similar thing for
hiding cloud models within the app is now unified with this new cloud
disabled mode. That setting has been replaced with a "Cloud" toggle,
which behind the scenes edits `server.json` and then restarts the
server.

* gate cloud models across TUI and launch flows when cloud is disabled

Block cloud models from being selected, launched, or written to
integration configs when cloud mode is turned off:

- TUI main menu: open model picker instead of launching with a
  disabled cloud model
- cmd.go: add IsCloudModelDisabled checks for all Selection* paths
- LaunchCmd: filter cloud models from saved Editor configs before
  launch, fall through to picker if none remain
- Editor Run() methods (droid, opencode, openclaw): filter cloud
  models before calling Edit() and persist the cleaned list
- Export SaveIntegration, remove SaveIntegrationModel wrapper that
  was accumulating models instead of replacing them

* rename saveIntegration to SaveIntegration in config.go and tests

* cmd/config: add --model guarding and empty model list fixes

* Update docs/faq.mdx

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* Update internal/cloud/policy.go

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* Update internal/cloud/policy.go

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* Update server/routes.go

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* Revert "Update internal/cloud/policy.go"

This reverts commit 8bff8615f9.

Since this error shows up in other integrations, we want it to be
prefixed with Ollama

* rename cloud status

* more status renaming

* fix tests that weren't updated after rename

---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2026-02-12 15:47:00 -08:00
113 changed files with 11808 additions and 1505 deletions

View File

@@ -1 +1 @@
v0.4.1
v0.5.0

View File

@@ -16,7 +16,7 @@ Start building with open models.
curl -fsSL https://ollama.com/install.sh | sh
```
or [download manually](http://localhost:8080/download/Ollama.dmg)
or [download manually](https://ollama.com/download/Ollama.dmg)
### Windows

View File

@@ -1,17 +1,25 @@
package anthropic
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/logutil"
)
// Error types matching Anthropic API
@@ -82,22 +90,25 @@ type MessageParam struct {
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
// only when set, which is required for SDK streaming accumulation.
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
Type string `json:"type"` // text, image, tool_use, tool_result, thinking, server_tool_use, web_search_tool_result
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
Text *string `json:"text,omitempty"`
// For text blocks with citations
Citations []Citation `json:"citations,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
// For tool_use and server_tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
// For tool_result and web_search_tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
Content any `json:"content,omitempty"` // string, []ContentBlock, []WebSearchResult, or WebSearchToolResultError
IsError bool `json:"is_error,omitempty"`
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
@@ -105,6 +116,30 @@ type ContentBlock struct {
Signature string `json:"signature,omitempty"`
}
// Citation represents a citation in a text block
type Citation struct {
Type string `json:"type"` // "web_search_result_location"
URL string `json:"url"`
Title string `json:"title"`
EncryptedIndex string `json:"encrypted_index,omitempty"`
CitedText string `json:"cited_text,omitempty"`
}
// WebSearchResult represents a single web search result
type WebSearchResult struct {
Type string `json:"type"` // "web_search_result"
URL string `json:"url"`
Title string `json:"title"`
EncryptedContent string `json:"encrypted_content,omitempty"`
PageAge string `json:"page_age,omitempty"`
}
// WebSearchToolResultError represents an error from web search
type WebSearchToolResultError struct {
Type string `json:"type"` // "web_search_tool_result_error"
ErrorCode string `json:"error_code"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
@@ -115,10 +150,13 @@ type ImageSource struct {
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Type string `json:"type,omitempty"` // "custom" for user-defined tools, or "web_search_20250305" for web search
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
// Web search specific fields
MaxUses int `json:"max_uses,omitempty"`
}
// ToolChoice controls how the model uses tools
@@ -233,6 +271,8 @@ type StreamErrorEvent struct {
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
logutil.Trace("anthropic: converting request", "req", TraceMessagesRequest(r))
var messages []api.Message
if r.System != nil {
@@ -259,9 +299,10 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
}
}
for _, msg := range r.Messages {
for i, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
logutil.Trace("anthropic: message conversion failed", "index", i, "role", msg.Role, "err", err)
return nil, err
}
messages = append(messages, converted...)
@@ -288,8 +329,24 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
}
var tools api.Tools
hasBuiltinWebSearch := false
for _, t := range r.Tools {
tool, err := convertTool(t)
if strings.HasPrefix(t.Type, "web_search") {
hasBuiltinWebSearch = true
break
}
}
for _, t := range r.Tools {
// Anthropic built-in web_search maps to Ollama function name "web_search".
// If a user-defined tool also uses that name in the same request, drop the
// user-defined one to avoid ambiguous tool-call routing.
if hasBuiltinWebSearch && !strings.HasPrefix(t.Type, "web_search") && t.Name == "web_search" {
logutil.Trace("anthropic: dropping colliding custom web_search tool", "tool", TraceTool(t))
continue
}
tool, _, err := convertTool(t)
if err != nil {
return nil, err
}
@@ -302,15 +359,17 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
}
stream := r.Stream
return &api.ChatRequest{
convertedRequest := &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
logutil.Trace("anthropic: converted request", "req", TraceChatRequest(convertedRequest))
return convertedRequest, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
@@ -328,10 +387,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
textBlocks := 0
imageBlocks := 0
toolUseBlocks := 0
toolResultBlocks := 0
serverToolUseBlocks := 0
webSearchToolResultBlocks := 0
thinkingBlocks := 0
unknownBlocks := 0
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
logutil.Trace("anthropic: invalid content block format", "role", role)
return nil, errors.New("invalid content block format")
}
@@ -339,13 +407,16 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
switch blockType {
case "text":
textBlocks++
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
imageBlocks++
source, ok := blockMap["source"].(map[string]any)
if !ok {
logutil.Trace("anthropic: invalid image source", "role", role)
return nil, errors.New("invalid image source")
}
@@ -354,21 +425,26 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use":
toolUseBlocks++
id, ok := blockMap["id"].(string)
if !ok {
logutil.Trace("anthropic: tool_use block missing id", "role", role)
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
logutil.Trace("anthropic: tool_use block missing name", "role", role)
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
@@ -383,6 +459,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
toolCalls = append(toolCalls, tc)
case "tool_result":
toolResultBlocks++
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
@@ -408,9 +485,36 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
})
case "thinking":
thinkingBlocks++
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
case "server_tool_use":
serverToolUseBlocks++
id, _ := blockMap["id"].(string)
name, _ := blockMap["name"].(string)
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "web_search_tool_result":
webSearchToolResultBlocks++
toolUseID, _ := blockMap["tool_use_id"].(string)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: formatWebSearchToolResultContent(blockMap["content"]),
ToolCallID: toolUseID,
})
default:
unknownBlocks++
}
}
@@ -427,6 +531,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
// Add tool results as separate messages
messages = append(messages, toolResults...)
logutil.Trace("anthropic: converted block message",
"role", role,
"blocks", len(content),
"text", textBlocks,
"image", imageBlocks,
"tool_use", toolUseBlocks,
"tool_result", toolResultBlocks,
"server_tool_use", serverToolUseBlocks,
"web_search_result", webSearchToolResultBlocks,
"thinking", thinkingBlocks,
"unknown", unknownBlocks,
"messages", TraceAPIMessages(messages),
)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
@@ -435,12 +552,94 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
func formatWebSearchToolResultContent(content any) string {
switch c := content.(type) {
case string:
return c
case []WebSearchResult:
var resultContent strings.Builder
for _, item := range c {
if item.Type != "web_search_result" {
continue
}
fmt.Fprintf(&resultContent, "- %s: %s\n", item.Title, item.URL)
}
return resultContent.String()
case []any:
var resultContent strings.Builder
for _, item := range c {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
switch itemMap["type"] {
case "web_search_result":
title, _ := itemMap["title"].(string)
url, _ := itemMap["url"].(string)
fmt.Fprintf(&resultContent, "- %s: %s\n", title, url)
case "web_search_tool_result_error":
errorCode, _ := itemMap["error_code"].(string)
if errorCode == "" {
return "web_search_tool_result_error"
}
return "web_search_tool_result_error: " + errorCode
}
}
return resultContent.String()
case map[string]any:
if c["type"] == "web_search_tool_result_error" {
errorCode, _ := c["error_code"].(string)
if errorCode == "" {
return "web_search_tool_result_error"
}
return "web_search_tool_result_error: " + errorCode
}
data, err := json.Marshal(c)
if err != nil {
return ""
}
return string(data)
case WebSearchToolResultError:
if c.ErrorCode == "" {
return "web_search_tool_result_error"
}
return "web_search_tool_result_error: " + c.ErrorCode
default:
data, err := json.Marshal(c)
if err != nil {
return ""
}
return string(data)
}
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool, returning true if it's a server tool
func convertTool(t Tool) (api.Tool, bool, error) {
if strings.HasPrefix(t.Type, "web_search") {
props := api.NewToolPropertiesMap()
props.Set("query", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The search query to look up on the web",
})
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: "web_search",
Description: "Search the web for current information. Use this to find up-to-date information about any topic.",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"query"},
Properties: props,
},
},
}, true, nil
}
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
logutil.Trace("anthropic: invalid tool schema", "tool", t.Name, "err", err)
return api.Tool{}, false, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
@@ -451,7 +650,7 @@ func convertTool(t Tool) (api.Tool, error) {
Description: t.Description,
Parameters: params,
},
}, nil
}, false, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
@@ -899,3 +1098,113 @@ func countContentBlock(block any) int {
return total
}
// OllamaWebSearchRequest represents a request to the Ollama web search API
type OllamaWebSearchRequest struct {
Query string `json:"query"`
MaxResults int `json:"max_results,omitempty"`
}
// OllamaWebSearchResult represents a single search result from Ollama API
type OllamaWebSearchResult struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
}
// OllamaWebSearchResponse represents the response from the Ollama web search API
type OllamaWebSearchResponse struct {
Results []OllamaWebSearchResult `json:"results"`
}
var WebSearchEndpoint = "https://ollama.com/api/web_search"
func WebSearch(ctx context.Context, query string, maxResults int) (*OllamaWebSearchResponse, error) {
if internalcloud.Disabled() {
logutil.TraceContext(ctx, "anthropic: web search blocked", "reason", "cloud_disabled")
return nil, errors.New(internalcloud.DisabledError("web search is unavailable"))
}
if maxResults <= 0 {
maxResults = 5
}
if maxResults > 10 {
maxResults = 10
}
reqBody := OllamaWebSearchRequest{
Query: query,
MaxResults: maxResults,
}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal web search request: %w", err)
}
searchURL, err := url.Parse(WebSearchEndpoint)
if err != nil {
return nil, fmt.Errorf("failed to parse web search URL: %w", err)
}
logutil.TraceContext(ctx, "anthropic: web search request",
"query", TraceTruncateString(query),
"max_results", maxResults,
"url", searchURL.String(),
)
q := searchURL.Query()
q.Set("ts", strconv.FormatInt(time.Now().Unix(), 10))
searchURL.RawQuery = q.Encode()
signature := ""
if strings.EqualFold(searchURL.Hostname(), "ollama.com") {
challenge := fmt.Sprintf("%s,%s", http.MethodPost, searchURL.RequestURI())
signature, err = auth.Sign(ctx, []byte(challenge))
if err != nil {
return nil, fmt.Errorf("failed to sign web search request: %w", err)
}
}
logutil.TraceContext(ctx, "anthropic: web search auth", "signed", signature != "")
req, err := http.NewRequestWithContext(ctx, "POST", searchURL.String(), bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create web search request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if signature != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("web search request failed: %w", err)
}
defer resp.Body.Close()
logutil.TraceContext(ctx, "anthropic: web search response", "status", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("web search returned status %d: %s", resp.StatusCode, string(respBody))
}
var searchResp OllamaWebSearchResponse
if err := json.NewDecoder(resp.Body).Decode(&searchResp); err != nil {
return nil, fmt.Errorf("failed to decode web search response: %w", err)
}
logutil.TraceContext(ctx, "anthropic: web search results", "count", len(searchResp.Results))
return &searchResp, nil
}
func ConvertOllamaToAnthropicResults(ollamaResults *OllamaWebSearchResponse) []WebSearchResult {
var results []WebSearchResult
for _, r := range ollamaResults.Results {
results = append(results, WebSearchResult{
Type: "web_search_result",
URL: r.URL,
Title: r.Title,
})
}
return results
}

View File

@@ -3,6 +3,7 @@ package anthropic
import (
"encoding/base64"
"encoding/json"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
@@ -300,6 +301,78 @@ func TestFromMessagesRequest_WithTools(t *testing.T) {
}
}
func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Type: "web_search_20250305",
Name: "web_search",
},
{
Type: "custom",
Name: "web_search",
Description: "User-defined web search that should be dropped",
InputSchema: json.RawMessage(`{"type":"invalid"}`),
},
{
Type: "custom",
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 2 {
t.Fatalf("expected 2 tools after dropping custom web_search, got %d", len(result.Tools))
}
if result.Tools[0].Function.Name != "web_search" {
t.Fatalf("expected first tool to be built-in web_search, got %q", result.Tools[0].Function.Name)
}
if result.Tools[1].Function.Name != "get_weather" {
t.Fatalf("expected second tool to be get_weather, got %q", result.Tools[1].Function.Name)
}
}
func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Type: "custom",
Name: "web_search",
Description: "User-defined web search",
InputSchema: json.RawMessage(`{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 custom tool, got %d", len(result.Tools))
}
if result.Tools[0].Function.Name != "web_search" {
t.Fatalf("expected custom tool name web_search, got %q", result.Tools[0].Function.Name)
}
if result.Tools[0].Function.Description != "User-defined web search" {
t.Fatalf("expected custom description preserved, got %q", result.Tools[0].Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
@@ -1063,3 +1136,320 @@ func TestEstimateTokens_EmptyContent(t *testing.T) {
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
}
}
// Web Search Tests
func TestConvertTool_WebSearch(t *testing.T) {
tool := Tool{
Type: "web_search_20250305",
Name: "web_search",
MaxUses: 5,
}
result, isServerTool, err := convertTool(tool)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !isServerTool {
t.Error("expected isServerTool to be true for web_search tool")
}
if result.Type != "function" {
t.Errorf("expected type 'function', got %q", result.Type)
}
if result.Function.Name != "web_search" {
t.Errorf("expected name 'web_search', got %q", result.Function.Name)
}
if result.Function.Description == "" {
t.Error("expected non-empty description for web_search tool")
}
// Check that query parameter is defined
if result.Function.Parameters.Properties == nil {
t.Fatal("expected properties to be defined")
}
queryProp, ok := result.Function.Parameters.Properties.Get("query")
if !ok {
t.Error("expected 'query' property to be defined")
}
if len(queryProp.Type) == 0 || queryProp.Type[0] != "string" {
t.Errorf("expected query type to be 'string', got %v", queryProp.Type)
}
}
func TestConvertTool_RegularTool(t *testing.T) {
tool := Tool{
Type: "custom",
Name: "get_weather",
Description: "Get the weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
}
result, isServerTool, err := convertTool(tool)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if isServerTool {
t.Error("expected isServerTool to be false for regular tool")
}
if result.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Function.Name)
}
}
func TestConvertMessage_ServerToolUse(t *testing.T) {
msg := MessageParam{
Role: "assistant",
Content: []any{
map[string]any{
"type": "server_tool_use",
"id": "srvtoolu_123",
"name": "web_search",
"input": map[string]any{"query": "test query"},
},
},
}
messages, err := convertMessage(msg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(messages))
}
if len(messages[0].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(messages[0].ToolCalls))
}
tc := messages[0].ToolCalls[0]
if tc.ID != "srvtoolu_123" {
t.Errorf("expected tool call ID 'srvtoolu_123', got %q", tc.ID)
}
if tc.Function.Name != "web_search" {
t.Errorf("expected tool name 'web_search', got %q", tc.Function.Name)
}
}
func TestConvertMessage_WebSearchToolResult(t *testing.T) {
msg := MessageParam{
Role: "user",
Content: []any{
map[string]any{
"type": "web_search_tool_result",
"tool_use_id": "srvtoolu_123",
"content": []any{
map[string]any{
"type": "web_search_result",
"title": "Test Result",
"url": "https://example.com",
},
},
},
},
}
messages, err := convertMessage(msg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should have a tool result message
if len(messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(messages))
}
if messages[0].Role != "tool" {
t.Errorf("expected role 'tool', got %q", messages[0].Role)
}
if messages[0].ToolCallID != "srvtoolu_123" {
t.Errorf("expected tool_call_id 'srvtoolu_123', got %q", messages[0].ToolCallID)
}
if messages[0].Content == "" {
t.Error("expected non-empty content from web search results")
}
}
func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) {
msg := MessageParam{
Role: "user",
Content: []any{
map[string]any{
"type": "web_search_tool_result",
"tool_use_id": "srvtoolu_empty",
"content": []any{},
},
},
}
messages, err := convertMessage(msg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(messages))
}
if messages[0].Role != "tool" {
t.Fatalf("expected role tool, got %q", messages[0].Role)
}
if messages[0].ToolCallID != "srvtoolu_empty" {
t.Fatalf("expected tool_call_id srvtoolu_empty, got %q", messages[0].ToolCallID)
}
if messages[0].Content != "" {
t.Fatalf("expected empty content for empty web search results, got %q", messages[0].Content)
}
}
func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) {
msg := MessageParam{
Role: "user",
Content: []any{
map[string]any{
"type": "web_search_tool_result",
"tool_use_id": "srvtoolu_error",
"content": map[string]any{
"type": "web_search_tool_result_error",
"error_code": "max_uses_exceeded",
},
},
},
}
messages, err := convertMessage(msg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(messages))
}
if messages[0].Role != "tool" {
t.Fatalf("expected role tool, got %q", messages[0].Role)
}
if messages[0].ToolCallID != "srvtoolu_error" {
t.Fatalf("expected tool_call_id srvtoolu_error, got %q", messages[0].ToolCallID)
}
if !strings.Contains(messages[0].Content, "max_uses_exceeded") {
t.Fatalf("expected error code in converted tool content, got %q", messages[0].Content)
}
}
func TestConvertOllamaToAnthropicResults(t *testing.T) {
ollamaResp := &OllamaWebSearchResponse{
Results: []OllamaWebSearchResult{
{
Title: "Test Title",
URL: "https://example.com",
Content: "Test content",
},
{
Title: "Another Result",
URL: "https://example.org",
Content: "More content",
},
},
}
results := ConvertOllamaToAnthropicResults(ollamaResp)
if len(results) != 2 {
t.Fatalf("expected 2 results, got %d", len(results))
}
if results[0].Type != "web_search_result" {
t.Errorf("expected type 'web_search_result', got %q", results[0].Type)
}
if results[0].Title != "Test Title" {
t.Errorf("expected title 'Test Title', got %q", results[0].Title)
}
if results[0].URL != "https://example.com" {
t.Errorf("expected URL 'https://example.com', got %q", results[0].URL)
}
}
func TestWebSearchTypes(t *testing.T) {
// Test that WebSearchResult serializes correctly
result := WebSearchResult{
Type: "web_search_result",
URL: "https://example.com",
Title: "Test",
EncryptedContent: "abc123",
PageAge: "2025-01-01",
}
data, err := json.Marshal(result)
if err != nil {
t.Fatalf("failed to marshal WebSearchResult: %v", err)
}
var unmarshaled WebSearchResult
if err := json.Unmarshal(data, &unmarshaled); err != nil {
t.Fatalf("failed to unmarshal WebSearchResult: %v", err)
}
if unmarshaled.Type != result.Type {
t.Errorf("type mismatch: expected %q, got %q", result.Type, unmarshaled.Type)
}
// Test WebSearchToolResultError
errResult := WebSearchToolResultError{
Type: "web_search_tool_result_error",
ErrorCode: "max_uses_exceeded",
}
data, err = json.Marshal(errResult)
if err != nil {
t.Fatalf("failed to marshal WebSearchToolResultError: %v", err)
}
var unmarshaledErr WebSearchToolResultError
if err := json.Unmarshal(data, &unmarshaledErr); err != nil {
t.Fatalf("failed to unmarshal WebSearchToolResultError: %v", err)
}
if unmarshaledErr.ErrorCode != "max_uses_exceeded" {
t.Errorf("error_code mismatch: expected 'max_uses_exceeded', got %q", unmarshaledErr.ErrorCode)
}
}
func TestCitation(t *testing.T) {
citation := Citation{
Type: "web_search_result_location",
URL: "https://example.com",
Title: "Example",
EncryptedIndex: "enc123",
CitedText: "Some cited text...",
}
data, err := json.Marshal(citation)
if err != nil {
t.Fatalf("failed to marshal Citation: %v", err)
}
var unmarshaled Citation
if err := json.Unmarshal(data, &unmarshaled); err != nil {
t.Fatalf("failed to unmarshal Citation: %v", err)
}
if unmarshaled.Type != "web_search_result_location" {
t.Errorf("type mismatch: expected 'web_search_result_location', got %q", unmarshaled.Type)
}
if unmarshaled.CitedText != "Some cited text..." {
t.Errorf("cited_text mismatch: expected 'Some cited text...', got %q", unmarshaled.CitedText)
}
}

352
anthropic/trace.go Normal file
View File

@@ -0,0 +1,352 @@
package anthropic
import (
"encoding/json"
"fmt"
"sort"
"github.com/ollama/ollama/api"
)
// Trace truncation limits.
const (
TraceMaxStringRunes = 240
TraceMaxSliceItems = 8
TraceMaxMapEntries = 16
TraceMaxDepth = 4
)
// TraceTruncateString shortens s to TraceMaxStringRunes, appending a count of
// omitted characters when truncated.
func TraceTruncateString(s string) string {
if len(s) == 0 {
return s
}
runes := []rune(s)
if len(runes) <= TraceMaxStringRunes {
return s
}
return fmt.Sprintf("%s...(+%d chars)", string(runes[:TraceMaxStringRunes]), len(runes)-TraceMaxStringRunes)
}
// TraceJSON round-trips v through JSON and returns a compacted representation.
func TraceJSON(v any) any {
if v == nil {
return nil
}
data, err := json.Marshal(v)
if err != nil {
return map[string]any{"marshal_error": err.Error(), "type": fmt.Sprintf("%T", v)}
}
var out any
if err := json.Unmarshal(data, &out); err != nil {
return TraceTruncateString(string(data))
}
return TraceCompactValue(out, 0)
}
// TraceCompactValue recursively truncates strings, slices, and maps for trace
// output. depth tracks recursion to enforce TraceMaxDepth.
func TraceCompactValue(v any, depth int) any {
if v == nil {
return nil
}
if depth >= TraceMaxDepth {
switch t := v.(type) {
case string:
return TraceTruncateString(t)
case []any:
return fmt.Sprintf("<array len=%d>", len(t))
case map[string]any:
return fmt.Sprintf("<object keys=%d>", len(t))
default:
return fmt.Sprintf("<%T>", v)
}
}
switch t := v.(type) {
case string:
return TraceTruncateString(t)
case []any:
limit := min(len(t), TraceMaxSliceItems)
out := make([]any, 0, limit+1)
for i := range limit {
out = append(out, TraceCompactValue(t[i], depth+1))
}
if len(t) > limit {
out = append(out, fmt.Sprintf("... +%d more items", len(t)-limit))
}
return out
case map[string]any:
keys := make([]string, 0, len(t))
for k := range t {
keys = append(keys, k)
}
sort.Strings(keys)
limit := min(len(keys), TraceMaxMapEntries)
out := make(map[string]any, limit+1)
for i := range limit {
out[keys[i]] = TraceCompactValue(t[keys[i]], depth+1)
}
if len(keys) > limit {
out["__truncated_keys"] = len(keys) - limit
}
return out
default:
return t
}
}
// ---------------------------------------------------------------------------
// Anthropic request/response tracing
// ---------------------------------------------------------------------------
// TraceMessagesRequest returns a compact trace representation of a MessagesRequest.
func TraceMessagesRequest(r MessagesRequest) map[string]any {
return map[string]any{
"model": r.Model,
"max_tokens": r.MaxTokens,
"messages": traceMessageParams(r.Messages),
"system": traceAnthropicContent(r.System),
"stream": r.Stream,
"tools": traceTools(r.Tools),
"tool_choice": TraceJSON(r.ToolChoice),
"thinking": TraceJSON(r.Thinking),
"stop_sequences": r.StopSequences,
"temperature": ptrVal(r.Temperature),
"top_p": ptrVal(r.TopP),
"top_k": ptrVal(r.TopK),
}
}
// TraceMessagesResponse returns a compact trace representation of a MessagesResponse.
func TraceMessagesResponse(r MessagesResponse) map[string]any {
return map[string]any{
"id": r.ID,
"model": r.Model,
"content": TraceJSON(r.Content),
"stop_reason": r.StopReason,
"usage": r.Usage,
}
}
func traceMessageParams(msgs []MessageParam) []map[string]any {
out := make([]map[string]any, 0, len(msgs))
for _, m := range msgs {
out = append(out, map[string]any{
"role": m.Role,
"content": traceAnthropicContent(m.Content),
})
}
return out
}
func traceAnthropicContent(content any) any {
switch c := content.(type) {
case nil:
return nil
case string:
return TraceTruncateString(c)
case []any:
blocks := make([]any, 0, len(c))
for _, block := range c {
blockMap, ok := block.(map[string]any)
if !ok {
blocks = append(blocks, TraceCompactValue(block, 0))
continue
}
blocks = append(blocks, traceAnthropicBlock(blockMap))
}
return blocks
default:
return TraceJSON(c)
}
}
func traceAnthropicBlock(block map[string]any) map[string]any {
blockType, _ := block["type"].(string)
out := map[string]any{"type": blockType}
switch blockType {
case "text":
if text, ok := block["text"].(string); ok {
out["text"] = TraceTruncateString(text)
} else {
out["text"] = TraceCompactValue(block["text"], 0)
}
case "thinking":
if thinking, ok := block["thinking"].(string); ok {
out["thinking"] = TraceTruncateString(thinking)
} else {
out["thinking"] = TraceCompactValue(block["thinking"], 0)
}
case "tool_use", "server_tool_use":
out["id"] = block["id"]
out["name"] = block["name"]
out["input"] = TraceCompactValue(block["input"], 0)
case "tool_result", "web_search_tool_result":
out["tool_use_id"] = block["tool_use_id"]
out["content"] = TraceCompactValue(block["content"], 0)
case "image":
if source, ok := block["source"].(map[string]any); ok {
out["source"] = map[string]any{
"type": source["type"],
"media_type": source["media_type"],
"url": source["url"],
"data_len": len(fmt.Sprint(source["data"])),
}
}
default:
out["block"] = TraceCompactValue(block, 0)
}
return out
}
func traceTools(tools []Tool) []map[string]any {
out := make([]map[string]any, 0, len(tools))
for _, t := range tools {
out = append(out, TraceTool(t))
}
return out
}
// TraceTool returns a compact trace representation of an Anthropic Tool.
func TraceTool(t Tool) map[string]any {
return map[string]any{
"type": t.Type,
"name": t.Name,
"description": TraceTruncateString(t.Description),
"input_schema": TraceJSON(t.InputSchema),
"max_uses": t.MaxUses,
}
}
// ContentBlockTypes returns the type strings from content (when it's []any blocks).
func ContentBlockTypes(content any) []string {
blocks, ok := content.([]any)
if !ok {
return nil
}
types := make([]string, 0, len(blocks))
for _, block := range blocks {
blockMap, ok := block.(map[string]any)
if !ok {
types = append(types, fmt.Sprintf("%T", block))
continue
}
t, _ := blockMap["type"].(string)
types = append(types, t)
}
return types
}
func ptrVal[T any](v *T) any {
if v == nil {
return nil
}
return *v
}
// ---------------------------------------------------------------------------
// Ollama api.* tracing (shared between anthropic and middleware packages)
// ---------------------------------------------------------------------------
// TraceChatRequest returns a compact trace representation of an Ollama ChatRequest.
func TraceChatRequest(req *api.ChatRequest) map[string]any {
if req == nil {
return nil
}
stream := false
if req.Stream != nil {
stream = *req.Stream
}
return map[string]any{
"model": req.Model,
"messages": TraceAPIMessages(req.Messages),
"tools": TraceAPITools(req.Tools),
"stream": stream,
"options": req.Options,
"think": TraceJSON(req.Think),
}
}
// TraceChatResponse returns a compact trace representation of an Ollama ChatResponse.
func TraceChatResponse(resp api.ChatResponse) map[string]any {
return map[string]any{
"model": resp.Model,
"done": resp.Done,
"done_reason": resp.DoneReason,
"message": TraceAPIMessage(resp.Message),
"metrics": TraceJSON(resp.Metrics),
}
}
// TraceAPIMessages returns compact trace representations for a slice of api.Message.
func TraceAPIMessages(msgs []api.Message) []map[string]any {
out := make([]map[string]any, 0, len(msgs))
for _, m := range msgs {
out = append(out, TraceAPIMessage(m))
}
return out
}
// TraceAPIMessage returns a compact trace representation of a single api.Message.
func TraceAPIMessage(m api.Message) map[string]any {
return map[string]any{
"role": m.Role,
"content": TraceTruncateString(m.Content),
"thinking": TraceTruncateString(m.Thinking),
"images": traceImageSizes(m.Images),
"tool_calls": traceToolCalls(m.ToolCalls),
"tool_name": m.ToolName,
"tool_call_id": m.ToolCallID,
}
}
func traceImageSizes(images []api.ImageData) []int {
if len(images) == 0 {
return nil
}
sizes := make([]int, 0, len(images))
for _, img := range images {
sizes = append(sizes, len(img))
}
return sizes
}
// TraceAPITools returns compact trace representations for a slice of api.Tool.
func TraceAPITools(tools api.Tools) []map[string]any {
out := make([]map[string]any, 0, len(tools))
for _, t := range tools {
out = append(out, TraceAPITool(t))
}
return out
}
// TraceAPITool returns a compact trace representation of a single api.Tool.
func TraceAPITool(t api.Tool) map[string]any {
return map[string]any{
"type": t.Type,
"name": t.Function.Name,
"description": TraceTruncateString(t.Function.Description),
"parameters": TraceJSON(t.Function.Parameters),
}
}
// TraceToolCall returns a compact trace representation of an api.ToolCall.
func TraceToolCall(tc api.ToolCall) map[string]any {
return map[string]any{
"id": tc.ID,
"name": tc.Function.Name,
"args": TraceJSON(tc.Function.Arguments),
}
}
func traceToolCalls(tcs []api.ToolCall) []map[string]any {
if len(tcs) == 0 {
return nil
}
out := make([]map[string]any, 0, len(tcs))
for _, tc := range tcs {
out = append(out, TraceToolCall(tc))
}
return out
}

View File

@@ -449,6 +449,16 @@ func (c *Client) Version(ctx context.Context) (string, error) {
return version.Version, nil
}
// CloudStatusExperimental returns whether cloud features are disabled on the server.
func (c *Client) CloudStatusExperimental(ctx context.Context) (*StatusResponse, error) {
var status StatusResponse
if err := c.do(ctx, http.MethodGet, "/api/status", nil, &status); err != nil {
return nil, err
}
return &status, nil
}
// Signout will signout a client for a local ollama server.
func (c *Client) Signout(ctx context.Context) error {
return c.do(ctx, http.MethodPost, "/api/signout", nil, nil)

View File

@@ -834,6 +834,16 @@ type TokenResponse struct {
Token string `json:"token"`
}
type CloudStatus struct {
Disabled bool `json:"disabled"`
Source string `json:"source"`
}
// StatusResponse is the response from [Client.CloudStatusExperimental].
type StatusResponse struct {
Cloud CloudStatus `json:"cloud"`
}
// GenerateResponse is the response passed into [GenerateResponseFunc].
type GenerateResponse struct {
// Model is the model name that generated the response.

View File

@@ -205,6 +205,11 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
return nil, err
}
cloudDisabled, err := s.store.CloudDisabled()
if err != nil {
return nil, err
}
cmd := commandContext(ctx, s.bin, "serve")
cmd.Stdout, cmd.Stderr = s.log, s.log
@@ -230,6 +235,11 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
if settings.ContextLength > 0 {
env["OLLAMA_CONTEXT_LENGTH"] = strconv.Itoa(settings.ContextLength)
}
if cloudDisabled {
env["OLLAMA_NO_CLOUD"] = "1"
} else {
env["OLLAMA_NO_CLOUD"] = "0"
}
cmd.Env = []string{}
for k, v := range env {
cmd.Env = append(cmd.Env, k+"="+v)

View File

@@ -111,7 +111,7 @@ func TestServerCmd(t *testing.T) {
for _, want := range tt.want {
found := false
for _, env := range cmd.Env {
if strings.Contains(env, want) {
if strings.HasPrefix(env, want) {
found = true
break
}
@@ -123,7 +123,7 @@ func TestServerCmd(t *testing.T) {
for _, dont := range tt.dont {
for _, env := range cmd.Env {
if strings.Contains(env, dont) {
if strings.HasPrefix(env, dont) {
t.Errorf("unexpected environment variable: %s", env)
}
}
@@ -136,6 +136,75 @@ func TestServerCmd(t *testing.T) {
}
}
func TestServerCmdCloudSettingEnv(t *testing.T) {
tests := []struct {
name string
envValue string
configContent string
want string
}{
{
name: "default cloud enabled",
want: "OLLAMA_NO_CLOUD=0",
},
{
name: "env disables cloud",
envValue: "1",
want: "OLLAMA_NO_CLOUD=1",
},
{
name: "config disables cloud",
configContent: `{"disable_ollama_cloud": true}`,
want: "OLLAMA_NO_CLOUD=1",
},
{
name: "invalid env disables cloud",
envValue: "invalid",
want: "OLLAMA_NO_CLOUD=1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpHome := t.TempDir()
t.Setenv("HOME", tmpHome)
t.Setenv("USERPROFILE", tmpHome)
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
if tt.configContent != "" {
configDir := filepath.Join(tmpHome, ".ollama")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatalf("mkdir config dir: %v", err)
}
configPath := filepath.Join(configDir, "server.json")
if err := os.WriteFile(configPath, []byte(tt.configContent), 0o644); err != nil {
t.Fatalf("write config: %v", err)
}
}
st := &store.Store{DBPath: filepath.Join(t.TempDir(), "db.sqlite")}
defer st.Close()
s := &Server{store: st}
cmd, err := s.cmd(t.Context())
if err != nil {
t.Fatalf("s.cmd() error = %v", err)
}
found := false
for _, env := range cmd.Env {
if env == tt.want {
found = true
break
}
}
if !found {
t.Fatalf("expected environment variable %q in command env", tt.want)
}
})
}
}
func TestGetInferenceComputer(t *testing.T) {
tests := []struct {
name string

128
app/store/cloud_config.go Normal file
View File

@@ -0,0 +1,128 @@
//go:build windows || darwin
package store
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"github.com/ollama/ollama/envconfig"
)
const serverConfigFilename = "server.json"
type serverConfig struct {
DisableOllamaCloud bool `json:"disable_ollama_cloud,omitempty"`
}
// CloudDisabled returns whether cloud features should be disabled.
// The source of truth is: OLLAMA_NO_CLOUD OR ~/.ollama/server.json:disable_ollama_cloud.
func (s *Store) CloudDisabled() (bool, error) {
disabled, _, err := s.CloudStatus()
return disabled, err
}
// CloudStatus returns whether cloud is disabled and the source of that decision.
// Source is one of: "none", "env", "config", "both".
func (s *Store) CloudStatus() (bool, string, error) {
if err := s.ensureDB(); err != nil {
return false, "", err
}
configDisabled, err := readServerConfigCloudDisabled()
if err != nil {
return false, "", err
}
envDisabled := envconfig.NoCloudEnv()
return envDisabled || configDisabled, cloudStatusSource(envDisabled, configDisabled), nil
}
// SetCloudEnabled writes the cloud setting to ~/.ollama/server.json.
func (s *Store) SetCloudEnabled(enabled bool) error {
if err := s.ensureDB(); err != nil {
return err
}
return setCloudEnabled(enabled)
}
func setCloudEnabled(enabled bool) error {
configPath, err := serverConfigPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return fmt.Errorf("create server config directory: %w", err)
}
configMap := map[string]any{}
if data, err := os.ReadFile(configPath); err == nil {
if err := json.Unmarshal(data, &configMap); err != nil {
// If the existing file is invalid JSON, overwrite with a fresh object.
configMap = map[string]any{}
}
} else if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("read server config: %w", err)
}
configMap["disable_ollama_cloud"] = !enabled
data, err := json.MarshalIndent(configMap, "", " ")
if err != nil {
return fmt.Errorf("marshal server config: %w", err)
}
data = append(data, '\n')
if err := os.WriteFile(configPath, data, 0o644); err != nil {
return fmt.Errorf("write server config: %w", err)
}
return nil
}
func readServerConfigCloudDisabled() (bool, error) {
configPath, err := serverConfigPath()
if err != nil {
return false, err
}
data, err := os.ReadFile(configPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return false, nil
}
return false, fmt.Errorf("read server config: %w", err)
}
var cfg serverConfig
// Invalid or unexpected JSON should not block startup; treat as default.
if json.Unmarshal(data, &cfg) == nil {
return cfg.DisableOllamaCloud, nil
}
return false, nil
}
func serverConfigPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolve home directory: %w", err)
}
return filepath.Join(home, ".ollama", serverConfigFilename), nil
}
func cloudStatusSource(envDisabled bool, configDisabled bool) string {
switch {
case envDisabled && configDisabled:
return "both"
case envDisabled:
return "env"
case configDisabled:
return "config"
default:
return "none"
}
}

View File

@@ -0,0 +1,130 @@
//go:build windows || darwin
package store
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestCloudDisabled(t *testing.T) {
tests := []struct {
name string
envValue string
configContent string
wantDisabled bool
wantSource string
}{
{
name: "default enabled",
wantDisabled: false,
wantSource: "none",
},
{
name: "env disables cloud",
envValue: "1",
wantDisabled: true,
wantSource: "env",
},
{
name: "config disables cloud",
configContent: `{"disable_ollama_cloud": true}`,
wantDisabled: true,
wantSource: "config",
},
{
name: "env and config",
envValue: "1",
configContent: `{"disable_ollama_cloud": false}`,
wantDisabled: true,
wantSource: "env",
},
{
name: "invalid config is ignored",
configContent: `{bad`,
wantDisabled: false,
wantSource: "none",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpHome := t.TempDir()
setTestHome(t, tmpHome)
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
if tt.configContent != "" {
configDir := filepath.Join(tmpHome, ".ollama")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatalf("mkdir config dir: %v", err)
}
configPath := filepath.Join(configDir, serverConfigFilename)
if err := os.WriteFile(configPath, []byte(tt.configContent), 0o644); err != nil {
t.Fatalf("write config: %v", err)
}
}
s := &Store{DBPath: filepath.Join(tmpHome, "db.sqlite")}
defer s.Close()
disabled, err := s.CloudDisabled()
if err != nil {
t.Fatalf("CloudDisabled() error = %v", err)
}
if disabled != tt.wantDisabled {
t.Fatalf("CloudDisabled() = %v, want %v", disabled, tt.wantDisabled)
}
statusDisabled, source, err := s.CloudStatus()
if err != nil {
t.Fatalf("CloudStatus() error = %v", err)
}
if statusDisabled != tt.wantDisabled {
t.Fatalf("CloudStatus() disabled = %v, want %v", statusDisabled, tt.wantDisabled)
}
if source != tt.wantSource {
t.Fatalf("CloudStatus() source = %v, want %v", source, tt.wantSource)
}
})
}
}
func TestSetCloudEnabled(t *testing.T) {
tmpHome := t.TempDir()
setTestHome(t, tmpHome)
configDir := filepath.Join(tmpHome, ".ollama")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatalf("mkdir config dir: %v", err)
}
configPath := filepath.Join(configDir, serverConfigFilename)
if err := os.WriteFile(configPath, []byte(`{"another_key":"value","disable_ollama_cloud":true}`), 0o644); err != nil {
t.Fatalf("seed config: %v", err)
}
s := &Store{DBPath: filepath.Join(tmpHome, "db.sqlite")}
defer s.Close()
if err := s.SetCloudEnabled(true); err != nil {
t.Fatalf("SetCloudEnabled(true) error = %v", err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("read config: %v", err)
}
var got map[string]any
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("unmarshal config: %v", err)
}
if got["disable_ollama_cloud"] != false {
t.Fatalf("disable_ollama_cloud = %v, want false", got["disable_ollama_cloud"])
}
if got["another_key"] != "value" {
t.Fatalf("another_key = %v, want value", got["another_key"])
}
}

View File

@@ -14,7 +14,7 @@ import (
// currentSchemaVersion defines the current database schema version.
// Increment this when making schema changes that require migrations.
const currentSchemaVersion = 12
const currentSchemaVersion = 13
// database wraps the SQLite connection.
// SQLite handles its own locking for concurrent access:
@@ -84,6 +84,7 @@ func (db *database) init() error {
sidebar_open BOOLEAN NOT NULL DEFAULT 0,
think_enabled BOOLEAN NOT NULL DEFAULT 0,
think_level TEXT NOT NULL DEFAULT '',
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
remote TEXT NOT NULL DEFAULT '', -- deprecated
schema_version INTEGER NOT NULL DEFAULT %d
);
@@ -244,6 +245,12 @@ func (db *database) migrate() error {
return fmt.Errorf("migrate v11 to v12: %w", err)
}
version = 12
case 12:
// add cloud_setting_migrated column to settings table
if err := db.migrateV12ToV13(); err != nil {
return fmt.Errorf("migrate v12 to v13: %w", err)
}
version = 13
default:
// If we have a version we don't recognize, just set it to current
// This might happen during development
@@ -452,6 +459,21 @@ func (db *database) migrateV11ToV12() error {
return nil
}
// migrateV12ToV13 adds cloud_setting_migrated to settings.
func (db *database) migrateV12ToV13() error {
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0`)
if err != nil && !duplicateColumnError(err) {
return fmt.Errorf("add cloud_setting_migrated column: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 13`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
func (db *database) cleanupOrphanedData() error {
_, err := db.conn.Exec(`
@@ -1108,9 +1130,9 @@ func (db *database) getSettings() (Settings, error) {
var s Settings
err := db.conn.QueryRow(`
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
FROM settings
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
if err != nil {
return Settings{}, fmt.Errorf("get settings: %w", err)
}
@@ -1121,14 +1143,40 @@ func (db *database) getSettings() (Settings, error) {
func (db *database) setSettings(s Settings) error {
_, err := db.conn.Exec(`
UPDATE settings
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
if err != nil {
return fmt.Errorf("set settings: %w", err)
}
return nil
}
func (db *database) isCloudSettingMigrated() (bool, error) {
var migrated bool
err := db.conn.QueryRow("SELECT cloud_setting_migrated FROM settings").Scan(&migrated)
if err != nil {
return false, fmt.Errorf("get cloud setting migration status: %w", err)
}
return migrated, nil
}
func (db *database) setCloudSettingMigrated(migrated bool) error {
_, err := db.conn.Exec("UPDATE settings SET cloud_setting_migrated = ?", migrated)
if err != nil {
return fmt.Errorf("set cloud setting migration status: %w", err)
}
return nil
}
func (db *database) getAirplaneMode() (bool, error) {
var airplaneMode bool
err := db.conn.QueryRow("SELECT airplane_mode FROM settings").Scan(&airplaneMode)
if err != nil {
return false, fmt.Errorf("get airplane_mode: %w", err)
}
return airplaneMode, nil
}
func (db *database) getWindowSize() (int, int, error) {
var width, height int
err := db.conn.QueryRow("SELECT window_width, window_height FROM settings").Scan(&width, &height)

View File

@@ -127,6 +127,65 @@ func TestNoConfigToMigrate(t *testing.T) {
}
}
func TestCloudMigrationFromAirplaneMode(t *testing.T) {
tmpHome := t.TempDir()
setTestHome(t, tmpHome)
t.Setenv("OLLAMA_NO_CLOUD", "")
dbPath := filepath.Join(tmpHome, "db.sqlite")
db, err := newDatabase(dbPath)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
if _, err := db.conn.Exec("UPDATE settings SET airplane_mode = 1, cloud_setting_migrated = 0"); err != nil {
db.Close()
t.Fatalf("failed to seed airplane migration state: %v", err)
}
db.Close()
s := Store{DBPath: dbPath}
defer s.Close()
// Trigger DB initialization + one-time cloud migration.
if _, err := s.ID(); err != nil {
t.Fatalf("failed to initialize store: %v", err)
}
disabled, err := s.CloudDisabled()
if err != nil {
t.Fatalf("CloudDisabled() error: %v", err)
}
if !disabled {
t.Fatal("expected cloud to be disabled after migrating airplane_mode=true")
}
configPath := filepath.Join(tmpHome, ".ollama", serverConfigFilename)
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("failed to read migrated server config: %v", err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("failed to parse migrated server config: %v", err)
}
if cfg["disable_ollama_cloud"] != true {
t.Fatalf("disable_ollama_cloud = %v, want true", cfg["disable_ollama_cloud"])
}
var airplaneMode, migrated bool
if err := s.db.conn.QueryRow("SELECT airplane_mode, cloud_setting_migrated FROM settings").Scan(&airplaneMode, &migrated); err != nil {
t.Fatalf("failed to read migration flags from DB: %v", err)
}
if !airplaneMode {
t.Fatal("expected legacy airplane_mode value to remain unchanged")
}
if !migrated {
t.Fatal("expected cloud_setting_migrated to be true")
}
}
const (
v1Schema = `
CREATE TABLE IF NOT EXISTS settings (

View File

@@ -149,9 +149,6 @@ type Settings struct {
// ContextLength specifies the context length for the ollama server (using OLLAMA_CONTEXT_LENGTH)
ContextLength int
// AirplaneMode when true, turns off Ollama Turbo features and only uses local models
AirplaneMode bool
// TurboEnabled indicates if Ollama Turbo features are enabled
TurboEnabled bool
@@ -259,6 +256,40 @@ func (s *Store) ensureDB() error {
}
}
// Run one-time migration from legacy airplane_mode behavior.
if err := s.migrateCloudSetting(database); err != nil {
return fmt.Errorf("migrate cloud setting: %w", err)
}
return nil
}
// migrateCloudSetting migrates legacy airplane_mode into server.json exactly once.
// After this, cloud state is sourced from server.json OR OLLAMA_NO_CLOUD.
func (s *Store) migrateCloudSetting(database *database) error {
migrated, err := database.isCloudSettingMigrated()
if err != nil {
return err
}
if migrated {
return nil
}
airplaneMode, err := database.getAirplaneMode()
if err != nil {
return err
}
if airplaneMode {
if err := setCloudEnabled(false); err != nil {
return fmt.Errorf("migrate airplane_mode to cloud disabled: %w", err)
}
}
if err := database.setCloudSettingMigrated(true); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,11 @@
//go:build windows || darwin
package store
import "testing"
func setTestHome(t *testing.T, home string) {
t.Helper()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
}

35
app/tools/cloud_policy.go Normal file
View File

@@ -0,0 +1,35 @@
//go:build windows || darwin
package tools
import (
"context"
"errors"
"github.com/ollama/ollama/api"
internalcloud "github.com/ollama/ollama/internal/cloud"
)
// ensureCloudEnabledForTool checks cloud policy from the connected Ollama server.
// If policy cannot be determined, this fails closed and blocks the operation.
func ensureCloudEnabledForTool(ctx context.Context, operation string) error {
// Reuse shared message formatting; policy evaluation is still done via
// the connected server's /api/status endpoint below.
disabledMessage := internalcloud.DisabledError(operation)
client, err := api.ClientFromEnvironment()
if err != nil {
return errors.New(disabledMessage + " (unable to verify server cloud policy)")
}
status, err := client.CloudStatusExperimental(ctx)
if err != nil {
return errors.New(disabledMessage + " (unable to verify server cloud policy)")
}
if status.Cloud.Disabled {
return errors.New(disabledMessage)
}
return nil
}

View File

@@ -0,0 +1,73 @@
//go:build windows || darwin
package tools
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestEnsureCloudEnabledForTool(t *testing.T) {
const op = "web search is unavailable"
const disabledPrefix = "ollama cloud is disabled: web search is unavailable"
t.Run("enabled allows tool execution", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/status" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"cloud":{"disabled":false,"source":"none"}}`))
}))
t.Cleanup(ts.Close)
t.Setenv("OLLAMA_HOST", ts.URL)
if err := ensureCloudEnabledForTool(context.Background(), op); err != nil {
t.Fatalf("expected nil error, got %v", err)
}
})
t.Run("disabled blocks tool execution", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/status" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"cloud":{"disabled":true,"source":"config"}}`))
}))
t.Cleanup(ts.Close)
t.Setenv("OLLAMA_HOST", ts.URL)
err := ensureCloudEnabledForTool(context.Background(), op)
if err == nil {
t.Fatal("expected error, got nil")
}
if got := err.Error(); got != disabledPrefix {
t.Fatalf("unexpected error: %q", got)
}
})
t.Run("status unavailable fails closed", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
t.Cleanup(ts.Close)
t.Setenv("OLLAMA_HOST", ts.URL)
err := ensureCloudEnabledForTool(context.Background(), op)
if err == nil {
t.Fatal("expected error, got nil")
}
if got := err.Error(); !strings.Contains(got, disabledPrefix) {
t.Fatalf("expected disabled prefix, got %q", got)
}
if got := err.Error(); !strings.Contains(got, "unable to verify server cloud policy") {
t.Fatalf("expected verification failure detail, got %q", got)
}
})
}

View File

@@ -77,6 +77,10 @@ func (w *WebFetch) Execute(ctx context.Context, args map[string]any) (any, strin
}
func performWebFetch(ctx context.Context, targetURL string) (*FetchResponse, error) {
if err := ensureCloudEnabledForTool(ctx, "web fetch is unavailable"); err != nil {
return nil, err
}
reqBody := FetchRequest{URL: targetURL}
jsonBody, err := json.Marshal(reqBody)
if err != nil {

View File

@@ -93,6 +93,10 @@ func (w *WebSearch) Execute(ctx context.Context, args map[string]any) (any, stri
}
func performWebSearch(ctx context.Context, query string, maxResults int) (*SearchResponse, error) {
if err := ensureCloudEnabledForTool(ctx, "web search is unavailable"); err != nil {
return nil, err
}
reqBody := SearchRequest{Query: query, MaxResults: maxResults}
jsonBody, err := json.Marshal(reqBody)

View File

@@ -406,7 +406,6 @@ export class Settings {
Tools: boolean;
WorkingDir: string;
ContextLength: number;
AirplaneMode: boolean;
TurboEnabled: boolean;
WebSearchEnabled: boolean;
ThinkEnabled: boolean;
@@ -424,7 +423,6 @@ export class Settings {
this.Tools = source["Tools"];
this.WorkingDir = source["WorkingDir"];
this.ContextLength = source["ContextLength"];
this.AirplaneMode = source["AirplaneMode"];
this.TurboEnabled = source["TurboEnabled"];
this.WebSearchEnabled = source["WebSearchEnabled"];
this.ThinkEnabled = source["ThinkEnabled"];

View File

@@ -27,6 +27,12 @@ declare module "@/gotypes" {
Model.prototype.isCloud = function (): boolean {
return this.model.endsWith("cloud");
};
export type CloudStatusSource = "env" | "config" | "both" | "none";
export interface CloudStatusResponse {
disabled: boolean;
source: CloudStatusSource;
}
// Helper function to convert Uint8Array to base64
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
@@ -285,6 +291,28 @@ export async function updateSettings(settings: Settings): Promise<{
};
}
export async function updateCloudSetting(
enabled: boolean,
): Promise<CloudStatusResponse> {
const response = await fetch(`${API_BASE}/api/v1/cloud`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ enabled }),
});
if (!response.ok) {
const error = await response.text();
throw new Error(error || "Failed to update cloud setting");
}
const data = await response.json();
return {
disabled: Boolean(data.disabled),
source: (data.source as CloudStatusSource) || "none",
};
}
export async function renameChat(chatId: string, title: string): Promise<void> {
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}/rename`, {
method: "PUT",
@@ -414,3 +442,16 @@ export async function fetchHealth(): Promise<boolean> {
return false;
}
}
export async function getCloudStatus(): Promise<CloudStatusResponse | null> {
const response = await fetch(`${API_BASE}/api/v1/cloud`);
if (!response.ok) {
throw new Error(`Failed to fetch cloud status: ${response.status}`);
}
const data = await response.json();
return {
disabled: Boolean(data.disabled),
source: (data.source as CloudStatusSource) || "none",
};
}

View File

@@ -22,6 +22,7 @@ import { useUser } from "@/hooks/useUser";
import { DisplayLogin } from "@/components/DisplayLogin";
import { ErrorEvent, Message } from "@/gotypes";
import { useSettings } from "@/hooks/useSettings";
import { useCloudStatus } from "@/hooks/useCloudStatus";
import { ThinkButton } from "./ThinkButton";
import { ErrorMessage } from "./ErrorMessage";
import { processFiles } from "@/utils/fileValidation";
@@ -141,12 +142,12 @@ function ChatForm({
const {
settings: {
webSearchEnabled,
airplaneMode,
thinkEnabled,
thinkLevel: settingsThinkLevel,
},
setSettings,
} = useSettings();
const { cloudDisabled } = useCloudStatus();
// current supported models for web search
const modelLower = selectedModel?.model.toLowerCase() || "";
@@ -180,6 +181,12 @@ function ChatForm({
setSettings,
]);
useEffect(() => {
if (cloudDisabled && webSearchEnabled) {
setSettings({ WebSearchEnabled: false });
}
}, [cloudDisabled, webSearchEnabled, setSettings]);
const removeFile = (index: number) => {
setMessage((prev) => ({
...prev,
@@ -234,19 +241,19 @@ function ChatForm({
// Determine if login banner should be shown
const shouldShowLoginBanner =
!cloudDisabled &&
!isLoadingUser &&
!isAuthenticated &&
((webSearchEnabled && supportsWebSearch) ||
(selectedModel?.isCloud() && !airplaneMode));
((webSearchEnabled && supportsWebSearch) || selectedModel?.isCloud());
// Determine which feature to highlight in the banner
const getActiveFeatureForBanner = () => {
if (cloudDisabled) return null;
if (!isAuthenticated) {
if (loginPromptFeature) return loginPromptFeature;
if (webSearchEnabled && selectedModel?.isCloud() && !airplaneMode)
return "webSearch";
if (webSearchEnabled && selectedModel?.isCloud()) return "webSearch";
if (webSearchEnabled) return "webSearch";
if (selectedModel?.isCloud() && !airplaneMode) return "turbo";
if (selectedModel?.isCloud()) return "turbo";
}
return null;
};
@@ -269,11 +276,12 @@ function ChatForm({
useEffect(() => {
if (
isAuthenticated ||
(!webSearchEnabled && !!selectedModel?.isCloud() && !airplaneMode)
cloudDisabled ||
(!webSearchEnabled && !!selectedModel?.isCloud())
) {
setLoginPromptFeature(null);
}
}, [isAuthenticated, webSearchEnabled, selectedModel, airplaneMode]);
}, [isAuthenticated, webSearchEnabled, selectedModel, cloudDisabled]);
// When entering edit mode, populate the composition with existing data
useEffect(() => {
@@ -465,6 +473,10 @@ function ChatForm({
const handleSubmit = async () => {
if (!message.content.trim() || isStreaming || isDownloading) return;
if (cloudDisabled && selectedModel?.isCloud()) {
return;
}
// Check if cloud mode is enabled but user is not authenticated
if (shouldShowLoginBanner) {
return;
@@ -478,7 +490,8 @@ function ChatForm({
}),
);
const useWebSearch = supportsWebSearch && webSearchEnabled && !airplaneMode;
const useWebSearch =
supportsWebSearch && webSearchEnabled && !cloudDisabled;
const useThink = modelSupportsThinkingLevels
? thinkLevel
: supportsThinkToggling
@@ -899,7 +912,7 @@ function ChatForm({
)}
<WebSearchButton
ref={webSearchButtonRef}
isVisible={supportsWebSearch && airplaneMode === false}
isVisible={supportsWebSearch && cloudDisabled === false}
isActive={webSearchEnabled}
onToggle={() => {
if (!webSearchEnabled && !isAuthenticated) {
@@ -940,6 +953,7 @@ function ChatForm({
!isDownloading &&
(!message.content.trim() ||
shouldShowLoginBanner ||
(cloudDisabled && selectedModel?.isCloud()) ||
message.fileErrors.length > 0)
}
className={`flex items-center justify-center h-9 w-9 rounded-full disabled:cursor-default cursor-pointer bg-black text-white dark:bg-white dark:text-black disabled:opacity-10 focus:outline-none focus:ring-2 focus:ring-blue-500`}

View File

@@ -8,7 +8,7 @@ import {
} from "react";
import { Model } from "@/gotypes";
import { useSelectedModel } from "@/hooks/useSelectedModel";
import { useSettings } from "@/hooks/useSettings";
import { useCloudStatus } from "@/hooks/useCloudStatus";
import { useQueryClient } from "@tanstack/react-query";
import { getModelUpstreamInfo } from "@/api";
import { ArrowDownTrayIcon } from "@heroicons/react/24/outline";
@@ -34,7 +34,7 @@ export const ModelPicker = forwardRef<
chatId,
searchQuery,
);
const { settings } = useSettings();
const { cloudDisabled } = useCloudStatus();
const dropdownRef = useRef<HTMLDivElement>(null);
const searchInputRef = useRef<HTMLInputElement>(null);
const queryClient = useQueryClient();
@@ -219,7 +219,7 @@ export const ModelPicker = forwardRef<
models={models}
selectedModel={selectedModel}
onModelSelect={handleModelSelect}
airplaneMode={settings.airplaneMode}
cloudDisabled={cloudDisabled}
isOpen={isOpen}
/>
</div>
@@ -233,13 +233,13 @@ export const ModelList = forwardRef(function ModelList(
models,
selectedModel,
onModelSelect,
airplaneMode,
cloudDisabled,
isOpen,
}: {
models: Model[];
selectedModel: Model | null;
onModelSelect: (model: Model) => void;
airplaneMode: boolean;
cloudDisabled: boolean;
isOpen: boolean;
},
ref,
@@ -348,7 +348,7 @@ export const ModelList = forwardRef(function ModelList(
</svg>
)}
{model.digest === undefined &&
(airplaneMode || !model.isCloud()) && (
(cloudDisabled || !model.isCloud()) && (
<ArrowDownTrayIcon
className="h-4 w-4 text-neutral-500 dark:text-neutral-400"
strokeWidth={1.75}

View File

@@ -11,6 +11,7 @@ import {
FolderIcon,
BoltIcon,
WrenchIcon,
CloudIcon,
XMarkIcon,
CogIcon,
ArrowLeftIcon,
@@ -18,8 +19,14 @@ import {
import { Settings as SettingsType } from "@/gotypes";
import { useNavigate } from "@tanstack/react-router";
import { useUser } from "@/hooks/useUser";
import { useCloudStatus } from "@/hooks/useCloudStatus";
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { getSettings, updateSettings } from "@/api";
import {
getSettings,
type CloudStatusResponse,
updateCloudSetting,
updateSettings,
} from "@/api";
function AnimatedDots() {
return (
@@ -53,6 +60,11 @@ export default function Settings() {
const [connectionError, setConnectionError] = useState<string | null>(null);
const [pollingInterval, setPollingInterval] = useState<number | null>(null);
const navigate = useNavigate();
const {
cloudDisabled,
cloudStatus,
isLoading: cloudStatusLoading,
} = useCloudStatus();
const {
data: settingsData,
@@ -74,6 +86,50 @@ export default function Settings() {
},
});
const updateCloudMutation = useMutation({
mutationFn: (enabled: boolean) => updateCloudSetting(enabled),
onMutate: async (enabled: boolean) => {
await queryClient.cancelQueries({ queryKey: ["cloudStatus"] });
const previous = queryClient.getQueryData<CloudStatusResponse | null>([
"cloudStatus",
]);
const envForcesDisabled =
previous?.source === "env" || previous?.source === "both";
queryClient.setQueryData<CloudStatusResponse | null>(
["cloudStatus"],
previous
? {
...previous,
disabled: !enabled || envForcesDisabled,
}
: {
disabled: !enabled,
source: "config",
},
);
return { previous };
},
onError: (_error, _enabled, context) => {
if (context?.previous !== undefined) {
queryClient.setQueryData(["cloudStatus"], context.previous);
}
},
onSuccess: (status) => {
queryClient.setQueryData<CloudStatusResponse | null>(
["cloudStatus"],
status,
);
queryClient.invalidateQueries({ queryKey: ["models"] });
queryClient.invalidateQueries({ queryKey: ["cloudStatus"] });
setShowSaved(true);
setTimeout(() => setShowSaved(false), 1500);
},
});
useEffect(() => {
refetchUser();
}, []); // eslint-disable-line react-hooks/exhaustive-deps
@@ -149,12 +205,16 @@ export default function Settings() {
Agent: false,
Tools: false,
ContextLength: 4096,
AirplaneMode: false,
});
updateSettingsMutation.mutate(defaultSettings);
}
};
const cloudOverriddenByEnv =
cloudStatus?.source === "env" || cloudStatus?.source === "both";
const cloudToggleDisabled =
cloudStatusLoading || updateCloudMutation.isPending || cloudOverriddenByEnv;
const handleConnectOllamaAccount = async () => {
setConnectionError(null);
@@ -237,7 +297,7 @@ export default function Settings() {
<div className="space-y-4 max-w-2xl mx-auto">
{/* Connect Ollama Account */}
<div className="overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
<div className="p-4 border-b border-neutral-200 dark:border-neutral-800">
<div className="p-4">
<Field>
{isLoading ? (
// Loading skeleton, this will only happen if the app started recently
@@ -344,6 +404,34 @@ export default function Settings() {
{/* Local Configuration */}
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
<div className="space-y-4 p-4">
<Field>
<div className="flex items-start justify-between gap-4">
<div className="flex items-start space-x-3 flex-1">
<CloudIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
<div>
<Label>Cloud</Label>
<Description>
{cloudOverriddenByEnv
? "The OLLAMA_NO_CLOUD environment variable is currently forcing cloud off."
: "Enable cloud models and web search."}
</Description>
</div>
</div>
<div className="flex-shrink-0">
<Switch
checked={!cloudDisabled}
disabled={cloudToggleDisabled}
onChange={(checked) => {
if (cloudOverriddenByEnv) {
return;
}
updateCloudMutation.mutate(checked);
}}
/>
</div>
</div>
</Field>
{/* Expose Ollama */}
<Field>
<div className="flex items-start justify-between gap-4">
@@ -440,35 +528,6 @@ export default function Settings() {
</div>
</div>
</Field>
{/* Airplane Mode */}
<Field>
<div className="flex items-start justify-between gap-4">
<div className="flex items-start space-x-3 flex-1">
<svg
className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100"
viewBox="0 0 21.5508 17.9033"
fill="currentColor"
>
<path d="M21.5508 8.94727C21.542 7.91895 20.1445 7.17188 18.4658 7.17188L14.9238 7.17188C14.4316 7.17188 14.2471 7.09277 13.957 6.75879L8.05078 0.316406C7.86621 0.105469 7.6377 0 7.37402 0L6.35449 0C6.12598 0 5.99414 0.202148 6.1084 0.448242L9.14941 7.17188L4.68457 7.68164L3.09375 4.76367C2.97949 4.54395 2.78613 4.44727 2.49609 4.44727L2.11816 4.44727C1.88965 4.44727 1.74023 4.59668 1.74023 4.8252L1.74023 13.0693C1.74023 13.2979 1.88965 13.4385 2.11816 13.4385L2.49609 13.4385C2.78613 13.4385 2.97949 13.3418 3.09375 13.1309L4.68457 10.2129L9.14941 10.7227L6.1084 17.4463C5.99414 17.6836 6.12598 17.8945 6.35449 17.8945L7.37402 17.8945C7.6377 17.8945 7.86621 17.7803 8.05078 17.5781L13.957 11.127C14.2471 10.8018 14.4316 10.7227 14.9238 10.7227L18.4658 10.7227C20.1445 10.7227 21.542 9.9668 21.5508 8.94727Z" />
</svg>
<div>
<Label>Airplane mode</Label>
<Description>
Airplane mode keeps data local, disabling cloud models
and web search.
</Description>
</div>
</div>
<div className="flex-shrink-0">
<Switch
checked={settings.AirplaneMode}
onChange={(checked) =>
handleChange("AirplaneMode", checked)
}
/>
</div>
</div>
</Field>
</div>
</div>

View File

@@ -6,8 +6,8 @@ import { useSelectedModel } from "./useSelectedModel";
import { createQueryBatcher } from "./useQueryBatcher";
import { useRefetchModels } from "./useModels";
import { useStreamingContext } from "@/contexts/StreamingContext";
import { useSettings } from "./useSettings";
import { getModelCapabilities } from "@/api";
import { useCloudStatus } from "./useCloudStatus";
export const useChats = () => {
return useQuery({
@@ -116,11 +116,9 @@ export const useIsModelStale = (modelName: string) => {
export const useShouldShowStaleDisplay = (model: Model | null) => {
const isStale = useIsModelStale(model?.model || "");
const { data: dismissedModels } = useDismissedStaleModels();
const {
settings: { airplaneMode },
} = useSettings();
const { cloudDisabled } = useCloudStatus();
if (model?.isCloud() && !airplaneMode) {
if (model?.isCloud() && !cloudDisabled) {
return false;
}

View File

@@ -0,0 +1,20 @@
import { useQuery } from "@tanstack/react-query";
import { getCloudStatus, type CloudStatusResponse } from "@/api";
export function useCloudStatus() {
const cloudQuery = useQuery<CloudStatusResponse | null>({
queryKey: ["cloudStatus"],
queryFn: getCloudStatus,
retry: false,
staleTime: 60 * 1000,
});
return {
cloudStatus: cloudQuery.data,
cloudDisabled: cloudQuery.data?.disabled ?? false,
isKnown: cloudQuery.data !== null && cloudQuery.data !== undefined,
isLoading: cloudQuery.isLoading,
isError: cloudQuery.isError,
error: cloudQuery.error,
};
}

View File

@@ -2,11 +2,11 @@ import { useQuery } from "@tanstack/react-query";
import { Model } from "@/gotypes";
import { getModels } from "@/api";
import { mergeModels } from "@/utils/mergeModels";
import { useSettings } from "./useSettings";
import { useMemo } from "react";
import { useCloudStatus } from "./useCloudStatus";
export function useModels(searchQuery = "") {
const { settings } = useSettings();
const { cloudDisabled } = useCloudStatus();
const localQuery = useQuery<Model[], Error>({
queryKey: ["models", searchQuery],
queryFn: () => getModels(searchQuery),
@@ -20,7 +20,7 @@ export function useModels(searchQuery = "") {
});
const allModels = useMemo(() => {
const models = mergeModels(localQuery.data || [], settings.airplaneMode);
const models = mergeModels(localQuery.data || [], cloudDisabled);
if (searchQuery && searchQuery.trim()) {
const query = searchQuery.toLowerCase().trim();
@@ -40,7 +40,7 @@ export function useModels(searchQuery = "") {
}
return models;
}, [localQuery.data, searchQuery, settings.airplaneMode]);
}, [localQuery.data, searchQuery, cloudDisabled]);
return {
...localQuery,

View File

@@ -7,6 +7,7 @@ import { Model } from "@/gotypes";
import { FEATURED_MODELS } from "@/utils/mergeModels";
import { getTotalVRAM } from "@/utils/vram.ts";
import { getInferenceCompute } from "@/api";
import { useCloudStatus } from "./useCloudStatus";
export function recommendDefaultModel(totalVRAM: number): string {
const vram = Math.max(0, Number(totalVRAM) || 0);
@@ -22,6 +23,7 @@ export function recommendDefaultModel(totalVRAM: number): string {
export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
const { settings, setSettings } = useSettings();
const { data: models = [], isLoading } = useModels(searchQuery || "");
const { cloudDisabled } = useCloudStatus();
const { data: chatData, isLoading: isChatLoading } = useChat(
currentChatId && currentChatId !== "new" ? currentChatId : "",
);
@@ -46,12 +48,11 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
const restoredChatRef = useRef<string | null>(null);
const selectedModel: Model | null = useMemo(() => {
// if airplane mode is on and selected model ends with cloud,
// switch to recommended default model
if (settings.airplaneMode && settings.selectedModel?.endsWith("cloud")) {
// If cloud is disabled and selected model ends with cloud, switch to a local default.
if (cloudDisabled && settings.selectedModel?.endsWith("cloud")) {
return (
models.find((m) => m.model === recommendedModel) ||
models.find((m) => m.isCloud) ||
models.find((m) => !m.isCloud()) ||
models.find((m) => m.digest === undefined || m.digest === "") ||
models[0] ||
null
@@ -68,7 +69,7 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
"qwen3-coder:480b",
];
const shouldMigrate =
!settings.airplaneMode &&
!cloudDisabled &&
settings.turboEnabled &&
baseModelsToMigrate.includes(settings.selectedModel);
@@ -96,13 +97,18 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
})) ||
null
);
}, [models, settings.selectedModel, settings.airplaneMode, recommendedModel]);
}, [
models,
settings.selectedModel,
cloudDisabled,
recommendedModel,
]);
useEffect(() => {
if (!selectedModel) return;
if (
settings.airplaneMode &&
cloudDisabled &&
settings.selectedModel?.endsWith("cloud") &&
selectedModel.model !== settings.selectedModel
) {
@@ -110,13 +116,17 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
}
if (
!settings.airplaneMode &&
!cloudDisabled &&
settings.turboEnabled &&
selectedModel.model !== settings.selectedModel
) {
setSettings({ SelectedModel: selectedModel.model, TurboEnabled: false });
}
}, [selectedModel, settings.airplaneMode, settings.selectedModel]);
}, [
selectedModel,
cloudDisabled,
settings.selectedModel,
]);
// Set model from chat history when chat data loads
useEffect(() => {
@@ -169,7 +179,9 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
const defaultModel =
models.find((m) => m.model === recommendedModel) ||
models.find((m) => m.isCloud()) ||
(cloudDisabled
? models.find((m) => !m.isCloud())
: models.find((m) => m.isCloud())) ||
models.find((m) => m.digest === undefined || m.digest === "") ||
models[0];
@@ -181,6 +193,7 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
inferenceComputes.length,
models.length,
settings.selectedModel,
cloudDisabled,
]);
// Add the selected model to the models list if it's not already there

View File

@@ -9,7 +9,6 @@ interface SettingsState {
webSearchEnabled: boolean;
selectedModel: string;
sidebarOpen: boolean;
airplaneMode: boolean;
thinkEnabled: boolean;
thinkLevel: string;
}
@@ -51,7 +50,6 @@ export function useSettings() {
thinkLevel: settingsData?.settings?.ThinkLevel ?? "none",
selectedModel: settingsData?.settings?.SelectedModel ?? "",
sidebarOpen: settingsData?.settings?.SidebarOpen ?? false,
airplaneMode: settingsData?.settings?.AirplaneMode ?? false,
}),
[settingsData?.settings],
);

View File

@@ -2,6 +2,7 @@ import type { QueryClient } from "@tanstack/react-query";
import { createRootRouteWithContext, Outlet } from "@tanstack/react-router";
import { getSettings } from "@/api";
import { useQuery } from "@tanstack/react-query";
import { useCloudStatus } from "@/hooks/useCloudStatus";
function RootComponent() {
// This hook ensures settings are fetched on app startup
@@ -9,6 +10,8 @@ function RootComponent() {
queryKey: ["settings"],
queryFn: getSettings,
});
// Fetch cloud status on startup (best-effort)
useCloudStatus();
return (
<div>

View File

@@ -41,14 +41,14 @@ describe("Model merging logic", () => {
expect(merged.length).toBe(FEATURED_MODELS.length + 2);
});
it("should hide cloud models in airplane mode", () => {
it("should hide cloud models when cloud is disabled", () => {
const localModels: Model[] = [
new Model({ model: "gpt-oss:120b-cloud" }),
new Model({ model: "llama3:latest" }),
new Model({ model: "mistral:latest" }),
];
const merged = mergeModels(localModels, true); // airplane mode = true
const merged = mergeModels(localModels, true); // cloud disabled = true
// No cloud models should be present
const cloudModels = merged.filter((m) => m.isCloud());

View File

@@ -32,7 +32,7 @@ function alphabeticalSort(a: Model, b: Model): number {
//Merges models, sorting cloud models first, then other models
export function mergeModels(
localModels: Model[],
airplaneMode: boolean = false,
hideCloudModels: boolean = false,
): Model[] {
const allModels = (localModels || []).map((model) => model);
@@ -95,7 +95,7 @@ export function mergeModels(
remainingModels.sort(alphabeticalSort);
return airplaneMode
return hideCloudModels
? [...featuredModels, ...remainingModels]
: [...cloudModels, ...featuredModels, ...remainingModels];
}

View File

@@ -284,12 +284,15 @@ func (s *Server) Handler() http.Handler {
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
mux.Handle("POST /api/v1/settings", handle(s.settings))
mux.Handle("GET /api/v1/cloud", handle(s.getCloudSetting))
mux.Handle("POST /api/v1/cloud", handle(s.cloudSetting))
// Ollama proxy endpoints
ollamaProxy := s.ollamaProxy()
mux.Handle("GET /api/tags", ollamaProxy)
mux.Handle("POST /api/show", ollamaProxy)
mux.Handle("GET /api/version", ollamaProxy)
mux.Handle("GET /api/status", ollamaProxy)
mux.Handle("HEAD /api/version", ollamaProxy)
mux.Handle("POST /api/me", ollamaProxy)
mux.Handle("POST /api/signout", ollamaProxy)
@@ -1460,6 +1463,40 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
})
}
func (s *Server) cloudSetting(w http.ResponseWriter, r *http.Request) error {
var req struct {
Enabled bool `json:"enabled"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return fmt.Errorf("invalid request body: %w", err)
}
if err := s.Store.SetCloudEnabled(req.Enabled); err != nil {
return fmt.Errorf("failed to persist cloud setting: %w", err)
}
s.Restart()
return s.writeCloudStatus(w)
}
func (s *Server) getCloudSetting(w http.ResponseWriter, r *http.Request) error {
return s.writeCloudStatus(w)
}
func (s *Server) writeCloudStatus(w http.ResponseWriter) error {
disabled, source, err := s.Store.CloudStatus()
if err != nil {
return fmt.Errorf("failed to load cloud status: %w", err)
}
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(map[string]any{
"disabled": disabled,
"source": source,
})
}
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
defer cancel()

View File

@@ -115,6 +115,107 @@ func TestHandlePostApiSettings(t *testing.T) {
}
}
func TestHandlePostApiCloudSetting(t *testing.T) {
tmpHome := t.TempDir()
t.Setenv("HOME", tmpHome)
t.Setenv("OLLAMA_NO_CLOUD", "")
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close()
restartCount := 0
server := &Server{
Store: testStore,
Restart: func() {
restartCount++
},
}
for _, tc := range []struct {
name string
body string
wantEnabled bool
}{
{name: "disable cloud", body: `{"enabled": false}`, wantEnabled: false},
{name: "enable cloud", body: `{"enabled": true}`, wantEnabled: true},
} {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/api/v1/cloud", bytes.NewBufferString(tc.body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
if err := server.cloudSetting(rr, req); err != nil {
t.Fatalf("cloudSetting() error = %v", err)
}
if rr.Code != http.StatusOK {
t.Fatalf("cloudSetting() status = %d, want %d", rr.Code, http.StatusOK)
}
var got map[string]any
if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
t.Fatalf("cloudSetting() invalid response JSON: %v", err)
}
if got["disabled"] != !tc.wantEnabled {
t.Fatalf("response disabled = %v, want %v", got["disabled"], !tc.wantEnabled)
}
disabled, err := testStore.CloudDisabled()
if err != nil {
t.Fatalf("CloudDisabled() error = %v", err)
}
if gotEnabled := !disabled; gotEnabled != tc.wantEnabled {
t.Fatalf("cloud enabled = %v, want %v", gotEnabled, tc.wantEnabled)
}
})
}
if restartCount != 2 {
t.Fatalf("Restart called %d times, want 2", restartCount)
}
}
func TestHandleGetApiCloudSetting(t *testing.T) {
tmpHome := t.TempDir()
t.Setenv("HOME", tmpHome)
t.Setenv("OLLAMA_NO_CLOUD", "")
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close()
if err := testStore.SetCloudEnabled(false); err != nil {
t.Fatalf("SetCloudEnabled(false) error = %v", err)
}
server := &Server{
Store: testStore,
Restart: func() {},
}
req := httptest.NewRequest("GET", "/api/v1/cloud", nil)
rr := httptest.NewRecorder()
if err := server.getCloudSetting(rr, req); err != nil {
t.Fatalf("getCloudSetting() error = %v", err)
}
if rr.Code != http.StatusOK {
t.Fatalf("getCloudSetting() status = %d, want %d", rr.Code, http.StatusOK)
}
var got map[string]any
if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
t.Fatalf("getCloudSetting() invalid response JSON: %v", err)
}
if got["disabled"] != true {
t.Fatalf("response disabled = %v, want true", got["disabled"])
}
if got["source"] != "config" {
t.Fatalf("response source = %v, want config", got["source"])
}
}
func TestAuthenticationMiddleware(t *testing.T) {
tests := []struct {
name string

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"strings"
@@ -83,3 +84,24 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
// signature is <pubkey>:<signature>
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
}
// SignRequest adds a nonce query parameter and an Authorization header with
// an Ed25519 signature to req.
func SignRequest(ctx context.Context, req *http.Request) error {
nonce, err := NewNonce(rand.Reader, 16)
if err != nil {
return err
}
q := req.URL.Query()
q.Set("nonce", nonce)
req.URL.RawQuery = q.Encode()
data := []byte(fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI()))
signature, err := Sign(ctx, data)
if err != nil {
return err
}
req.Header.Set("Authorization", signature)
return nil
}

View File

@@ -57,9 +57,9 @@ import (
func init() {
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
config.DefaultSingleSelector = func(title string, items []config.ModelItem) (string, error) {
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingle(title, tuiItems)
result, err := tui.SelectSingle(title, tuiItems, current)
if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled
}
@@ -182,6 +182,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
case "parser":
mfConfig.Parser = cmd.Args
case "renderer":
mfConfig.Renderer = cmd.Args
}
}
@@ -581,6 +585,17 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
opts.WordWrap = !nowrap
useImagegen := false
if cmd.Flags().Lookup("imagegen") != nil {
useImagegen, err = cmd.Flags().GetBool("imagegen")
if err != nil {
return err
}
}
if useImagegen {
opts.Options["use_imagegen_runner"] = true
}
// Fill out the rest of the options based on information about the
// model.
client, err := api.ClientFromEnvironment()
@@ -1885,13 +1900,25 @@ func runInteractiveTUI(cmd *cobra.Command) {
return
}
// Selector adapters for tui
singleSelector := func(title string, items []config.ModelItem) (string, error) {
tuiItems := make([]tui.SelectItem, len(items))
for i, item := range items {
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
if version.Version != "0.0.0" && version.IsOfficialInstall() && version.IsLocalHost(envconfig.Host()) {
if version.HasCachedUpdate() {
fmt.Print("A new version of Ollama is available. Run \"ollama update\" to install.\n\n")
_ = version.ClearCachedUpdate()
}
result, err := tui.SelectSingle(title, tuiItems)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if available, err := version.CheckForUpdate(ctx); err == nil && available {
_ = version.CacheAvailableUpdate()
}
}()
}
// Selector adapters for tui
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingle(title, tuiItems, current)
if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled
}
@@ -1899,10 +1926,7 @@ func runInteractiveTUI(cmd *cobra.Command) {
}
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
tuiItems := make([]tui.SelectItem, len(items))
for i, item := range items {
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
}
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
if errors.Is(err, tui.ErrCancelled) {
return nil, config.ErrCancelled
@@ -1949,7 +1973,7 @@ func runInteractiveTUI(cmd *cobra.Command) {
launchIntegration := func(name string) bool {
// If not configured or model no longer exists, prompt for model selection
configuredModel := config.IntegrationModel(name)
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) {
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
if errors.Is(err, config.ErrCancelled) {
return false // Return to main menu
@@ -1971,7 +1995,7 @@ func runInteractiveTUI(cmd *cobra.Command) {
return
case tui.SelectionRunModel:
_ = config.SetLastSelection("run")
if modelName := config.LastModel(); modelName != "" {
if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) {
runModel(modelName)
} else {
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
@@ -1999,6 +2023,9 @@ func runInteractiveTUI(cmd *cobra.Command) {
continue
}
}
if config.IsCloudModelDisabled(cmd.Context(), modelName) {
continue // Return to main menu
}
runModel(modelName)
case tui.SelectionIntegration:
_ = config.SetLastSelection(result.Integration)
@@ -2008,6 +2035,17 @@ func runInteractiveTUI(cmd *cobra.Command) {
case tui.SelectionChangeIntegration:
_ = config.SetLastSelection(result.Integration)
if len(result.Models) > 0 {
// Filter out cloud-disabled models
var filtered []string
for _, m := range result.Models {
if !config.IsCloudModelDisabled(cmd.Context(), m) {
filtered = append(filtered, m)
}
}
if len(filtered) == 0 {
continue
}
result.Models = filtered
// Multi-select from modal (Editor integrations)
if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
@@ -2017,8 +2055,11 @@ func runInteractiveTUI(cmd *cobra.Command) {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
} else if result.Model != "" {
if config.IsCloudModelDisabled(cmd.Context(), result.Model) {
continue
}
// Single-select from modal - save and launch
if err := config.SaveIntegrationModel(result.Integration, result.Model); err != nil {
if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil {
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
continue
}
@@ -2130,6 +2171,9 @@ func NewCLI() *cobra.Command {
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
runCmd.Flags().Bool("imagegen", false, "Use the imagegen runner for LLM inference")
runCmd.Flags().MarkHidden("imagegen")
stopCmd := &cobra.Command{
Use: "stop MODEL",
Short: "Stop a running model",
@@ -2273,6 +2317,7 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_MAX_QUEUE"],
envVars["OLLAMA_MODELS"],
envVars["OLLAMA_NUM_PARALLEL"],
envVars["OLLAMA_NO_CLOUD"],
envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_SCHED_SPREAD"],
@@ -2287,6 +2332,18 @@ func NewCLI() *cobra.Command {
}
}
updateCmd := &cobra.Command{
Use: "update",
Short: "Update Ollama to the latest version",
Args: cobra.ExactArgs(0),
RunE: func(cmd *cobra.Command, args []string) error {
force, _ := cmd.Flags().GetBool("force")
_ = version.ClearCachedUpdate()
return version.DoUpdate(force)
},
}
updateCmd.Flags().BoolP("force", "f", false, "Force update even if installed via a package manager")
rootCmd.AddCommand(
serveCmd,
createCmd,
@@ -2304,6 +2361,7 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
updateCmd,
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
)

View File

@@ -126,7 +126,7 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
if aliases["primary"] == "" || force {
primary, err := DefaultSingleSelector("Select model:", items)
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
if err != nil {
return nil, false, err
}

View File

@@ -140,7 +140,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
saveIntegration("claude", []string{"qwen3:8b"})
SaveIntegration("claude", []string{"qwen3:8b"})
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
got := envMap(c.modelEnvVars("qwen3:8b"))
@@ -162,7 +162,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
saveIntegration("claude", []string{"llama3.2:70b"})
SaveIntegration("claude", []string{"llama3.2:70b"})
saveAliases("claude", map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
@@ -187,7 +187,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
saveIntegration("claude", []string{"saved-model"})
SaveIntegration("claude", []string{"saved-model"})
saveAliases("claude", map[string]string{"primary": "saved-model"})
got := envMap(c.modelEnvVars("different-model"))

123
cmd/config/cline.go Normal file
View File

@@ -0,0 +1,123 @@
package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"github.com/ollama/ollama/envconfig"
)
// Cline implements Runner and Editor for the Cline CLI integration
type Cline struct{}
func (c *Cline) String() string { return "Cline" }
func (c *Cline) Run(model string, args []string) error {
if _, err := exec.LookPath("cline"); err != nil {
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
}
models := []string{model}
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
return selectModels(context.Background(), "cline", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("cline", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (c *Cline) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
p := filepath.Join(home, ".cline", "data", "globalState.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
return nil
}
func (c *Cline) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".cline", "data", "globalState.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
if err := json.Unmarshal(data, &config); err != nil {
return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
}
}
// Set Ollama as the provider for both act and plan modes
baseURL := envconfig.Host().String()
config["ollamaBaseUrl"] = baseURL
config["actModeApiProvider"] = "ollama"
config["actModeOllamaModelId"] = models[0]
config["actModeOllamaBaseUrl"] = baseURL
config["planModeApiProvider"] = "ollama"
config["planModeOllamaModelId"] = models[0]
config["planModeOllamaBaseUrl"] = baseURL
config["welcomeViewCompleted"] = true
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
return writeWithBackup(configPath, data)
}
func (c *Cline) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
if err != nil {
return nil
}
if config["actModeApiProvider"] != "ollama" {
return nil
}
modelID, _ := config["actModeOllamaModelId"].(string)
if modelID == "" {
return nil
}
return []string{modelID}
}

204
cmd/config/cline_test.go Normal file
View File

@@ -0,0 +1,204 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestClineIntegration(t *testing.T) {
c := &Cline{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Cline" {
t.Errorf("String() = %q, want %q", got, "Cline")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = c
})
}
func TestClineEdit(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".cline", "data")
configPath := filepath.Join(configDir, "globalState.json")
readConfig := func() map[string]any {
data, _ := os.ReadFile(configPath)
var config map[string]any
json.Unmarshal(data, &config)
return config
}
t.Run("creates config from scratch", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeApiProvider"] != "ollama" {
t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
}
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
}
if config["planModeApiProvider"] != "ollama" {
t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
}
if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
}
if config["welcomeViewCompleted"] != true {
t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
}
})
t.Run("preserves existing fields", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
os.MkdirAll(configDir, 0o755)
existing := map[string]any{
"remoteRulesToggles": map[string]any{},
"remoteWorkflowToggles": map[string]any{},
"customSetting": "keep-me",
}
data, _ := json.Marshal(existing)
os.WriteFile(configPath, data, 0o644)
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["customSetting"] != "keep-me" {
t.Errorf("customSetting was not preserved")
}
if config["actModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
}
})
t.Run("updates model on re-edit", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
t.Fatal(err)
}
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
}
if config["planModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
}
})
t.Run("empty models is no-op", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit(nil); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
t.Error("expected no config file to be created for empty models")
}
})
t.Run("uses first model as primary", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit([]string{"kimi-k2.5:cloud", "glm-5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
}
})
}
func TestClineModels(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".cline", "data")
configPath := filepath.Join(configDir, "globalState.json")
t.Run("returns nil when no config", func(t *testing.T) {
if models := c.Models(); models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns nil when provider is not ollama", func(t *testing.T) {
os.MkdirAll(configDir, 0o755)
config := map[string]any{
"actModeApiProvider": "anthropic",
"actModeOllamaModelId": "some-model",
}
data, _ := json.Marshal(config)
os.WriteFile(configPath, data, 0o644)
if models := c.Models(); models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns model when ollama is configured", func(t *testing.T) {
os.MkdirAll(configDir, 0o755)
config := map[string]any{
"actModeApiProvider": "ollama",
"actModeOllamaModelId": "kimi-k2.5:cloud",
}
data, _ := json.Marshal(config)
os.WriteFile(configPath, data, 0o644)
models := c.Models()
if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
}
})
}
func TestClinePaths(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns nil when no config exists", func(t *testing.T) {
if paths := c.Paths(); paths != nil {
t.Errorf("Paths() = %v, want nil", paths)
}
})
t.Run("returns path when config exists", func(t *testing.T) {
configDir := filepath.Join(tmpDir, ".cline", "data")
os.MkdirAll(configDir, 0o755)
configPath := filepath.Join(configDir, "globalState.json")
os.WriteFile(configPath, []byte("{}"), 0o644)
paths := c.Paths()
if len(paths) != 1 || paths[0] != configPath {
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
}
})
}

View File

@@ -56,8 +56,8 @@ func migrateConfig() (bool, error) {
return false, err
}
var js json.RawMessage
if err := json.Unmarshal(oldData, &js); err != nil {
// Ignore legacy files with invalid JSON and continue startup.
if !json.Valid(oldData) {
return false, nil
}
@@ -126,7 +126,7 @@ func save(cfg *config) error {
return writeWithBackup(path, data)
}
func saveIntegration(appName string, models []string) error {
func SaveIntegration(appName string, models []string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}

View File

@@ -85,7 +85,7 @@ func TestSaveAliases_PreservesModels(t *testing.T) {
setTestHome(t, tmpDir)
// First save integration with models
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
if err := SaveIntegration("claude", []string{"model1", "model2"}); err != nil {
t.Fatalf("failed to save integration: %v", err)
}
@@ -604,7 +604,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
}
// Save integration with same model (this is the pattern we use)
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
@@ -619,7 +619,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
setTestHome(t, tmpDir)
// Simulate out-of-sync state (like manual edit or bug)
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
@@ -634,7 +634,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
}
// The fix: when updating aliases, also update models
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
if err := SaveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
t.Fatal(err)
}
@@ -650,7 +650,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
setTestHome(t, tmpDir)
// Initial state
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
@@ -662,7 +662,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
if err := saveAliases("claude", newAliases); err != nil {
t.Fatal(err)
}
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
t.Fatal(err)
}

View File

@@ -27,7 +27,7 @@ func TestIntegrationConfig(t *testing.T) {
t.Run("save and load round-trip", func(t *testing.T) {
models := []string{"llama3.2", "mistral", "qwen2.5"}
if err := saveIntegration("claude", models); err != nil {
if err := SaveIntegration("claude", models); err != nil {
t.Fatal(err)
}
@@ -48,7 +48,7 @@ func TestIntegrationConfig(t *testing.T) {
t.Run("save and load aliases", func(t *testing.T) {
models := []string{"llama3.2"}
if err := saveIntegration("claude", models); err != nil {
if err := SaveIntegration("claude", models); err != nil {
t.Fatal(err)
}
aliases := map[string]string{
@@ -74,14 +74,14 @@ func TestIntegrationConfig(t *testing.T) {
})
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
t.Fatal(err)
}
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
@@ -94,7 +94,7 @@ func TestIntegrationConfig(t *testing.T) {
})
t.Run("defaultModel returns first model", func(t *testing.T) {
saveIntegration("codex", []string{"model-a", "model-b"})
SaveIntegration("codex", []string{"model-a", "model-b"})
config, _ := loadIntegration("codex")
defaultModel := ""
@@ -118,7 +118,7 @@ func TestIntegrationConfig(t *testing.T) {
})
t.Run("app name is case-insensitive", func(t *testing.T) {
saveIntegration("Claude", []string{"model-x"})
SaveIntegration("Claude", []string{"model-x"})
config, err := loadIntegration("claude")
if err != nil {
@@ -134,8 +134,8 @@ func TestIntegrationConfig(t *testing.T) {
})
t.Run("multiple integrations in single file", func(t *testing.T) {
saveIntegration("app1", []string{"model-1"})
saveIntegration("app2", []string{"model-2"})
SaveIntegration("app1", []string{"model-1"})
SaveIntegration("app2", []string{"model-2"})
config1, _ := loadIntegration("app1")
config2, _ := loadIntegration("app2")
@@ -172,8 +172,8 @@ func TestListIntegrations(t *testing.T) {
})
t.Run("returns all saved integrations", func(t *testing.T) {
saveIntegration("claude", []string{"model-1"})
saveIntegration("droid", []string{"model-2"})
SaveIntegration("claude", []string{"model-1"})
SaveIntegration("droid", []string{"model-2"})
configs, err := listIntegrations()
if err != nil {
@@ -261,7 +261,7 @@ func TestSaveIntegration_NilModels(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := saveIntegration("test", nil); err != nil {
if err := SaveIntegration("test", nil); err != nil {
t.Fatalf("saveIntegration with nil models failed: %v", err)
}
@@ -281,7 +281,7 @@ func TestSaveIntegration_EmptyAppName(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
err := saveIntegration("", []string{"model"})
err := SaveIntegration("", []string{"model"})
if err == nil {
t.Error("expected error for empty app name, got nil")
}
@@ -511,7 +511,7 @@ func TestMigrateConfig(t *testing.T) {
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
// load triggers migration, then save should write to new path
if err := saveIntegration("codex", []string{"qwen2.5"}); err != nil {
if err := SaveIntegration("codex", []string{"qwen2.5"}); err != nil {
t.Fatal(err)
}

View File

@@ -3,6 +3,7 @@ package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
@@ -51,6 +52,16 @@ func (d *Droid) Run(model string, args []string) error {
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("droid", models, func() ([]string, error) {
return selectModels(context.Background(), "droid", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := d.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}

View File

@@ -4,7 +4,7 @@ import (
"context"
"errors"
"fmt"
"maps"
"net/http"
"os"
"os/exec"
"runtime"
@@ -13,6 +13,7 @@ import (
"time"
"github.com/ollama/ollama/api"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/progress"
"github.com/spf13/cobra"
)
@@ -52,6 +53,7 @@ type AliasConfigurer interface {
var integrations = map[string]Runner{
"claude": &Claude{},
"clawdbot": &Openclaw{},
"cline": &Cline{},
"codex": &Codex{},
"moltbot": &Openclaw{},
"droid": &Droid{},
@@ -100,16 +102,17 @@ var recommendedVRAM = map[string]string{
var integrationAliases = map[string]bool{
"clawdbot": true,
"moltbot": true,
"pi": true,
}
// integrationInstallHints maps integration names to install URLs.
var integrationInstallHints = map[string]string{
"claude": "https://code.claude.com/docs/en/quickstart",
"cline": "https://cline.bot/cli",
"openclaw": "https://docs.openclaw.ai",
"codex": "https://developers.openai.com/codex/cli/",
"droid": "https://docs.factory.ai/cli/getting-started/quickstart",
"opencode": "https://opencode.ai",
"pi": "https://github.com/badlogic/pi-mono",
}
// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable.
@@ -127,13 +130,21 @@ type IntegrationInfo struct {
// integrationDescriptions maps integration names to short descriptions.
var integrationDescriptions = map[string]string{
"claude": "Anthropic's coding tool with subagents",
"cline": "Autonomous coding agent with parallel execution",
"codex": "OpenAI's open-source coding agent",
"openclaw": "Personal AI with 100+ skills",
"droid": "Factory's coding agent across terminal and IDEs",
"opencode": "Anomaly's open-source coding agent",
"pi": "Minimal AI agent toolkit with plugin support",
}
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name.
// integrationOrder defines a custom display order for integrations.
// Integrations listed here are placed at the end in the given order;
// all others appear first, sorted alphabetically.
var integrationOrder = []string{"opencode", "droid", "pi", "cline"}
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name
// with integrationOrder entries placed at the end.
func ListIntegrationInfos() []IntegrationInfo {
var result []IntegrationInfo
for name, r := range integrations {
@@ -146,7 +157,26 @@ func ListIntegrationInfos() []IntegrationInfo {
Description: integrationDescriptions[name],
})
}
orderRank := make(map[string]int, len(integrationOrder))
for i, name := range integrationOrder {
orderRank[name] = i + 1 // 1-indexed so 0 means "not in the list"
}
slices.SortFunc(result, func(a, b IntegrationInfo) int {
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
// Both have custom order: sort by their rank
if aRank > 0 && bRank > 0 {
return aRank - bRank
}
// Only one has custom order: it goes last
if aRank > 0 {
return 1
}
if bRank > 0 {
return -1
}
// Neither has custom order: alphabetical
return strings.Compare(a.Name, b.Name)
})
return result
@@ -184,9 +214,15 @@ func IsIntegrationInstalled(name string) bool {
case "droid":
_, err := exec.LookPath("droid")
return err == nil
case "cline":
_, err := exec.LookPath("cline")
return err == nil
case "opencode":
_, err := exec.LookPath("opencode")
return err == nil
case "pi":
_, err := exec.LookPath("pi")
return err == nil
default:
return true // Assume installed for unknown integrations
}
@@ -212,7 +248,8 @@ type ModelItem struct {
}
// SingleSelector is a function type for single item selection.
type SingleSelector func(title string, items []ModelItem) (string, error)
// current is the name of the previously selected item to highlight; empty means no pre-selection.
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
// MultiSelector is a function type for multi item selection.
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
@@ -234,6 +271,11 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
}
cloudDisabled, _ := cloudStatusDisabled(ctx, client)
if cloudDisabled {
existing = filterCloudModels(existing)
}
lastModel := LastModel()
var preChecked []string
if lastModel != "" {
@@ -242,11 +284,15 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
items, _, existingModels, cloudModels := buildModelList(existing, preChecked, lastModel)
if cloudDisabled {
items = filterCloudItems(items)
}
if len(items) == 0 {
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
selected, err := selector("Select model to run:", items)
selected, err := selector("Select model to run:", items, "")
if err != nil {
return "", err
}
@@ -356,13 +402,11 @@ func selectIntegration() (string, error) {
return "", fmt.Errorf("no integrations available")
}
names := slices.Sorted(maps.Keys(integrations))
var items []ModelItem
for _, name := range names {
for name, r := range integrations {
if integrationAliases[name] {
continue
}
r := integrations[name]
description := r.String()
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
@@ -370,7 +414,25 @@ func selectIntegration() (string, error) {
items = append(items, ModelItem{Name: name, Description: description})
}
return DefaultSingleSelector("Select integration:", items)
orderRank := make(map[string]int, len(integrationOrder))
for i, name := range integrationOrder {
orderRank[name] = i + 1
}
slices.SortFunc(items, func(a, b ModelItem) int {
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
if aRank > 0 && bRank > 0 {
return aRank - bRank
}
if aRank > 0 {
return 1
}
if bRank > 0 {
return -1
}
return strings.Compare(a.Name, b.Name)
})
return DefaultSingleSelector("Select integration:", items, "")
}
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
@@ -395,6 +457,11 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
}
cloudDisabled, _ := cloudStatusDisabled(ctx, client)
if cloudDisabled {
existing = filterCloudModels(existing)
}
var preChecked []string
if saved, err := loadIntegration(name); err == nil {
preChecked = saved.Models
@@ -404,6 +471,10 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
items, preChecked, existingModels, cloudModels := buildModelList(existing, preChecked, current)
if cloudDisabled {
items = filterCloudItems(items)
}
if len(items) == 0 {
return nil, fmt.Errorf("no models available")
}
@@ -419,7 +490,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
if _, ok := r.(AliasConfigurer); ok {
prompt = fmt.Sprintf("Select Primary model for %s:", r)
}
model, err := single(prompt, items)
model, err := single(prompt, items, current)
if err != nil {
return nil, err
}
@@ -510,8 +581,17 @@ func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]b
})
}
cloudDisabled, _ := cloudStatusDisabled(ctx, client)
if cloudDisabled {
existing = filterCloudModels(existing)
}
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
if cloudDisabled {
items = filterCloudItems(items)
}
if len(items) == 0 {
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
@@ -540,6 +620,9 @@ func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]
if len(selectedCloudModels) == 0 {
return nil
}
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
}
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
@@ -672,25 +755,6 @@ func LaunchIntegrationWithModel(name, modelName string) error {
return runIntegration(name, modelName, nil)
}
// SaveIntegrationModel saves the model for an integration.
func SaveIntegrationModel(name, modelName string) error {
// Load existing models and prepend the new one
var models []string
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
models = existing.Models
// Remove the model if it already exists
for i, m := range models {
if m == modelName {
models = append(models[:i], models[i+1:]...)
break
}
}
}
// Prepend the new model
models = append([]string{modelName}, models...)
return saveIntegration(name, models)
}
// SaveAndEditIntegration saves the models for an Editor integration and runs its Edit method
// to write the integration's config files.
func SaveAndEditIntegration(name string, models []string) error {
@@ -698,7 +762,7 @@ func SaveAndEditIntegration(name string, models []string) error {
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
if err := saveIntegration(name, models); err != nil {
if err := SaveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
if editor, isEditor := r.(Editor); isEditor {
@@ -709,6 +773,29 @@ func SaveAndEditIntegration(name string, models []string) error {
return nil
}
// resolveEditorModels filters out cloud-disabled models before editor launch.
// If no models remain, it invokes picker to collect a valid replacement list.
func resolveEditorModels(name string, models []string, picker func() ([]string, error)) ([]string, error) {
filtered := filterDisabledCloudModels(models)
if len(filtered) != len(models) {
if err := SaveIntegration(name, filtered); err != nil {
return nil, fmt.Errorf("failed to save: %w", err)
}
}
if len(filtered) > 0 {
return filtered, nil
}
selected, err := picker()
if err != nil {
return nil, err
}
if err := SaveIntegration(name, selected); err != nil {
return nil, fmt.Errorf("failed to save: %w", err)
}
return selected, nil
}
// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors.
func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error {
r, ok := integrations[name]
@@ -743,7 +830,7 @@ func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single
}
}
if err := saveIntegration(name, models); err != nil {
if err := SaveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
@@ -776,10 +863,12 @@ Without arguments, this is equivalent to running 'ollama' directly.
Supported integrations:
claude Claude Code
cline Cline
codex Codex
droid Droid
opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot)
pi Pi
Examples:
ollama launch
@@ -837,6 +926,10 @@ Examples:
return fmt.Errorf("unknown integration: %s", name)
}
if modelFlag != "" && IsCloudModelDisabled(cmd.Context(), modelFlag) {
modelFlag = ""
}
// Handle AliasConfigurer integrations (claude, codex)
if ac, ok := r.(AliasConfigurer); ok {
client, err := api.ClientFromEnvironment()
@@ -864,7 +957,7 @@ Examples:
model = cfg.Models[0]
// AliasConfigurer integrations use single model; sanitize if multiple
if len(cfg.Models) > 1 {
_ = saveIntegration(name, []string{model})
_ = SaveIntegration(name, []string{model})
}
}
}
@@ -876,7 +969,9 @@ Examples:
// Validate saved model still exists
if model != "" && modelFlag == "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled && isCloudModelName(model) {
model = ""
} else if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
if err := ShowOrPull(cmd.Context(), client, model); err != nil {
model = ""
@@ -884,18 +979,16 @@ Examples:
}
}
// If no valid model or --config flag, show picker
if model == "" || configFlag {
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
// Show picker so user can change model (skip when --model flag provided)
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, modelFlag == "")
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
// Ensure cloud models are authenticated
if isCloudModel(cmd.Context(), client, model) {
@@ -908,7 +1001,7 @@ Examples:
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
}
if err := saveIntegration(name, []string{model}); err != nil {
if err := SaveIntegration(name, []string{model}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
@@ -946,11 +1039,24 @@ Examples:
}
}
}
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
return runIntegration(name, saved.Models[0], passArgs)
models = filterDisabledCloudModels(models)
if len(models) == 0 {
var err error
models, err = selectModels(cmd.Context(), name, "")
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
}
} else {
current := ""
if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 {
current = saved.Models[0]
}
var err error
models, err = selectModels(cmd.Context(), name, "")
models, err = selectModels(cmd.Context(), name, current)
if errors.Is(err, errCancelled) {
return nil
}
@@ -974,7 +1080,7 @@ Examples:
}
}
if err := saveIntegration(name, models); err != nil {
if err := SaveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
@@ -1048,7 +1154,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
continue
}
items = append(items, rec)
if strings.HasSuffix(rec.Name, ":cloud") {
if isCloudModelName(rec.Name) {
cloudModels[rec.Name] = true
}
}
@@ -1153,7 +1259,55 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
return items, preChecked, existingModels, cloudModels
}
// isCloudModel checks if a model is a cloud model using the Show API.
// IsCloudModelDisabled reports whether the given model name looks like a cloud
// model and cloud features are currently disabled on the server.
func IsCloudModelDisabled(ctx context.Context, name string) bool {
if !isCloudModelName(name) {
return false
}
client, err := api.ClientFromEnvironment()
if err != nil {
return false
}
disabled, _ := cloudStatusDisabled(ctx, client)
return disabled
}
func isCloudModelName(name string) bool {
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
}
func filterCloudModels(existing []modelInfo) []modelInfo {
filtered := existing[:0]
for _, m := range existing {
if !m.Remote {
filtered = append(filtered, m)
}
}
return filtered
}
// filterDisabledCloudModels removes cloud models from a list when cloud is disabled.
func filterDisabledCloudModels(models []string) []string {
var filtered []string
for _, m := range models {
if !IsCloudModelDisabled(context.Background(), m) {
filtered = append(filtered, m)
}
}
return filtered
}
func filterCloudItems(items []ModelItem) []ModelItem {
filtered := items[:0]
for _, item := range items {
if !isCloudModelName(item.Name) {
filtered = append(filtered, item)
}
}
return filtered
}
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
if client == nil {
return false
@@ -1183,6 +1337,11 @@ func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
}
cloudDisabled, _ := cloudStatusDisabled(ctx, client)
if cloudDisabled {
existing = filterCloudModels(existing)
}
lastModel := LastModel()
var preChecked []string
if lastModel != "" {
@@ -1191,9 +1350,25 @@ func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
items, _, existingModels, _ := buildModelList(existing, preChecked, lastModel)
if cloudDisabled {
items = filterCloudItems(items)
}
return items, existingModels
}
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
status, err := client.CloudStatusExperimental(ctx)
if err != nil {
var statusErr api.StatusError
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
return false, false
}
return false, false
}
return status.Cloud.Disabled, true
}
func pullModel(ctx context.Context, client *api.Client, model string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()

View File

@@ -16,6 +16,28 @@ import (
"github.com/spf13/cobra"
)
type stubEditorRunner struct {
edited [][]string
ranModel string
}
func (s *stubEditorRunner) Run(model string, args []string) error {
s.ranModel = model
return nil
}
func (s *stubEditorRunner) String() string { return "StubEditor" }
func (s *stubEditorRunner) Paths() []string { return nil }
func (s *stubEditorRunner) Edit(models []string) error {
cloned := append([]string(nil), models...)
s.edited = append(s.edited, cloned)
return nil
}
func (s *stubEditorRunner) Models() []string { return nil }
func TestIntegrationLookup(t *testing.T) {
tests := []struct {
name string
@@ -149,6 +171,10 @@ func TestLaunchCmd_TUICallback(t *testing.T) {
})
t.Run("integration arg bypasses TUI", func(t *testing.T) {
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
@@ -680,7 +706,7 @@ func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
setTestHome(t, tmpDir)
// Save a config for opencode so it looks like a previous launch
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
if err := SaveIntegration("opencode", []string{"llama3.2"}); err != nil {
t.Fatal(err)
}
@@ -697,6 +723,137 @@ func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
}
}
func TestResolveEditorLaunchModels_PicksWhenAllFiltered(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
pickerCalled := false
models, err := resolveEditorModels("opencode", []string{"glm-5:cloud"}, func() ([]string, error) {
pickerCalled = true
return []string{"llama3.2"}, nil
})
if err != nil {
t.Fatalf("resolveEditorLaunchModels returned error: %v", err)
}
if !pickerCalled {
t.Fatal("expected model picker to be called when all models are filtered")
}
if diff := cmp.Diff([]string{"llama3.2"}, models); diff != "" {
t.Fatalf("resolved models mismatch (-want +got):\n%s", diff)
}
saved, err := loadIntegration("opencode")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
}
func TestResolveEditorLaunchModels_FiltersAndSkipsPickerWhenLocalRemains(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
pickerCalled := false
models, err := resolveEditorModels("droid", []string{"llama3.2", "glm-5:cloud"}, func() ([]string, error) {
pickerCalled = true
return []string{"qwen3:8b"}, nil
})
if err != nil {
t.Fatalf("resolveEditorLaunchModels returned error: %v", err)
}
if pickerCalled {
t.Fatal("picker should not be called when a local model remains")
}
if diff := cmp.Diff([]string{"llama3.2"}, models); diff != "" {
t.Fatalf("resolved models mismatch (-want +got):\n%s", diff)
}
saved, err := loadIntegration("droid")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
}
func TestLaunchCmd_ModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
t.Fatalf("failed to seed saved config: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
case "/api/show":
fmt.Fprintf(w, `{"model":"llama3.2"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &stubEditorRunner{}
old, existed := integrations["stubeditor"]
integrations["stubeditor"] = stub
defer func() {
if existed {
integrations["stubeditor"] = old
} else {
delete(integrations, "stubeditor")
}
}()
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
saved, err := loadIntegration("stubeditor")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
}
}
func TestAliasConfigurerInterface(t *testing.T) {
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
claude := &Claude{}
@@ -1091,10 +1248,26 @@ func TestListIntegrationInfos(t *testing.T) {
}
})
t.Run("sorted by name", func(t *testing.T) {
t.Run("sorted with custom order at end", func(t *testing.T) {
// integrationOrder entries (cline, opencode) should appear last, in that order.
// All other entries should be sorted alphabetically before them.
orderRank := make(map[string]int)
for i, name := range integrationOrder {
orderRank[name] = i + 1
}
for i := 1; i < len(infos); i++ {
if infos[i-1].Name >= infos[i].Name {
t.Errorf("not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
switch {
case aRank == 0 && bRank == 0:
if infos[i-1].Name >= infos[i].Name {
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
}
case aRank > 0 && bRank == 0:
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
case aRank > 0 && bRank > 0:
if aRank >= bRank {
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
}
}
}
})
@@ -1234,7 +1407,7 @@ func TestIntegrationModels(t *testing.T) {
})
t.Run("returns all saved models", func(t *testing.T) {
if err := saveIntegration("droid", []string{"llama3.2", "qwen3:8b"}); err != nil {
if err := SaveIntegration("droid", []string{"llama3.2", "qwen3:8b"}); err != nil {
t.Fatal(err)
}
got := IntegrationModels("droid")

View File

@@ -2,7 +2,9 @@ package config
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
@@ -32,6 +34,16 @@ func (c *Openclaw) Run(model string, args []string) error {
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("openclaw", models, func() ([]string, error) {
return selectModels(context.Background(), "openclaw", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
@@ -58,7 +70,7 @@ func (c *Openclaw) Run(model string, args []string) error {
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
err := cmd.Run()
err = cmd.Run()
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
return nil

View File

@@ -3,6 +3,7 @@ package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"maps"
"os"
@@ -51,6 +52,16 @@ func (o *OpenCode) Run(model string, args []string) error {
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
return selectModels(context.Background(), "opencode", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := o.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}

View File

@@ -365,14 +365,27 @@ func (m selectorModel) View() string {
return s
}
func SelectSingle(title string, items []SelectItem) (string, error) {
// cursorForCurrent returns the item index matching current, or 0 if not found.
func cursorForCurrent(items []SelectItem, current string) int {
if current != "" {
for i, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
return i
}
}
}
return 0
}
func SelectSingle(title string, items []SelectItem, current string) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
m := selectorModel{
title: title,
items: items,
title: title,
items: items,
cursor: cursorForCurrent(items, current),
}
p := tea.NewProgram(m)
@@ -402,6 +415,12 @@ type multiSelectorModel struct {
cancelled bool
confirmed bool
width int
// multi enables full multi-select editing mode. The zero value (false)
// shows a single-select picker where Enter adds the chosen model to
// the existing list. Tab toggles between modes.
multi bool
singleAdd string // model picked in single mode
}
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
@@ -416,13 +435,23 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
m.itemIndex[item.Name] = i
}
for _, name := range preChecked {
if idx, ok := m.itemIndex[name]; ok {
// Reverse order so preChecked[0] (the current default) ends up last
// in checkOrder, matching the "last checked = default" convention.
for i := len(preChecked) - 1; i >= 0; i-- {
if idx, ok := m.itemIndex[preChecked[i]]; ok {
m.checked[idx] = true
m.checkOrder = append(m.checkOrder, idx)
}
}
// Position cursor on the current default model
if len(preChecked) > 0 {
if idx, ok := m.itemIndex[preChecked[0]]; ok {
m.cursor = idx
m.updateScroll(m.otherStart())
}
}
return m
}
@@ -533,14 +562,25 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.cancelled = true
return m, tea.Quit
case tea.KeyTab:
m.multi = !m.multi
case tea.KeyEnter:
if len(m.checkOrder) > 0 {
if !m.multi {
if len(filtered) > 0 && m.cursor < len(filtered) {
m.singleAdd = filtered[m.cursor].Name
m.confirmed = true
return m, tea.Quit
}
} else if len(m.checkOrder) > 0 {
m.confirmed = true
return m, tea.Quit
}
case tea.KeySpace:
m.toggleItem()
if m.multi {
m.toggleItem()
}
case tea.KeyUp:
if m.cursor > 0 {
@@ -576,15 +616,36 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
case tea.KeyRunes:
m.filter += string(msg.Runes)
m.cursor = 0
m.scrollOffset = 0
// On some terminals (e.g. Windows PowerShell), space arrives as
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
if m.multi {
m.toggleItem()
}
} else {
m.filter += string(msg.Runes)
m.cursor = 0
m.scrollOffset = 0
}
}
}
return m, nil
}
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
} else {
s.WriteString(selectorItemStyle.Render(item.Name))
}
s.WriteString("\n")
if item.Description != "" {
s.WriteString(selectorDescLineStyle.Render(item.Description))
s.WriteString("\n")
}
}
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
origIdx := m.itemIndex[item.Name]
@@ -596,7 +657,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
}
suffix := ""
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
suffix = " " + selectorDefaultTagStyle.Render("(default)")
}
@@ -618,6 +679,11 @@ func (m multiSelectorModel) View() string {
return ""
}
renderItem := m.renderSingleItem
if m.multi {
renderItem = m.renderMultiItem
}
var s strings.Builder
s.WriteString(selectorTitleStyle.Render(m.title))
@@ -642,7 +708,7 @@ func (m multiSelectorModel) View() string {
if idx >= len(filtered) {
break
}
m.renderMultiItem(&s, filtered[idx], idx)
renderItem(&s, filtered[idx], idx)
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
@@ -665,7 +731,7 @@ func (m multiSelectorModel) View() string {
s.WriteString(sectionHeaderStyle.Render("Recommended"))
s.WriteString("\n")
for _, idx := range recItems {
m.renderMultiItem(&s, filtered[idx], idx)
renderItem(&s, filtered[idx], idx)
}
}
@@ -685,7 +751,7 @@ func (m multiSelectorModel) View() string {
if idx >= len(otherItems) {
break
}
m.renderMultiItem(&s, filtered[otherItems[idx]], otherItems[idx])
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
}
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
@@ -697,15 +763,18 @@ func (m multiSelectorModel) View() string {
s.WriteString("\n")
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
if !m.multi {
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
} else {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
} else {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
result := s.String()
if m.width > 0 {
@@ -728,18 +797,28 @@ func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]st
}
fm := finalModel.(multiSelectorModel)
if fm.cancelled {
if fm.cancelled || !fm.confirmed {
return nil, ErrCancelled
}
if !fm.confirmed {
return nil, ErrCancelled
// Single-add mode: prepend the picked model, keep existing models deduped
if fm.singleAdd != "" {
result := []string{fm.singleAdd}
for _, name := range preChecked {
if name != fm.singleAdd {
result = append(result, name)
}
}
return result, nil
}
var result []string
// Multi-edit mode: last checked is default (first in result)
last := fm.checkOrder[len(fm.checkOrder)-1]
result := []string{fm.items[last].Name}
for _, idx := range fm.checkOrder {
result = append(result, fm.items[idx].Name)
if idx != last {
result = append(result, fm.items[idx].Name)
}
}
return result, nil
}

View File

@@ -382,6 +382,42 @@ func TestUpdateNavigation_Backspace(t *testing.T) {
}
}
// --- cursorForCurrent ---
func TestCursorForCurrent(t *testing.T) {
testItems := []SelectItem{
{Name: "llama3.2", Recommended: true},
{Name: "qwen3:8b", Recommended: true},
{Name: "gemma3:latest"},
{Name: "deepseek-r1"},
{Name: "glm-5:cloud"},
}
tests := []struct {
name string
current string
want int
}{
{"empty current", "", 0},
{"exact match", "qwen3:8b", 1},
{"no match returns 0", "nonexistent", 0},
{"bare name matches with :latest suffix", "gemma3", 2},
{"full tag matches bare item", "llama3.2:latest", 0},
{"cloud model exact match", "glm-5:cloud", 4},
{"cloud model bare name", "glm-5", 4},
{"recommended item exact match", "llama3.2", 0},
{"recommended item with tag", "qwen3", 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := cursorForCurrent(testItems, tt.current); got != tt.want {
t.Errorf("cursorForCurrent(%q) = %d, want %d", tt.current, got, tt.want)
}
})
}
}
// --- ReorderItems ---
func TestReorderItems(t *testing.T) {
@@ -503,6 +539,7 @@ func TestMultiView_CursorIndicator(t *testing.T) {
func TestMultiView_CheckedItemShowsX(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
m.multi = true
content := m.View()
if !strings.Contains(content, "[x]") {
@@ -514,11 +551,18 @@ func TestMultiView_CheckedItemShowsX(t *testing.T) {
}
func TestMultiView_DefaultTag(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b"})
m.multi = true
content := m.View()
if !strings.Contains(content, "(default)") {
t.Error("first checked item should have (default) tag")
t.Error("should have (default) tag")
}
// preChecked[0] ("a") should be the default (last in checkOrder)
aIdx := strings.Index(content, "a")
defaultIdx := strings.Index(content, "(default)")
if defaultIdx < aIdx {
t.Error("(default) tag should appear after 'a' (the current default)")
}
}
@@ -545,6 +589,200 @@ func TestMultiView_OverflowIndicator(t *testing.T) {
}
}
// --- Multi-select space toggle (including KeyRunes fallback for Windows PowerShell) ---
func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.multi = true
m.cursor = 1
// Simulate space delivered as tea.KeySpace
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
m = updated.(multiSelectorModel)
if !m.checked[1] {
t.Error("space (KeySpace) should toggle the item at cursor")
}
if m.filter != "" {
t.Error("space should not modify filter")
}
}
func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.multi = true
m.cursor = 1
// Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
m = updated.(multiSelectorModel)
if !m.checked[1] {
t.Error("space (KeyRunes) should toggle the item at cursor")
}
if m.filter != "" {
t.Error("space rune should not be added to filter")
}
if m.cursor != 1 {
t.Errorf("cursor should stay at 1, got %d", m.cursor)
}
}
// --- Single-add mode ---
func TestMulti_StartsInSingleMode(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
if m.multi {
t.Error("should start in single mode (multi=false)")
}
}
func TestMulti_SingleModeNoCheckboxes(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
content := m.View()
if strings.Contains(content, "[x]") || strings.Contains(content, "[ ]") {
t.Error("single mode should not show checkboxes")
}
if !strings.Contains(content, "▸") {
t.Error("single mode should show cursor indicator")
}
}
func TestMulti_SingleModeEnterPicksItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.cursor = 1
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
m = updated.(multiSelectorModel)
if m.singleAdd != "b" {
t.Errorf("enter in single mode should pick cursor item, got %q", m.singleAdd)
}
if !m.confirmed {
t.Error("should set confirmed")
}
}
func TestMulti_SingleModeSpaceIsNoop(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
m.cursor = 0
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
m = updated.(multiSelectorModel)
if len(m.checked) != 0 {
t.Error("space in single mode should not toggle items")
}
}
func TestMulti_SingleModeSpaceRuneIsNoop(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
m.cursor = 0
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
m = updated.(multiSelectorModel)
if len(m.checked) != 0 {
t.Error("space rune in single mode should not toggle items")
}
if m.filter != "" {
t.Error("space rune in single mode should not add to filter")
}
}
func TestMulti_TabTogglesMode(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
if m.multi {
t.Fatal("should start in single mode")
}
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = updated.(multiSelectorModel)
if !m.multi {
t.Error("tab should switch to multi mode")
}
updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = updated.(multiSelectorModel)
if m.multi {
t.Error("tab should switch back to single mode")
}
}
func TestMulti_SingleModeHelpText(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a"), nil)
content := m.View()
if !strings.Contains(content, "tab add multiple") {
t.Error("single mode should show 'tab add multiple' in help")
}
}
func TestMulti_MultiModeHelpText(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a"), nil)
m.multi = true
content := m.View()
if !strings.Contains(content, "tab select single") {
t.Error("multi mode should show 'tab select single' in help")
}
}
// --- preChecked initialization order ---
func TestMulti_PreCheckedDefaultIsLast(t *testing.T) {
// preChecked[0] ("a") is the current default and should end up
// last in checkOrder so it gets the (default) tag.
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b", "c"})
if len(m.checkOrder) != 3 {
t.Fatalf("expected 3 in checkOrder, got %d", len(m.checkOrder))
}
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "a" {
t.Errorf("preChecked[0] should be last in checkOrder, got %q", m.items[lastIdx].Name)
}
}
func TestMulti_CursorOnDefaultModel(t *testing.T) {
// preChecked[0] ("b") is the default; cursor should start on it
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c"})
if m.cursor != 1 {
t.Errorf("cursor should be on preChecked[0] ('b') at index 1, got %d", m.cursor)
}
}
// --- Multi-mode last-checked is default ---
func TestMulti_LastCheckedIsDefault(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("alpha", "beta", "gamma"), nil)
m.multi = true
// Check "alpha" then "gamma"
m.cursor = 0
m.toggleItem()
m.cursor = 2
m.toggleItem()
// Last checked ("gamma") should be at the end of checkOrder
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "gamma" {
t.Errorf("last checked should be 'gamma', got %q", m.items[lastIdx].Name)
}
// The (default) tag renders based on checkOrder[len-1]
content := m.View()
if !strings.Contains(content, "(default)") {
t.Fatal("should show (default) tag")
}
// "alpha" line should NOT have the default tag
for _, line := range strings.Split(content, "\n") {
if strings.Contains(line, "alpha") && strings.Contains(line, "(default)") {
t.Error("'alpha' (first checked) should not have (default) tag")
}
}
}
// Key message helpers for testing
type keyType = int

View File

@@ -131,7 +131,7 @@ type model struct {
signInURL string
signInModel string
signInSpinner int
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
width int // terminal width from WindowSizeMsg
statusMsg string // temporary status message shown near help text
@@ -209,7 +209,26 @@ func (m *model) openMultiModelModal(integration string) {
}
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud")
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
}
func cloudStatusDisabled(client *api.Client) bool {
status, err := client.CloudStatusExperimental(context.Background())
if err != nil {
return false
}
return status.Cloud.Disabled
}
func cloudModelDisabled(name string) bool {
if !isCloudModel(name) {
return false
}
client, err := api.ClientFromEnvironment()
if err != nil {
return false
}
return cloudStatusDisabled(client)
}
// checkCloudSignIn checks if a cloud model needs sign-in.
@@ -222,6 +241,9 @@ func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
if err != nil {
return nil
}
if cloudStatusDisabled(client) {
return nil
}
user, err := client.Whoami(context.Background())
if err == nil && user != nil && user.Name != "" {
return nil
@@ -272,7 +294,11 @@ func (m *model) loadAvailableModels() {
if err != nil {
return
}
cloudDisabled := cloudStatusDisabled(client)
for _, mdl := range models.Models {
if cloudDisabled && mdl.RemoteModel != "" {
continue
}
m.availableModels[mdl.Name] = true
}
}
@@ -403,8 +429,24 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
if m.multiModalSelector.confirmed {
var selected []string
for _, idx := range m.multiModalSelector.checkOrder {
selected = append(selected, m.multiModalSelector.items[idx].Name)
if m.multiModalSelector.singleAdd != "" {
// Single-add mode: prepend picked model, keep existing deduped
selected = []string{m.multiModalSelector.singleAdd}
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
if name != m.multiModalSelector.singleAdd {
selected = append(selected, name)
}
}
} else {
// Last checked is default (first in result)
co := m.multiModalSelector.checkOrder
last := co[len(co)-1]
selected = []string{m.multiModalSelector.items[last].Name}
for _, idx := range co {
if idx != last {
selected = append(selected, m.multiModalSelector.items[idx].Name)
}
}
}
if len(selected) > 0 {
m.changeModels = selected
@@ -496,6 +538,15 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, cmd
}
if configuredModel != "" && isCloudModel(configuredModel) && cloudModelDisabled(configuredModel) {
if item.integration != "" && config.IsEditorIntegration(item.integration) {
m.openMultiModelModal(item.integration)
} else {
m.openModelModal(configuredModel)
}
return m, nil
}
m.selected = true
m.quitting = true
return m, tea.Quit

View File

@@ -226,3 +226,7 @@ curl https://ollama.com/api/chat \
</Tab>
</Tabs>
## Local only
Ollama can run in local-only mode by [disabling Ollama's cloud](./faq#how-do-i-disable-ollama-cloud) features.

View File

@@ -106,20 +106,23 @@
"group": "Integrations",
"pages": [
"/integrations/index",
{
"group": "Assistants",
"expanded": true,
"pages": [
"/integrations/openclaw"
]
},
{
"group": "Coding",
"expanded": true,
"pages": [
"/integrations/claude-code",
"/integrations/codex",
"/integrations/opencode",
"/integrations/droid",
"/integrations/goose"
]
},
{
"group": "Assistants",
"pages": [
"/integrations/openclaw"
"/integrations/goose",
"/integrations/pi"
]
},
{

View File

@@ -160,6 +160,26 @@ docker run -d -e HTTPS_PROXY=https://my.proxy.example.com -p 11434:11434 ollama-
Ollama runs locally. We don't see your prompts or data when you run locally. When using cloud-hosted models, we process your prompts and responses to provide the service but do not store or log that content and never train on it. We collect basic account info and limited usage metadata to provide the service that does not include prompt or response content. We don't sell your data. You can delete your account anytime.
## How do I disable Ollama's cloud features?
Ollama can run in local only mode by disabling Ollama's cloud features. By turning off Ollama's cloud features, you will lose the ability to use Ollama's cloud models and web search.
Set `disable_ollama_cloud` in `~/.ollama/server.json`:
```json
{
"disable_ollama_cloud": true
}
```
You can also set the environment variable:
```shell
OLLAMA_NO_CLOUD=1
```
Restart Ollama after changing configuration. Once disabled, Ollama's logs will show `Ollama cloud disabled: true`.
## How can I expose Ollama on my network?
Ollama binds 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable.

View File

@@ -13,6 +13,7 @@ Coding assistants that can read, modify, and execute code in your projects.
- [OpenCode](/integrations/opencode)
- [Droid](/integrations/droid)
- [Goose](/integrations/goose)
- [Pi](/integrations/pi)
## Assistants

57
docs/integrations/pi.mdx Normal file
View File

@@ -0,0 +1,57 @@
---
title: Pi
---
Pi is a minimal AI agent toolkit with plugin support.
## Install
Install [Pi](https://github.com/badlogic/pi-mono):
```bash
npm install -g @mariozechner/pi-coding-agent
```
## Usage with Ollama
### Quick setup
```bash
ollama launch pi
```
To configure without launching:
```shell
ollama launch pi --config
```
### Manual setup
Add a configuration block to `~/.pi/agent/models.json`:
```json
{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{
"id": "qwen3-coder"
}
]
}
}
}
```
Update `~/.pi/agent/settings.json` to set the default provider:
```json
{
"defaultProvider": "ollama",
"defaultModel": "qwen3-coder"
}
```

View File

@@ -27,9 +27,17 @@ The menu provides quick access to:
- **Launch tools** - Claude Code, Codex, OpenClaw, and more
- **Additional integrations** - Available under "More..."
## Assistants
Launch [OpenClaw](/integrations/openclaw), a personal AI with 100+ skills:
```sh
ollama launch openclaw
```
## Coding
Launch coding tools with Ollama models:
Launch [Claude Code](/integrations/claude-code) and other coding tools with Ollama models:
```sh
ollama launch claude

View File

@@ -1,6 +1,8 @@
package envconfig
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"math"
@@ -11,6 +13,7 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"time"
)
@@ -206,6 +209,8 @@ var (
UseAuth = Bool("OLLAMA_AUTH")
// Enable Vulkan backend
EnableVulkan = Bool("OLLAMA_VULKAN")
// NoCloudEnv checks the OLLAMA_NO_CLOUD environment variable.
NoCloudEnv = Bool("OLLAMA_NO_CLOUD")
)
func String(s string) func() string {
@@ -285,6 +290,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
@@ -334,3 +340,91 @@ func Values() map[string]string {
func Var(key string) string {
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
}
// serverConfigData holds the parsed fields from ~/.ollama/server.json.
type serverConfigData struct {
DisableOllamaCloud bool `json:"disable_ollama_cloud,omitempty"`
}
var (
serverCfgMu sync.RWMutex
serverCfgLoaded bool
serverCfg serverConfigData
)
func loadServerConfig() {
serverCfgMu.RLock()
if serverCfgLoaded {
serverCfgMu.RUnlock()
return
}
serverCfgMu.RUnlock()
cfg := serverConfigData{}
home, err := os.UserHomeDir()
if err == nil {
path := filepath.Join(home, ".ollama", "server.json")
data, err := os.ReadFile(path)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
slog.Debug("envconfig: could not read server config", "error", err)
}
} else if err := json.Unmarshal(data, &cfg); err != nil {
slog.Debug("envconfig: could not parse server config", "error", err)
}
}
serverCfgMu.Lock()
defer serverCfgMu.Unlock()
if serverCfgLoaded {
return
}
serverCfg = cfg
serverCfgLoaded = true
}
func cachedServerConfig() serverConfigData {
serverCfgMu.RLock()
defer serverCfgMu.RUnlock()
return serverCfg
}
// ReloadServerConfig refreshes the cached ~/.ollama/server.json settings.
func ReloadServerConfig() {
serverCfgMu.Lock()
serverCfgLoaded = false
serverCfg = serverConfigData{}
serverCfgMu.Unlock()
loadServerConfig()
}
// NoCloud returns true if Ollama cloud features are disabled,
// checking both the OLLAMA_NO_CLOUD environment variable and
// the disable_ollama_cloud field in ~/.ollama/server.json.
func NoCloud() bool {
if NoCloudEnv() {
return true
}
loadServerConfig()
return cachedServerConfig().DisableOllamaCloud
}
// NoCloudSource returns the source of the cloud-disabled decision.
// Returns "none", "env", "config", or "both".
func NoCloudSource() string {
envDisabled := NoCloudEnv()
loadServerConfig()
configDisabled := cachedServerConfig().DisableOllamaCloud
switch {
case envDisabled && configDisabled:
return "both"
case envDisabled:
return "env"
case configDisabled:
return "config"
default:
return "none"
}
}

View File

@@ -3,6 +3,8 @@ package envconfig
import (
"log/slog"
"math"
"os"
"path/filepath"
"testing"
"time"
@@ -326,3 +328,81 @@ func TestLogLevel(t *testing.T) {
})
}
}
func TestNoCloud(t *testing.T) {
tests := []struct {
name string
envValue string
configContent string
wantDisabled bool
wantSource string
}{
{
name: "neither env nor config",
wantDisabled: false,
wantSource: "none",
},
{
name: "env only",
envValue: "1",
wantDisabled: true,
wantSource: "env",
},
{
name: "config only",
configContent: `{"disable_ollama_cloud": true}`,
wantDisabled: true,
wantSource: "config",
},
{
name: "both env and config",
envValue: "1",
configContent: `{"disable_ollama_cloud": true}`,
wantDisabled: true,
wantSource: "both",
},
{
name: "config false",
configContent: `{"disable_ollama_cloud": false}`,
wantDisabled: false,
wantSource: "none",
},
{
name: "invalid config ignored",
configContent: `{invalid json`,
wantDisabled: false,
wantSource: "none",
},
{
name: "no config file",
wantDisabled: false,
wantSource: "none",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
home := t.TempDir()
if tt.configContent != "" {
configDir := filepath.Join(home, ".ollama")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(configDir, "server.json"), []byte(tt.configContent), 0o644); err != nil {
t.Fatal(err)
}
}
setTestHome(t, home)
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
if got := NoCloud(); got != tt.wantDisabled {
t.Errorf("NoCloud() = %v, want %v", got, tt.wantDisabled)
}
if got := NoCloudSource(); got != tt.wantSource {
t.Errorf("NoCloudSource() = %q, want %q", got, tt.wantSource)
}
})
}
}

View File

@@ -0,0 +1,10 @@
package envconfig
import "testing"
func setTestHome(t *testing.T, home string) {
t.Helper()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
ReloadServerConfig()
}

25
internal/cloud/policy.go Normal file
View File

@@ -0,0 +1,25 @@
package cloud
import (
"github.com/ollama/ollama/envconfig"
)
const DisabledMessagePrefix = "ollama cloud is disabled"
// Status returns whether cloud is disabled and the source of the decision.
// Source is one of: "none", "env", "config", "both".
func Status() (disabled bool, source string) {
return envconfig.NoCloud(), envconfig.NoCloudSource()
}
func Disabled() bool {
return envconfig.NoCloud()
}
func DisabledError(operation string) string {
if operation == "" {
return DisabledMessagePrefix
}
return DisabledMessagePrefix + ": " + operation
}

View File

@@ -0,0 +1,85 @@
package cloud
import (
"os"
"path/filepath"
"testing"
)
func TestStatus(t *testing.T) {
tests := []struct {
name string
envValue string
configContent string
disabled bool
source string
}{
{
name: "none",
disabled: false,
source: "none",
},
{
name: "env only",
envValue: "1",
disabled: true,
source: "env",
},
{
name: "config only",
configContent: `{"disable_ollama_cloud": true}`,
disabled: true,
source: "config",
},
{
name: "both",
envValue: "1",
configContent: `{"disable_ollama_cloud": true}`,
disabled: true,
source: "both",
},
{
name: "invalid config ignored",
configContent: `{invalid json`,
disabled: false,
source: "none",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
home := t.TempDir()
if tt.configContent != "" {
configPath := filepath.Join(home, ".ollama", "server.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(configPath, []byte(tt.configContent), 0o644); err != nil {
t.Fatal(err)
}
}
setTestHome(t, home)
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
disabled, source := Status()
if disabled != tt.disabled {
t.Fatalf("disabled: expected %v, got %v", tt.disabled, disabled)
}
if source != tt.source {
t.Fatalf("source: expected %q, got %q", tt.source, source)
}
})
}
}
func TestDisabledError(t *testing.T) {
if got := DisabledError(""); got != DisabledMessagePrefix {
t.Fatalf("expected %q, got %q", DisabledMessagePrefix, got)
}
want := DisabledMessagePrefix + ": remote inference is unavailable"
if got := DisabledError("remote inference is unavailable"); got != want {
t.Fatalf("expected %q, got %q", want, got)
}
}

View File

@@ -0,0 +1,14 @@
package cloud
import (
"testing"
"github.com/ollama/ollama/envconfig"
)
func setTestHome(t *testing.T, home string) {
t.Helper()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
envconfig.ReloadServerConfig()
}

View File

@@ -2,15 +2,22 @@ package middleware
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/logutil"
)
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
@@ -18,7 +25,6 @@ type AnthropicWriter struct {
BaseWriter
stream bool
id string
model string
converter *anthropic.StreamConverter
}
@@ -31,7 +37,7 @@ func (w *AnthropicWriter) writeError(data []byte) (int, error) {
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.Status(), errData.Error))
if err != nil {
return 0, err
}
@@ -40,18 +46,7 @@ func (w *AnthropicWriter) writeError(data []byte) (int, error) {
}
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
d, err := json.Marshal(data)
if err != nil {
return err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
if err != nil {
return err
}
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
return nil
return writeSSE(w.ResponseWriter, eventType, data)
}
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
@@ -65,6 +60,7 @@ func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
events := w.converter.Process(chatResponse)
logutil.Trace("anthropic middleware: stream chunk", "resp", anthropic.TraceChatResponse(chatResponse), "events", len(events))
for _, event := range events {
if err := w.writeEvent(event.Event, event.Data); err != nil {
return 0, err
@@ -75,6 +71,7 @@ func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := anthropic.ToMessagesResponse(w.id, chatResponse)
logutil.Trace("anthropic middleware: converted response", "resp", anthropic.TraceMessagesResponse(response))
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
@@ -87,9 +84,743 @@ func (w *AnthropicWriter) Write(data []byte) (int, error) {
return w.writeResponse(data)
}
// WebSearchAnthropicWriter intercepts responses containing web_search tool calls,
// executes the search, re-invokes the model with results, and assembles the
// Anthropic-format response (server_tool_use + web_search_tool_result + text).
type WebSearchAnthropicWriter struct {
BaseWriter
newLoopContext func() (context.Context, context.CancelFunc)
inner *AnthropicWriter
req anthropic.MessagesRequest // original Anthropic request
chatReq *api.ChatRequest // converted Ollama request (for followup calls)
stream bool
estimatedInputTokens int
terminalSent bool
observedPromptEvalCount int
observedEvalCount int
loopInFlight bool
loopBaseInputTok int
loopBaseOutputTok int
loopResultCh chan webSearchLoopResult
streamMessageStarted bool
streamHasOpenBlock bool
streamOpenBlockIndex int
streamNextIndex int
}
const maxWebSearchLoops = 3
type webSearchLoopResult struct {
response anthropic.MessagesResponse
loopErr *webSearchLoopError
}
type webSearchLoopError struct {
code string
query string
usage anthropic.Usage
err error
}
func (e *webSearchLoopError) Error() string {
if e.err == nil {
return e.code
}
return fmt.Sprintf("%s: %v", e.code, e.err)
}
func (w *WebSearchAnthropicWriter) Write(data []byte) (int, error) {
if w.terminalSent {
return len(data), nil
}
code := w.Status()
if code != http.StatusOK {
return w.inner.writeError(data)
}
var chatResponse api.ChatResponse
if err := json.Unmarshal(data, &chatResponse); err != nil {
return 0, err
}
w.recordObservedUsage(chatResponse.Metrics)
if w.stream && w.loopInFlight {
if !chatResponse.Done {
return len(data), nil
}
if err := w.writeLoopResult(); err != nil {
return len(data), err
}
return len(data), nil
}
webSearchCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(chatResponse.Message.ToolCalls)
logutil.Trace("anthropic middleware: upstream chunk",
"resp", anthropic.TraceChatResponse(chatResponse),
"web_search", hasWebSearch,
"other_tools", hasOtherTools,
)
if hasWebSearch && hasOtherTools {
// Prefer web_search if both server and client tools are present in one chunk.
slog.Debug("preferring web_search tool call over client tool calls in mixed tool response")
}
if !hasWebSearch {
if w.stream {
if err := w.writePassthroughStreamChunk(chatResponse); err != nil {
return 0, err
}
return len(data), nil
}
return w.inner.writeResponse(data)
}
if w.stream {
// Let the original generation continue to completion while web search runs in parallel.
logutil.Trace("anthropic middleware: starting async web_search loop",
"tool_call", anthropic.TraceToolCall(webSearchCall),
"resp", anthropic.TraceChatResponse(chatResponse),
)
w.startLoopWorker(chatResponse, webSearchCall)
if chatResponse.Done {
if err := w.writeLoopResult(); err != nil {
return len(data), err
}
}
return len(data), nil
}
loopCtx, cancel := w.startLoopContext()
defer cancel()
initialUsage := anthropic.Usage{
InputTokens: max(w.observedPromptEvalCount, chatResponse.Metrics.PromptEvalCount),
OutputTokens: max(w.observedEvalCount, chatResponse.Metrics.EvalCount),
}
logutil.Trace("anthropic middleware: starting sync web_search loop",
"tool_call", anthropic.TraceToolCall(webSearchCall),
"resp", anthropic.TraceChatResponse(chatResponse),
"usage", initialUsage,
)
response, loopErr := w.runWebSearchLoop(loopCtx, chatResponse, webSearchCall, initialUsage)
if loopErr != nil {
return len(data), w.sendError(loopErr.code, loopErr.query, loopErr.usage)
}
if err := w.writeTerminalResponse(response); err != nil {
return 0, err
}
return len(data), nil
}
func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initialResponse api.ChatResponse, initialToolCall api.ToolCall, initialUsage anthropic.Usage) (anthropic.MessagesResponse, *webSearchLoopError) {
followUpMessages := make([]api.Message, 0, len(w.chatReq.Messages)+maxWebSearchLoops*2)
followUpMessages = append(followUpMessages, w.chatReq.Messages...)
followUpTools := append(api.Tools(nil), w.chatReq.Tools...)
usage := initialUsage
logutil.TraceContext(ctx, "anthropic middleware: web_search loop init",
"model", w.req.Model,
"tool_call", anthropic.TraceToolCall(initialToolCall),
"messages", len(followUpMessages),
"tools", len(followUpTools),
"max_loops", maxWebSearchLoops,
)
currentResponse := initialResponse
currentToolCall := initialToolCall
var serverContent []anthropic.ContentBlock
if !isCloudModelName(w.req.Model) {
logutil.TraceContext(ctx, "anthropic middleware: web_search execution blocked", "reason", "non_cloud_model")
return anthropic.MessagesResponse{}, &webSearchLoopError{
code: "web_search_not_supported_for_local_models",
query: extractQueryFromToolCall(&initialToolCall),
usage: usage,
}
}
for loop := 1; loop <= maxWebSearchLoops; loop++ {
query := extractQueryFromToolCall(&currentToolCall)
logutil.TraceContext(ctx, "anthropic middleware: web_search loop iteration",
"loop", loop,
"query", anthropic.TraceTruncateString(query),
"messages", len(followUpMessages),
)
if query == "" {
return anthropic.MessagesResponse{}, &webSearchLoopError{
code: "invalid_request",
query: "",
usage: usage,
}
}
const defaultMaxResults = 5
searchResp, err := anthropic.WebSearch(ctx, query, defaultMaxResults)
if err != nil {
logutil.TraceContext(ctx, "anthropic middleware: web_search request failed",
"loop", loop,
"query", query,
"error", err,
)
return anthropic.MessagesResponse{}, &webSearchLoopError{
code: "unavailable",
query: query,
usage: usage,
err: err,
}
}
logutil.TraceContext(ctx, "anthropic middleware: web_search results",
"loop", loop,
"results", len(searchResp.Results),
)
toolUseID := loopServerToolUseID(w.inner.id, loop)
searchResults := anthropic.ConvertOllamaToAnthropicResults(searchResp)
serverContent = append(serverContent,
anthropic.ContentBlock{
Type: "server_tool_use",
ID: toolUseID,
Name: "web_search",
Input: map[string]any{"query": query},
},
anthropic.ContentBlock{
Type: "web_search_tool_result",
ToolUseID: toolUseID,
Content: searchResults,
},
)
assistantMsg := buildWebSearchAssistantMessage(currentResponse, currentToolCall)
toolResultMsg := api.Message{
Role: "tool",
Content: formatWebSearchResultsForToolMessage(searchResp.Results),
ToolCallID: currentToolCall.ID,
}
followUpMessages = append(followUpMessages, assistantMsg, toolResultMsg)
followUpResponse, err := w.callFollowUpChat(ctx, followUpMessages, followUpTools)
if err != nil {
logutil.TraceContext(ctx, "anthropic middleware: followup /api/chat failed",
"loop", loop,
"query", query,
"error", err,
)
return anthropic.MessagesResponse{}, &webSearchLoopError{
code: "api_error",
query: query,
usage: usage,
err: err,
}
}
logutil.TraceContext(ctx, "anthropic middleware: followup response",
"loop", loop,
"resp", anthropic.TraceChatResponse(followUpResponse),
)
usage.InputTokens += followUpResponse.Metrics.PromptEvalCount
usage.OutputTokens += followUpResponse.Metrics.EvalCount
nextToolCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(followUpResponse.Message.ToolCalls)
if hasWebSearch && hasOtherTools {
// Prefer web_search if both server and client tools are present in one chunk.
slog.Debug("preferring web_search tool call over client tool calls in mixed followup response")
}
if !hasWebSearch {
finalResponse := w.combineServerAndFinalContent(serverContent, followUpResponse, usage)
logutil.TraceContext(ctx, "anthropic middleware: web_search loop complete",
"loop", loop,
"resp", anthropic.TraceMessagesResponse(finalResponse),
)
return finalResponse, nil
}
currentResponse = followUpResponse
currentToolCall = nextToolCall
}
maxLoopQuery := extractQueryFromToolCall(&currentToolCall)
maxLoopToolUseID := loopServerToolUseID(w.inner.id, maxWebSearchLoops+1)
serverContent = append(serverContent,
anthropic.ContentBlock{
Type: "server_tool_use",
ID: maxLoopToolUseID,
Name: "web_search",
Input: map[string]any{"query": maxLoopQuery},
},
anthropic.ContentBlock{
Type: "web_search_tool_result",
ToolUseID: maxLoopToolUseID,
Content: anthropic.WebSearchToolResultError{
Type: "web_search_tool_result_error",
ErrorCode: "max_uses_exceeded",
},
},
)
maxResponse := anthropic.MessagesResponse{
ID: w.inner.id,
Type: "message",
Role: "assistant",
Model: w.req.Model,
Content: serverContent,
StopReason: "end_turn",
Usage: usage,
}
logutil.TraceContext(ctx, "anthropic middleware: web_search loop max reached",
"resp", anthropic.TraceMessagesResponse(maxResponse),
)
return maxResponse, nil
}
func (w *WebSearchAnthropicWriter) startLoopWorker(initialResponse api.ChatResponse, initialToolCall api.ToolCall) {
if w.loopInFlight {
return
}
initialUsage := anthropic.Usage{
InputTokens: max(w.observedPromptEvalCount, initialResponse.Metrics.PromptEvalCount),
OutputTokens: max(w.observedEvalCount, initialResponse.Metrics.EvalCount),
}
w.loopBaseInputTok = initialUsage.InputTokens
w.loopBaseOutputTok = initialUsage.OutputTokens
w.loopResultCh = make(chan webSearchLoopResult, 1)
w.loopInFlight = true
logutil.Trace("anthropic middleware: loop worker started",
"usage", initialUsage,
"tool_call", anthropic.TraceToolCall(initialToolCall),
)
go func() {
ctx, cancel := w.startLoopContext()
defer cancel()
response, loopErr := w.runWebSearchLoop(ctx, initialResponse, initialToolCall, initialUsage)
w.loopResultCh <- webSearchLoopResult{
response: response,
loopErr: loopErr,
}
}()
}
func (w *WebSearchAnthropicWriter) writeLoopResult() error {
if w.loopResultCh == nil {
return w.sendError("api_error", "", w.currentObservedUsage())
}
result := <-w.loopResultCh
w.loopResultCh = nil
w.loopInFlight = false
if result.loopErr != nil {
logutil.Trace("anthropic middleware: loop worker returned error",
"code", result.loopErr.code,
"query", result.loopErr.query,
"usage", result.loopErr.usage,
"error", result.loopErr.err,
)
usage := result.loopErr.usage
w.applyObservedUsageDeltaToUsage(&usage)
return w.sendError(result.loopErr.code, result.loopErr.query, usage)
}
logutil.Trace("anthropic middleware: loop worker done", "resp", anthropic.TraceMessagesResponse(result.response))
w.applyObservedUsageDelta(&result.response)
return w.writeTerminalResponse(result.response)
}
func (w *WebSearchAnthropicWriter) applyObservedUsageDelta(response *anthropic.MessagesResponse) {
w.applyObservedUsageDeltaToUsage(&response.Usage)
}
func (w *WebSearchAnthropicWriter) recordObservedUsage(metrics api.Metrics) {
if metrics.PromptEvalCount > w.observedPromptEvalCount {
w.observedPromptEvalCount = metrics.PromptEvalCount
}
if metrics.EvalCount > w.observedEvalCount {
w.observedEvalCount = metrics.EvalCount
}
}
func (w *WebSearchAnthropicWriter) applyObservedUsageDeltaToUsage(usage *anthropic.Usage) {
if deltaIn := w.observedPromptEvalCount - w.loopBaseInputTok; deltaIn > 0 {
usage.InputTokens += deltaIn
}
if deltaOut := w.observedEvalCount - w.loopBaseOutputTok; deltaOut > 0 {
usage.OutputTokens += deltaOut
}
}
func (w *WebSearchAnthropicWriter) currentObservedUsage() anthropic.Usage {
return anthropic.Usage{
InputTokens: w.observedPromptEvalCount,
OutputTokens: w.observedEvalCount,
}
}
func (w *WebSearchAnthropicWriter) startLoopContext() (context.Context, context.CancelFunc) {
if w.newLoopContext != nil {
return w.newLoopContext()
}
return context.WithTimeout(context.Background(), 5*time.Minute)
}
func (w *WebSearchAnthropicWriter) combineServerAndFinalContent(serverContent []anthropic.ContentBlock, finalResponse api.ChatResponse, usage anthropic.Usage) anthropic.MessagesResponse {
converted := anthropic.ToMessagesResponse(w.inner.id, finalResponse)
content := make([]anthropic.ContentBlock, 0, len(serverContent)+len(converted.Content))
content = append(content, serverContent...)
content = append(content, converted.Content...)
return anthropic.MessagesResponse{
ID: w.inner.id,
Type: "message",
Role: "assistant",
Model: w.req.Model,
Content: content,
StopReason: converted.StopReason,
StopSequence: converted.StopSequence,
Usage: usage,
}
}
func buildWebSearchAssistantMessage(response api.ChatResponse, webSearchCall api.ToolCall) api.Message {
assistantMsg := api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{webSearchCall},
}
if response.Message.Content != "" {
assistantMsg.Content = response.Message.Content
}
if response.Message.Thinking != "" {
assistantMsg.Thinking = response.Message.Thinking
}
return assistantMsg
}
func formatWebSearchResultsForToolMessage(results []anthropic.OllamaWebSearchResult) string {
var resultText strings.Builder
for _, r := range results {
fmt.Fprintf(&resultText, "Title: %s\nURL: %s\n", r.Title, r.URL)
if r.Content != "" {
fmt.Fprintf(&resultText, "Content: %s\n", r.Content)
}
resultText.WriteString("\n")
}
return resultText.String()
}
func findWebSearchToolCall(toolCalls []api.ToolCall) (api.ToolCall, bool, bool) {
var webSearchCall api.ToolCall
hasWebSearch := false
hasOtherTools := false
for _, toolCall := range toolCalls {
if toolCall.Function.Name == "web_search" {
if !hasWebSearch {
webSearchCall = toolCall
hasWebSearch = true
}
continue
}
hasOtherTools = true
}
return webSearchCall, hasWebSearch, hasOtherTools
}
func loopServerToolUseID(messageID string, loop int) string {
base := serverToolUseID(messageID)
if loop <= 1 {
return base
}
return fmt.Sprintf("%s_%d", base, loop)
}
func (w *WebSearchAnthropicWriter) callFollowUpChat(ctx context.Context, messages []api.Message, tools api.Tools) (api.ChatResponse, error) {
streaming := false
followUp := api.ChatRequest{
Model: w.chatReq.Model,
Messages: messages,
Stream: &streaming,
Tools: tools,
Options: w.chatReq.Options,
}
body, err := json.Marshal(followUp)
if err != nil {
return api.ChatResponse{}, err
}
chatURL := envconfig.Host().String() + "/api/chat"
logutil.TraceContext(ctx, "anthropic middleware: followup request",
"url", chatURL,
"req", anthropic.TraceChatRequest(&followUp),
)
httpReq, err := http.NewRequestWithContext(ctx, "POST", chatURL, bytes.NewReader(body))
if err != nil {
return api.ChatResponse{}, err
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(httpReq)
if err != nil {
return api.ChatResponse{}, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
logutil.TraceContext(ctx, "anthropic middleware: followup non-200 response",
"status", resp.StatusCode,
"response", strings.TrimSpace(string(respBody)),
)
return api.ChatResponse{}, fmt.Errorf("followup /api/chat returned status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
}
var chatResp api.ChatResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return api.ChatResponse{}, err
}
logutil.TraceContext(ctx, "anthropic middleware: followup decoded", "resp", anthropic.TraceChatResponse(chatResp))
return chatResp, nil
}
func (w *WebSearchAnthropicWriter) writePassthroughStreamChunk(chatResponse api.ChatResponse) error {
events := w.inner.converter.Process(chatResponse)
for _, event := range events {
switch e := event.Data.(type) {
case anthropic.MessageStartEvent:
w.streamMessageStarted = true
case anthropic.ContentBlockStartEvent:
w.streamHasOpenBlock = true
w.streamOpenBlockIndex = e.Index
if e.Index+1 > w.streamNextIndex {
w.streamNextIndex = e.Index + 1
}
case anthropic.ContentBlockStopEvent:
if w.streamHasOpenBlock && w.streamOpenBlockIndex == e.Index {
w.streamHasOpenBlock = false
}
if e.Index+1 > w.streamNextIndex {
w.streamNextIndex = e.Index + 1
}
case anthropic.MessageStopEvent:
w.terminalSent = true
}
if err := writeSSE(w.ResponseWriter, event.Event, event.Data); err != nil {
return err
}
}
return nil
}
func (w *WebSearchAnthropicWriter) ensureStreamMessageStart(usage anthropic.Usage) error {
if w.streamMessageStarted {
return nil
}
inputTokens := usage.InputTokens
if inputTokens == 0 {
inputTokens = w.estimatedInputTokens
}
if err := writeSSE(w.ResponseWriter, "message_start", anthropic.MessageStartEvent{
Type: "message_start",
Message: anthropic.MessagesResponse{
ID: w.inner.id,
Type: "message",
Role: "assistant",
Model: w.req.Model,
Content: []anthropic.ContentBlock{},
Usage: anthropic.Usage{
InputTokens: inputTokens,
},
},
}); err != nil {
return err
}
w.streamMessageStarted = true
return nil
}
func (w *WebSearchAnthropicWriter) closeOpenStreamBlock() error {
if !w.streamHasOpenBlock {
return nil
}
if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
Type: "content_block_stop",
Index: w.streamOpenBlockIndex,
}); err != nil {
return err
}
if w.streamOpenBlockIndex+1 > w.streamNextIndex {
w.streamNextIndex = w.streamOpenBlockIndex + 1
}
w.streamHasOpenBlock = false
return nil
}
func (w *WebSearchAnthropicWriter) writeStreamContentBlocks(content []anthropic.ContentBlock) error {
for _, block := range content {
index := w.streamNextIndex
if block.Type == "text" {
emptyText := ""
if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
Type: "content_block_start",
Index: index,
ContentBlock: anthropic.ContentBlock{
Type: "text",
Text: &emptyText,
},
}); err != nil {
return err
}
text := ""
if block.Text != nil {
text = *block.Text
}
if err := writeSSE(w.ResponseWriter, "content_block_delta", anthropic.ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: index,
Delta: anthropic.Delta{
Type: "text_delta",
Text: text,
},
}); err != nil {
return err
}
} else {
if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
Type: "content_block_start",
Index: index,
ContentBlock: block,
}); err != nil {
return err
}
}
if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
Type: "content_block_stop",
Index: index,
}); err != nil {
return err
}
w.streamNextIndex++
}
return nil
}
func (w *WebSearchAnthropicWriter) writeTerminalResponse(response anthropic.MessagesResponse) error {
if w.terminalSent {
return nil
}
if !w.stream {
w.ResponseWriter.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w.ResponseWriter).Encode(response); err != nil {
return err
}
w.terminalSent = true
return nil
}
if err := w.ensureStreamMessageStart(response.Usage); err != nil {
return err
}
if err := w.closeOpenStreamBlock(); err != nil {
return err
}
if err := w.writeStreamContentBlocks(response.Content); err != nil {
return err
}
if err := writeSSE(w.ResponseWriter, "message_delta", anthropic.MessageDeltaEvent{
Type: "message_delta",
Delta: anthropic.MessageDelta{
StopReason: response.StopReason,
},
Usage: anthropic.DeltaUsage{
InputTokens: response.Usage.InputTokens,
OutputTokens: response.Usage.OutputTokens,
},
}); err != nil {
return err
}
if err := writeSSE(w.ResponseWriter, "message_stop", anthropic.MessageStopEvent{
Type: "message_stop",
}); err != nil {
return err
}
w.terminalSent = true
return nil
}
// streamResponse emits a complete MessagesResponse as SSE events.
func (w *WebSearchAnthropicWriter) streamResponse(response anthropic.MessagesResponse) error {
return w.writeTerminalResponse(response)
}
func (w *WebSearchAnthropicWriter) webSearchErrorResponse(errorCode, query string, usage anthropic.Usage) anthropic.MessagesResponse {
toolUseID := serverToolUseID(w.inner.id)
return anthropic.MessagesResponse{
ID: w.inner.id,
Type: "message",
Role: "assistant",
Model: w.req.Model,
Content: []anthropic.ContentBlock{
{
Type: "server_tool_use",
ID: toolUseID,
Name: "web_search",
Input: map[string]any{"query": query},
},
{
Type: "web_search_tool_result",
ToolUseID: toolUseID,
Content: anthropic.WebSearchToolResultError{
Type: "web_search_tool_result_error",
ErrorCode: errorCode,
},
},
},
StopReason: "end_turn",
Usage: usage,
}
}
// sendError sends a web search error response.
func (w *WebSearchAnthropicWriter) sendError(errorCode, query string, usage anthropic.Usage) error {
response := w.webSearchErrorResponse(errorCode, query, usage)
logutil.Trace("anthropic middleware: web_search error", "code", errorCode, "query", query, "usage", usage)
return w.writeTerminalResponse(response)
}
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
func AnthropicMessagesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
requestCtx := c.Request.Context()
var req anthropic.MessagesRequest
err := c.ShouldBindJSON(&req)
if err != nil {
@@ -134,11 +865,10 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
// Estimate input tokens for streaming (actual count not available until generation completes)
estimatedTokens := anthropic.EstimateInputTokens(req)
w := &AnthropicWriter{
innerWriter := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
}
@@ -148,8 +878,78 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
c.Writer.Header().Set("Connection", "keep-alive")
}
c.Writer = w
if hasWebSearchTool(req.Tools) {
// Guard against runtime cloud-disable policy (OLLAMA_NO_CLOUD/server.json)
// for cloud models. Local models may still receive web_search tool definitions;
// execution is validated when the model actually emits a web_search tool call.
if isCloudModelName(req.Model) {
if disabled, _ := internalcloud.Status(); disabled {
c.AbortWithStatusJSON(http.StatusForbidden, anthropic.NewError(http.StatusForbidden, internalcloud.DisabledError("web search is unavailable")))
return
}
}
c.Writer = &WebSearchAnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
newLoopContext: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(requestCtx, 5*time.Minute)
},
inner: innerWriter,
req: req,
chatReq: chatReq,
stream: req.Stream,
estimatedInputTokens: estimatedTokens,
}
} else {
c.Writer = innerWriter
}
c.Next()
}
}
// hasWebSearchTool checks if the request tools include a web_search tool
func hasWebSearchTool(tools []anthropic.Tool) bool {
for _, tool := range tools {
if strings.HasPrefix(tool.Type, "web_search") {
return true
}
}
return false
}
func isCloudModelName(name string) bool {
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
}
// extractQueryFromToolCall extracts the search query from a web_search tool call
func extractQueryFromToolCall(tc *api.ToolCall) string {
q, ok := tc.Function.Arguments.Get("query")
if !ok {
return ""
}
if s, ok := q.(string); ok {
return s
}
return ""
}
// writeSSE writes a Server-Sent Event
func writeSSE(w http.ResponseWriter, eventType string, data any) error {
d, err := json.Marshal(data)
if err != nil {
return err
}
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, d); err != nil {
return err
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
return nil
}
// serverToolUseID derives a server tool use ID from a message ID
func serverToolUseID(messageID string) string {
return "srvtoolu_" + strings.TrimPrefix(messageID, "msg_")
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,22 @@
package middleware
import (
"testing"
"github.com/ollama/ollama/envconfig"
)
func setTestHome(t *testing.T, home string) {
t.Helper()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
envconfig.ReloadServerConfig()
}
// enableCloudForTest sets HOME to a clean temp dir and clears OLLAMA_NO_CLOUD
// so that cloud features are enabled for the duration of the test.
func enableCloudForTest(t *testing.T) {
t.Helper()
t.Setenv("OLLAMA_NO_CLOUD", "")
setTestHome(t, t.TempDir())
}

View File

@@ -45,6 +45,10 @@ func ParserForName(name string) Parser {
var p Parser
switch name {
case "qwen3":
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
case "qwen3-thinking":
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
case "qwen3-coder":
p = &Qwen3CoderParser{}
case "qwen3-vl-instruct":

View File

@@ -54,6 +54,8 @@ func TestBuiltInParsersStillWork(t *testing.T) {
name string
}{
{"passthrough"},
{"qwen3"},
{"qwen3-thinking"},
{"qwen3-coder"},
{"harmony"},
}

335
model/parsers/qwen3.go Normal file
View File

@@ -0,0 +1,335 @@
package parsers
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type qwen3ParserState int
const (
qwen3ParserStateLookingForThinkingOpen qwen3ParserState = iota
qwen3ParserStateThinkingStartedEatingWhitespace
qwen3ParserStateCollectingThinking
qwen3ParserStateThinkingDoneEatingWhitespace
qwen3ParserStateCollectingContent
qwen3ParserStateToolStartedEatingWhitespace
qwen3ParserStateCollectingToolContent
)
const (
qwen3ThinkingOpenTag = "<think>"
qwen3ThinkingCloseTag = "</think>"
qwen3ToolOpenTag = "<tool_call>"
qwen3ToolCloseTag = "</tool_call>"
)
// Qwen3Parser parses Qwen3 output to extract thinking and tool calls.
// Qwen3 prompts end with <think> when thinking is enabled, so output begins
// with thinking content directly (without an opening tag).
type Qwen3Parser struct {
state qwen3ParserState
buffer strings.Builder
tools []api.Tool
hasThinkingSupport bool
defaultThinking bool
maybeThinkingOpenAtBOL bool
}
func (p *Qwen3Parser) HasToolSupport() bool {
return true
}
func (p *Qwen3Parser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.buffer.Reset()
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
if thinkValue == nil {
thinkingEnabled = p.defaultThinking
}
if p.hasThinkingSupport && thinkingEnabled {
p.state = qwen3ParserStateCollectingThinking
p.maybeThinkingOpenAtBOL = true
} else {
p.state = qwen3ParserStateCollectingContent
p.maybeThinkingOpenAtBOL = false
}
return tools
}
type qwen3Event interface {
isQwen3Event()
}
type qwen3EventContent struct {
content string
}
func (qwen3EventContent) isQwen3Event() {}
type qwen3EventRawToolCall struct {
raw string
}
func (qwen3EventRawToolCall) isQwen3Event() {}
type qwen3EventThinkingContent struct {
content string
}
func (qwen3EventThinkingContent) isQwen3Event() {}
func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case qwen3EventRawToolCall:
toolCall, err := parseQwen3ToolCall(event, p.tools)
if err != nil {
slog.Warn("qwen3 tool call parsing failed", "error", err)
return "", "", nil, err
}
calls = append(calls, toolCall)
case qwen3EventThinkingContent:
thinkingSb.WriteString(event.content)
case qwen3EventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), calls, nil
}
func (p *Qwen3Parser) parseEvents() []qwen3Event {
var all []qwen3Event
keepLooping := true
for keepLooping {
var events []qwen3Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
func (p *Qwen3Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen3ParserState) ([]qwen3Event, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true
}
func (p *Qwen3Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
return splitAtTag(&p.buffer, tag, trimAfter)
}
func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
var events []qwen3Event
switch p.state {
case qwen3ParserStateLookingForThinkingOpen:
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
p.state = qwen3ParserStateThinkingStartedEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingThinking
}
return events, true
} else if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
return events, false
} else if trimmed == "" {
return events, false
}
p.state = qwen3ParserStateCollectingContent
return events, true
case qwen3ParserStateThinkingStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingThinking)
case qwen3ParserStateCollectingThinking:
acc := p.buffer.String()
// Some qwen3 checkpoints emit an explicit opening <think> tag even
// though the prompt already ended with <think>. Strip exactly one
// leading opening tag if present.
if p.maybeThinkingOpenAtBOL {
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
return events, false
}
p.maybeThinkingOpenAtBOL = false
return events, true
}
if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
return events, false
}
p.maybeThinkingOpenAtBOL = false
}
if strings.Contains(acc, qwen3ThinkingCloseTag) {
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, qwen3EventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = qwen3ParserStateThinkingDoneEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingContent
}
return events, true
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventThinkingContent{content: unambiguous})
}
return events, false
}
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventThinkingContent{content: unambiguous})
}
return events, false
case qwen3ParserStateThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingContent)
case qwen3ParserStateCollectingContent:
acc := p.buffer.String()
if strings.Contains(acc, qwen3ToolOpenTag) {
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
if len(before) > 0 {
events = append(events, qwen3EventContent{content: before})
}
if after == "" {
p.state = qwen3ParserStateToolStartedEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingToolContent
}
return events, true
} else if overlapLen := overlap(acc, qwen3ToolOpenTag); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventContent{content: unambiguous})
}
return events, false
}
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventContent{content: unambiguous})
}
return events, false
case qwen3ParserStateToolStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingToolContent)
case qwen3ParserStateCollectingToolContent:
acc := p.buffer.String()
if strings.Contains(acc, qwen3ToolCloseTag) {
toolContent, _ := p.splitAtTag(qwen3ToolCloseTag, true)
if len(toolContent) == 0 {
slog.Warn("qwen3 tool call closing tag found but no content before it")
}
events = append(events, qwen3EventRawToolCall{raw: toolContent})
p.state = qwen3ParserStateCollectingContent
return events, true
}
return events, false
default:
panic("unreachable")
}
}
func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
var parsed struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse JSON: %w", err)
}
if parsed.Name == "" {
return api.ToolCall{}, fmt.Errorf("empty function name")
}
_ = tools // qwen3 uses direct JSON args and does not require schema coercion here.
toolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: parsed.Name,
Arguments: api.NewToolCallFunctionArguments(),
},
}
for key, value := range parsed.Arguments {
toolCall.Function.Arguments.Set(key, value)
}
return toolCall, nil
}

147
model/parsers/qwen3_test.go Normal file
View File

@@ -0,0 +1,147 @@
package parsers
import (
"testing"
"github.com/ollama/ollama/api"
)
func TestQwen3ParserThinkingEnabled(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("Let me think...</think>Answer.", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "Let me think..." {
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
}
if content != "Answer." {
t.Fatalf("expected content %q, got %q", "Answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserThinkingEnabledWithExplicitOpeningTag(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "Let me think..." {
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
}
if content != "Answer." {
t.Fatalf("expected content %q, got %q", "Answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserThinkingEnabledWithSplitOpeningTag(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("<thi", false)
if err != nil {
t.Fatalf("parse failed on first chunk: %v", err)
}
if content != "" || thinking != "" || len(calls) != 0 {
t.Fatalf("expected no output for first chunk, got content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
content, thinking, calls, err = parser.Add("nk>Let me think...</think>Answer.", true)
if err != nil {
t.Fatalf("parse failed on second chunk: %v", err)
}
if thinking != "Let me think..." {
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
}
if content != "Answer." {
t.Fatalf("expected content %q, got %q", "Answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserThinkingDisabled(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add("Direct answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected no thinking, got %q", thinking)
}
if content != "Direct answer" {
t.Fatalf("expected content %q, got %q", "Direct answer", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserNilThinkDefaultsToContentForInstructParser(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, nil)
content, thinking, calls, err := parser.Add("Direct answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected no thinking, got %q", thinking)
}
if content != "Direct answer" {
t.Fatalf("expected content %q, got %q", "Direct answer", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserToolCall(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
input := "<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
content, thinking, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
}
location, ok := calls[0].Function.Arguments.Get("location")
if !ok || location != "San Francisco" {
t.Fatalf("expected location %q, got %v", "San Francisco", location)
}
unit, ok := calls[0].Function.Arguments.Get("unit")
if !ok || unit != "celsius" {
t.Fatalf("expected unit %q, got %v", "celsius", unit)
}
}

View File

@@ -115,6 +115,15 @@ func (s *store) saveLocked() error {
return err
}
// Read existing file into a generic map to preserve unknown fields
// (e.g. disable_ollama_cloud) that aliasStore doesn't own.
existing := make(map[string]json.RawMessage)
if data, err := os.ReadFile(s.path); err == nil {
if err := json.Unmarshal(data, &existing); err != nil {
slog.Debug("failed to parse existing server config; preserving unknown fields skipped", "path", s.path, "error", err)
}
}
// Combine exact and prefix entries
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
for _, entry := range s.entries {
@@ -126,10 +135,17 @@ func (s *store) saveLocked() error {
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
})
cfg := serverConfig{
Version: serverConfigVersion,
Aliases: entries,
// Overwrite only the keys we own
versionJSON, err := json.Marshal(serverConfigVersion)
if err != nil {
return err
}
aliasesJSON, err := json.Marshal(entries)
if err != nil {
return err
}
existing["version"] = versionJSON
existing["aliases"] = aliasesJSON
f, err := os.CreateTemp(dir, "router-*.json")
if err != nil {
@@ -138,7 +154,7 @@ func (s *store) saveLocked() error {
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(cfg); err != nil {
if err := enc.Encode(existing); err != nil {
_ = f.Close()
_ = os.Remove(f.Name())
return err

View File

@@ -38,6 +38,7 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/manifest"
@@ -58,6 +59,11 @@ import (
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
const (
cloudErrRemoteInferenceUnavailable = "remote model is unavailable"
cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
)
func shouldUseHarmony(model *Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
@@ -144,12 +150,15 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
useImagegen, _ := requestOpts["use_imagegen_runner"].(bool)
delete(requestOpts, "use_imagegen_runner")
opts, err := s.modelOptions(model, requestOpts)
if err != nil {
return nil, nil, nil, err
}
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive, useImagegen)
var runner *runnerRef
select {
case runner = <-runnerCh:
@@ -229,6 +238,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
if disabled, _ := internalcloud.Status(); disabled {
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
return
}
origModel := req.Model
remoteURL, err := url.Parse(m.Config.RemoteHost)
@@ -1066,9 +1080,12 @@ func (s *Server) ShowHandler(c *gin.Context) {
resp, err := GetModelInfo(req)
if err != nil {
var statusErr api.StatusError
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case errors.As(err, &statusErr):
c.JSON(statusErr.StatusCode, gin.H{"error": statusErr.ErrorMessage})
case err.Error() == errtypes.InvalidModelNameErrMsg:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
@@ -1095,6 +1112,15 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return nil, err
}
if m.Config.RemoteHost != "" {
if disabled, _ := internalcloud.Status(); disabled {
return nil, api.StatusError{
StatusCode: http.StatusForbidden,
ErrorMessage: internalcloud.DisabledError(cloudErrRemoteModelDetailsUnavailable),
}
}
}
modelDetails := api.ModelDetails{
ParentModel: m.ParentModel,
Format: m.Config.ModelFormat,
@@ -1571,6 +1597,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
r.GET("/api/status", s.StatusHandler)
// Local model cache management (new implementation is at end of function)
r.POST("/api/pull", s.PullHandler)
@@ -1634,6 +1661,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
func Serve(ln net.Listener) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
slog.Info("server config", "env", envconfig.Values())
cloudDisabled, _ := internalcloud.Status()
slog.Info(fmt.Sprintf("Ollama cloud disabled: %t", cloudDisabled))
blobsDir, err := manifest.BlobsPath("")
if err != nil {
@@ -1824,6 +1853,16 @@ func streamResponse(c *gin.Context, ch chan any) {
})
}
func (s *Server) StatusHandler(c *gin.Context) {
disabled, source := internalcloud.Status()
c.JSON(http.StatusOK, api.StatusResponse{
Cloud: api.CloudStatus{
Disabled: disabled,
Source: source,
},
})
}
func (s *Server) WhoamiHandler(c *gin.Context) {
// todo allow other hosts
u, err := url.Parse("https://ollama.com")
@@ -2010,6 +2049,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
if disabled, _ := internalcloud.Status(); disabled {
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
return
}
origModel := req.Model
remoteURL, err := url.Parse(m.Config.RemoteHost)

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
@@ -16,7 +17,7 @@ import (
func TestAliasShadowingRejected(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
setTestHome(t, t.TempDir())
s := Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
@@ -40,7 +41,7 @@ func TestAliasShadowingRejected(t *testing.T) {
func TestAliasResolvesForChatRemote(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
setTestHome(t, t.TempDir())
var remoteModel string
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -256,7 +257,7 @@ func TestPrefixAliasChain(t *testing.T) {
func TestPrefixAliasCRUD(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
setTestHome(t, t.TempDir())
s := Server{}
@@ -364,7 +365,7 @@ func TestPrefixAliasCaseInsensitive(t *testing.T) {
func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
setTestHome(t, t.TempDir())
s := Server{}
@@ -424,3 +425,51 @@ func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
}
}
func TestAliasSavePreservesCloudDisable(t *testing.T) {
gin.SetMode(gin.TestMode)
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".ollama", "server.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
t.Fatal(err)
}
initial := map[string]any{
"version": serverConfigVersion,
"disable_ollama_cloud": true,
"aliases": []aliasEntry{},
}
data, err := json.Marshal(initial)
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(configPath, data, 0o644); err != nil {
t.Fatal(err)
}
s := Server{}
w := createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
updated, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
var updatedCfg map[string]json.RawMessage
if err := json.Unmarshal(updated, &updatedCfg); err != nil {
t.Fatal(err)
}
raw, ok := updatedCfg["disable_ollama_cloud"]
if !ok {
t.Fatal("expected disable_ollama_cloud key to be preserved")
}
if string(raw) != "true" {
t.Fatalf("expected disable_ollama_cloud to remain true, got %s", string(raw))
}
}

View File

@@ -0,0 +1,94 @@
package server
import (
"encoding/json"
"net/http"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
internalcloud "github.com/ollama/ollama/internal/cloud"
)
func TestStatusHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
t.Setenv("OLLAMA_NO_CLOUD", "1")
s := Server{}
w := createRequest(t, s.StatusHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.StatusResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if !resp.Cloud.Disabled {
t.Fatalf("expected cloud.disabled true, got false")
}
if resp.Cloud.Source != "env" {
t.Fatalf("expected cloud.source env, got %q", resp.Cloud.Source)
}
}
func TestCloudDisabledBlocksRemoteOperations(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
t.Setenv("OLLAMA_NO_CLOUD", "1")
s := Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-cloud",
RemoteHost: "example.com",
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("chat remote blocked", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-cloud",
Messages: []api.Message{{Role: "user", Content: "hi"}},
})
if w.Code != http.StatusForbidden {
t.Fatalf("expected status 403, got %d", w.Code)
}
if got := w.Body.String(); got != `{"error":"`+internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)+`"}` {
t.Fatalf("unexpected response: %s", got)
}
})
t.Run("generate remote blocked", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-cloud",
Prompt: "hi",
})
if w.Code != http.StatusForbidden {
t.Fatalf("expected status 403, got %d", w.Code)
}
if got := w.Body.String(); got != `{"error":"`+internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)+`"}` {
t.Fatalf("unexpected response: %s", got)
}
})
t.Run("show remote blocked", func(t *testing.T) {
w := createRequest(t, s.ShowHandler, api.ShowRequest{
Model: "test-cloud",
})
if w.Code != http.StatusForbidden {
t.Fatalf("expected status 403, got %d", w.Code)
}
if got := w.Body.String(); got != `{"error":"`+internalcloud.DisabledError(cloudErrRemoteModelDetailsUnavailable)+`"}` {
t.Fatalf("unexpected response: %s", got)
}
})
}

View File

@@ -2371,29 +2371,6 @@ func TestImageGenerateStreamFalse(t *testing.T) {
return nil
}
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
"": {
llama: &mock,
Options: &opts,
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
numParallel: 1,
},
},
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
// Create model manifest with image capability
n := model.ParseName("test-image")
cfg := model.ConfigV2{Capabilities: []string{"image"}}
@@ -2409,6 +2386,35 @@ func TestImageGenerateStreamFalse(t *testing.T) {
t.Fatal(err)
}
loadedModel, err := GetModel("test-image")
if err != nil {
t.Fatal(err)
}
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
schedulerModelKey(loadedModel): {
llama: &mock,
Options: &opts,
model: loadedModel,
isImagegen: true,
numParallel: 1,
},
},
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
streamFalse := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-image",

View File

@@ -22,6 +22,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
type LlmRequest struct {
@@ -32,6 +33,7 @@ type LlmRequest struct {
successCh chan *runnerRef
errCh chan error
schedAttempts uint
useImagegen bool
}
type Scheduler struct {
@@ -81,8 +83,30 @@ func InitScheduler(ctx context.Context) *Scheduler {
return sched
}
// schedulerModelKey returns the scheduler map key for a model.
// GGUF-backed models use ModelPath; safetensors/image models without a
// ModelPath use manifest digest so distinct models don't collide.
func schedulerModelKey(m *Model) string {
if m == nil {
return ""
}
if m.ModelPath != "" {
return m.ModelPath
}
if m.Digest != "" {
return "digest:" + m.Digest
}
if m.Name != "" {
return "name:" + m.Name
}
if m.ShortName != "" {
return "short:" + m.ShortName
}
return ""
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
@@ -99,10 +123,12 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
sessionDuration: sessionDuration,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
useImagegen: useImagegen,
}
key := schedulerModelKey(req.model)
s.loadedMu.Lock()
runner := s.loaded[req.model.ModelPath]
runner := s.loaded[key]
s.loadedMu.Unlock()
if runner != nil && !runner.needsReload(c, req) {
req.useLoadedRunner(runner, s.finishedReqCh)
@@ -148,8 +174,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
for {
var runnerToExpire *runnerRef
pendingKey := schedulerModelKey(pending.model)
s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath]
runner := s.loaded[pendingKey]
loadedCount := len(s.loaded)
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
for _, r := range s.loaded {
@@ -163,7 +190,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
runnerToExpire = runner
} else {
// Runner is usable, return it
logutil.Trace("using existing loaded runner", "model", pending.model.ModelPath)
logutil.Trace("using existing loaded runner", "model", pendingKey)
pending.useLoadedRunner(runner, s.finishedReqCh)
break
}
@@ -289,11 +316,12 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
slog.Debug("shutting down scheduler completed loop")
return
case finished := <-s.finishedReqCh:
finishedKey := schedulerModelKey(finished.model)
s.loadedMu.Lock()
runner := s.loaded[finished.model.ModelPath]
runner := s.loaded[finishedKey]
s.loadedMu.Unlock()
if runner == nil {
slog.Error("finished request signal received after model unloaded", "modelPath", finished.model.ModelPath)
slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
continue
}
runner.refMu.Lock()
@@ -344,7 +372,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
s.loadedMu.Lock()
slog.Debug("got lock to unload expired event", "runner", runner)
runnerToUnload := s.loaded[runner.modelPath]
runnerToUnload := s.loaded[runner.modelKey]
if runnerToUnload == nil {
// If runnerToUnload is nil, we already processed an event and
// unloaded it. This double unload can happen if the initial
@@ -373,7 +401,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
}
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
runner.unload()
delete(s.loaded, runner.modelPath)
delete(s.loaded, runner.modelKey)
s.loadedMu.Unlock()
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
<-finished
@@ -511,6 +539,7 @@ iGPUScan:
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: llama,
Options: &req.opts,
sessionDuration: sessionDuration,
@@ -525,7 +554,7 @@ iGPUScan:
runner.refMu.Lock() // hold lock until running or aborted
s.loadedMu.Lock()
if oldRunner, ok := s.loaded[req.model.ModelPath]; ok {
if oldRunner, ok := s.loaded[runner.modelKey]; ok {
// Shouldn't happen, but safeguard against leaking a runner
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
oldRunner.refMu.Lock()
@@ -533,7 +562,7 @@ iGPUScan:
oldRunner.refMu.Unlock()
}
s.activeLoading = nil
s.loaded[req.model.ModelPath] = runner
s.loaded[runner.modelKey] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
@@ -566,17 +595,20 @@ iGPUScan:
// loadMLX loads an experimental safetensors model using the unified MLX runner.
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode imagegen.ModelMode
if slices.Contains(req.model.Config.Capabilities, "image") {
mode = imagegen.ModeImageGen
} else {
mode = imagegen.ModeLLM
}
// Use model name for MLX (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := imagegen.NewServer(modelName, mode)
var server llm.LlamaServer
var err error
isImagegen := false
if slices.Contains(req.model.Config.Capabilities, "image") {
server, err = imagegen.NewServer(modelName, imagegen.ModeImageGen)
isImagegen = true
} else if req.useImagegen {
server, err = imagegen.NewServer(modelName, imagegen.ModeLLM)
isImagegen = true
} else {
server, err = mlxrunner.NewClient(modelName)
}
if err != nil {
req.errCh <- err
return true
@@ -590,16 +622,18 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: server,
Options: &req.opts,
loading: false,
isImagegen: isImagegen,
sessionDuration: sessionDuration,
totalSize: server.TotalSize(),
vramSize: server.VRAMSize(),
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loaded[runner.modelKey] = runner
s.loadedMu.Unlock()
// Set up expiration timer
@@ -667,6 +701,7 @@ type runnerRef struct {
loading bool // True only during initial load, then false forever
gpus []ml.DeviceID // Recorded at time of provisioning
discreteGPUs bool // True if all devices are discrete GPUs - used to skip VRAM recovery check for iGPUs
isImagegen bool // True if loaded via imagegen runner (vs mlxrunner)
vramSize uint64
totalSize uint64
@@ -676,6 +711,7 @@ type runnerRef struct {
model *Model
modelPath string
modelKey string
numParallel int
*api.Options
}
@@ -695,10 +731,16 @@ func (runner *runnerRef) unload() {
}
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
runner.refMu.Lock()
defer runner.refMu.Unlock()
// Check if runner type (imagegen vs mlxrunner) matches what's requested
wantImagegen := req.useImagegen || slices.Contains(req.model.Config.Capabilities, "image")
if runner.isImagegen != wantImagegen {
return true
}
timeout := 10 * time.Second
if runner.loading {
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
@@ -800,6 +842,10 @@ func (runner *runnerRef) LogValue() slog.Value {
if runner == nil {
return slog.StringValue("nil")
}
modelID := runner.modelPath
if modelID == "" {
modelID = runner.modelKey
}
attrs := []slog.Attr{}
if runner.model != nil {
attrs = append(attrs, slog.String("name", runner.model.Name))
@@ -814,7 +860,7 @@ func (runner *runnerRef) LogValue() slog.Value {
slog.String("vram", format.HumanBytes2(runner.vramSize)),
slog.Int("parallel", runner.numParallel),
slog.Int("pid", runner.pid),
slog.String("model", runner.modelPath),
slog.String("model", modelID),
)
if runner.Options != nil {
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
@@ -859,8 +905,16 @@ func (a ByDurationAndName) Less(i, j int) bool {
if d1 != d2 {
return d1 < d2
}
// Secondary sort by model path lex order
return a[i].modelPath < a[j].modelPath
// Secondary sort by model key/path lex order
n1 := a[i].modelPath
if n1 == "" {
n1 = a[i].modelKey
}
n2 := a[j].modelPath
if n2 == "" {
n2 = a[j].modelKey
}
return n1 < n2
}
// TODO - future consideration to pick runners based on size
@@ -920,8 +974,9 @@ func (s *Scheduler) unloadAllRunners() {
}
func (s *Scheduler) expireRunner(model *Model) {
modelKey := schedulerModelKey(model)
s.loadedMu.Lock()
runner, ok := s.loaded[model.ModelPath]
runner, ok := s.loaded[modelKey]
s.loadedMu.Unlock()
if ok {
runner.refMu.Lock()

View File

@@ -408,10 +408,10 @@ func TestSchedGetRunner(t *testing.T) {
s.getSystemInfoFn = getSystemInfoFn
s.newServerFn = a.newServer
slog.Info("a")
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration, false)
require.Len(t, s.pendingReqCh, 1)
slog.Info("b")
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration, false)
require.Len(t, s.pendingReqCh, 1)
require.Empty(t, successCh1b)
require.Len(t, errCh1b, 1)
@@ -435,7 +435,7 @@ func TestSchedGetRunner(t *testing.T) {
c.req.model.ModelPath = "bad path"
slog.Info("c")
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration, false)
// Starts in pending channel, then should be quickly processed to return an error
time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload
require.Empty(t, successCh1c)
@@ -448,6 +448,71 @@ func TestSchedGetRunner(t *testing.T) {
b.ctxDone()
}
func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
s := InitScheduler(ctx)
opts := api.DefaultOptions()
opts.NumCtx = 4
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
loadedRunner := &runnerRef{
model: loadedModel,
modelKey: schedulerModelKey(loadedModel),
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
Options: &opts,
numParallel: 1,
}
s.loadedMu.Lock()
s.loaded[loadedRunner.modelKey] = loadedRunner
s.loadedMu.Unlock()
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
require.Empty(t, successCh)
require.Empty(t, errCh)
require.Len(t, s.pendingReqCh, 1)
}
func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
s := InitScheduler(ctx)
opts := api.DefaultOptions()
opts.NumCtx = 4
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
loadedRunner := &runnerRef{
model: loadedModel,
modelKey: schedulerModelKey(loadedModel),
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
Options: &opts,
numParallel: 1,
}
s.loadedMu.Lock()
s.loaded[loadedRunner.modelKey] = loadedRunner
s.loadedMu.Unlock()
reqCtx, cancelReq := context.WithCancel(ctx)
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
cancelReq()
select {
case runner := <-successCh:
require.Equal(t, loadedRunner, runner)
default:
t.Fatal("expected existing runner to be reused")
}
require.Empty(t, errCh)
require.Empty(t, s.pendingReqCh)
}
func TestSchedExpireRunner(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
defer done()
@@ -509,7 +574,7 @@ func TestSchedPrematureExpired(t *testing.T) {
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
s.newServerFn = scenario1a.newServer
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration, false)
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {

14
server/test_home_test.go Normal file
View File

@@ -0,0 +1,14 @@
package server
import (
"testing"
"github.com/ollama/ollama/envconfig"
)
func setTestHome(t *testing.T, home string) {
t.Helper()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
envconfig.ReloadServerConfig()
}

190
version/update.go Normal file
View File

@@ -0,0 +1,190 @@
package version
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/auth"
)
var updateCheckURLBase = "https://ollama.com"
// CheckForUpdate calls the ollama.com update API and reports whether a
// newer version is available.
func CheckForUpdate(ctx context.Context) (bool, error) {
requestURL, err := url.Parse(updateCheckURLBase + "/api/update")
if err != nil {
return false, fmt.Errorf("parse update URL: %w", err)
}
query := requestURL.Query()
query.Add("os", runtime.GOOS)
query.Add("arch", runtime.GOARCH)
query.Add("version", Version)
requestURL.RawQuery = query.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
if err != nil {
return false, fmt.Errorf("create request: %w", err)
}
_ = auth.SignRequest(ctx, req)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, fmt.Errorf("update check request: %w", err)
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK, nil
}
func cacheFilePath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "update"), nil
}
// CacheAvailableUpdate creates the update marker file.
func CacheAvailableUpdate() error {
path, err := cacheFilePath()
if err != nil {
return err
}
f, err := os.Create(path)
if err != nil {
return err
}
return f.Close()
}
// HasCachedUpdate reports whether a non-stale update marker exists.
func HasCachedUpdate() bool {
path, err := cacheFilePath()
if err != nil {
return false
}
fi, err := os.Stat(path)
if err != nil {
return false
}
return time.Since(fi.ModTime()) <= 24*time.Hour
}
// ClearCachedUpdate removes the update marker file.
func ClearCachedUpdate() error {
path, err := cacheFilePath()
if err != nil {
return err
}
err = os.Remove(path)
if os.IsNotExist(err) {
return nil
}
return err
}
func IsOfficialInstall() bool {
exe, err := os.Executable()
if err != nil {
return false
}
exe, err = filepath.EvalSymlinks(exe)
if err != nil {
return false
}
switch runtime.GOOS {
case "windows":
localAppData := os.Getenv("LOCALAPPDATA")
if localAppData == "" {
return false
}
return strings.HasPrefix(strings.ToLower(exe), strings.ToLower(filepath.Join(localAppData, "Programs", "Ollama")+string(filepath.Separator)))
case "darwin":
return strings.HasPrefix(exe, "/Applications/Ollama.app/")
default:
dir := filepath.Dir(exe)
return dir == "/usr/local/bin" || dir == "/usr/bin" || dir == "/bin"
}
}
// DoUpdate downloads and runs the platform-appropriate install script.
func DoUpdate(force bool) error {
if !force && !IsOfficialInstall() {
return fmt.Errorf("ollama appears to be installed through a package manager. Please update it using your package manager")
}
var scriptURL, tmpPattern, shell string
switch runtime.GOOS {
case "windows":
scriptURL = "https://ollama.com/install.ps1"
tmpPattern = "ollama-install-*.ps1"
shell = "powershell"
default:
scriptURL = "https://ollama.com/install.sh"
tmpPattern = "ollama-install-*.sh"
shell = "sh"
}
resp, err := http.Get(scriptURL)
if err != nil {
return fmt.Errorf("download install script: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download install script: status %d", resp.StatusCode)
}
tmpFile, err := os.CreateTemp("", tmpPattern)
if err != nil {
return fmt.Errorf("create temp file: %w", err)
}
defer os.Remove(tmpFile.Name())
if _, err := io.Copy(tmpFile, resp.Body); err != nil {
tmpFile.Close()
return fmt.Errorf("write install script: %w", err)
}
tmpFile.Close()
cmd := exec.Command(shell, tmpFile.Name())
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// IsLocalHost reports whether the configured Ollama host points to the
// local machine.
func IsLocalHost(host *url.URL) bool {
hostname := host.Hostname()
switch hostname {
case "", "127.0.0.1", "localhost", "::1", "0.0.0.0":
return true
}
if ip := net.ParseIP(hostname); ip != nil {
return ip.IsLoopback()
}
return false
}

146
version/update_test.go Normal file
View File

@@ -0,0 +1,146 @@
package version
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"testing"
"time"
)
func setHome(t *testing.T, dir string) {
t.Helper()
if runtime.GOOS == "windows" {
t.Setenv("USERPROFILE", dir)
} else {
t.Setenv("HOME", dir)
}
}
func TestCheckForUpdate(t *testing.T) {
t.Run("update available", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("os") == "" || r.URL.Query().Get("arch") == "" || r.URL.Query().Get("version") == "" {
t.Error("missing expected query parameters")
}
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
old := updateCheckURLBase
updateCheckURLBase = ts.URL
defer func() { updateCheckURLBase = old }()
available, err := CheckForUpdate(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !available {
t.Fatal("expected update to be available")
}
})
t.Run("up to date", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer ts.Close()
old := updateCheckURLBase
updateCheckURLBase = ts.URL
defer func() { updateCheckURLBase = old }()
available, err := CheckForUpdate(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if available {
t.Fatal("expected no update available")
}
})
t.Run("network error", func(t *testing.T) {
old := updateCheckURLBase
updateCheckURLBase = "http://localhost:1"
defer func() { updateCheckURLBase = old }()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := CheckForUpdate(ctx)
if err == nil {
t.Fatal("expected error for unreachable server")
}
})
}
func TestCacheRoundTrip(t *testing.T) {
tmp := t.TempDir()
setHome(t, tmp)
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
if err := CacheAvailableUpdate(); err != nil {
t.Fatalf("cache write: %v", err)
}
if !HasCachedUpdate() {
t.Fatal("expected cached update to be present")
}
if err := ClearCachedUpdate(); err != nil {
t.Fatalf("cache clear: %v", err)
}
if HasCachedUpdate() {
t.Fatal("expected no cached update after clear")
}
}
func TestHasCachedUpdateStale(t *testing.T) {
tmp := t.TempDir()
setHome(t, tmp)
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
if err := CacheAvailableUpdate(); err != nil {
t.Fatalf("cache write: %v", err)
}
// Backdate the file to make it stale
path := filepath.Join(tmp, ".ollama", "update")
staleTime := time.Now().Add(-25 * time.Hour)
os.Chtimes(path, staleTime, staleTime)
if HasCachedUpdate() {
t.Fatal("expected no cached update for stale file")
}
}
func TestIsLocalHost(t *testing.T) {
tests := []struct {
host string
local bool
}{
{"http://127.0.0.1:11434", true},
{"http://localhost:11434", true},
{"http://[::1]:11434", true},
{"http://0.0.0.0:11434", true},
{"http://remote.example.com:11434", false},
{"http://192.168.1.100:11434", false},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
u, err := url.Parse(tt.host)
if err != nil {
t.Fatalf("parse URL: %v", err)
}
if got := IsLocalHost(u); got != tt.local {
t.Errorf("IsLocalHost(%s) = %v, want %v", tt.host, got, tt.local)
}
})
}
}

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/signal"
@@ -18,6 +19,7 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/model"
@@ -62,6 +64,18 @@ func isLocalServer() bool {
return hostname == "localhost" || hostname == "127.0.0.1" || strings.Contains(parsed.Host, ":11434")
}
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
status, err := client.CloudStatusExperimental(ctx)
if err != nil {
var statusErr api.StatusError
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
return false, false
}
return false, false
}
return status.Cloud.Disabled, true
}
// truncateToolOutput truncates tool output to prevent context overflow.
// Uses a smaller limit (4k tokens) for local models, larger (10k) for cloud/remote.
func truncateToolOutput(output, modelName string) string {
@@ -86,6 +100,10 @@ func waitForOllamaSignin(ctx context.Context) error {
return err
}
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
return errors.New(internalcloud.DisabledError("cloud account endpoints are unavailable"))
}
// Get signin URL from initial Whoami call
_, err = client.Whoami(ctx)
if err != nil {
@@ -664,6 +682,15 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
supportsTools = false
}
if enableWebsearch {
if client, err := api.ClientFromEnvironment(); err == nil {
if disabled, known := cloudStatusDisabled(cmd.Context(), client); known && disabled {
fmt.Fprintf(os.Stderr, "%s\n", internalcloud.DisabledError("web search is unavailable"))
enableWebsearch = false
}
}
}
// Create tool registry only if model supports tools
var toolRegistry *tools.Registry
if supportsTools {

View File

@@ -30,6 +30,8 @@ type ModelfileConfig struct {
Template string
System string
License string
Parser string
Renderer string
}
// CreateOptions holds all options for model creation.
@@ -37,7 +39,7 @@ type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile
}
// CreateModel imports a model from a local directory.
@@ -267,8 +269,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
ModelFormat: "safetensors",
Capabilities: caps,
Requires: MinOllamaVersion,
Parser: parserName,
Renderer: rendererName,
Parser: resolveParserName(opts.Modelfile, parserName),
Renderer: resolveRendererName(opts.Modelfile, rendererName),
}
configJSON, err := json.Marshal(configData)
if err != nil {
@@ -305,6 +307,22 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
}
}
func resolveParserName(mf *ModelfileConfig, inferred string) string {
if mf != nil && mf.Parser != "" {
return mf.Parser
}
return inferred
}
func resolveRendererName(mf *ModelfileConfig, inferred string) string {
if mf != nil && mf.Renderer != "" {
return mf.Renderer
}
return inferred
}
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
var layers []manifest.Layer
@@ -410,7 +428,7 @@ func getParserName(modelDir string) string {
return "deepseek3"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3-coder"
return "qwen3"
}
}
@@ -424,7 +442,7 @@ func getParserName(modelDir string) string {
return "deepseek3"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3-coder"
return "qwen3"
}
}

View File

@@ -10,6 +10,8 @@ func TestModelfileConfig(t *testing.T) {
Template: "{{ .Prompt }}",
System: "You are a helpful assistant.",
License: "MIT",
Parser: "qwen3",
Renderer: "qwen3",
}
if config.Template != "{{ .Prompt }}" {
@@ -21,6 +23,12 @@ func TestModelfileConfig(t *testing.T) {
if config.License != "MIT" {
t.Errorf("License = %q, want %q", config.License, "MIT")
}
if config.Parser != "qwen3" {
t.Errorf("Parser = %q, want %q", config.Parser, "qwen3")
}
if config.Renderer != "qwen3" {
t.Errorf("Renderer = %q, want %q", config.Renderer, "qwen3")
}
}
func TestModelfileConfig_Empty(t *testing.T) {
@@ -35,6 +43,12 @@ func TestModelfileConfig_Empty(t *testing.T) {
if config.License != "" {
t.Errorf("License should be empty, got %q", config.License)
}
if config.Parser != "" {
t.Errorf("Parser should be empty, got %q", config.Parser)
}
if config.Renderer != "" {
t.Errorf("Renderer should be empty, got %q", config.Renderer)
}
}
func TestModelfileConfig_PartialFields(t *testing.T) {
@@ -53,6 +67,12 @@ func TestModelfileConfig_PartialFields(t *testing.T) {
if config.License != "" {
t.Error("License should be empty")
}
if config.Parser != "" {
t.Error("Parser should be empty")
}
if config.Renderer != "" {
t.Error("Renderer should be empty")
}
}
func TestMinOllamaVersion(t *testing.T) {
@@ -98,6 +118,8 @@ func TestCreateOptions(t *testing.T) {
Template: "test",
System: "system",
License: "MIT",
Parser: "qwen3-thinking",
Renderer: "qwen3",
},
}
@@ -116,6 +138,92 @@ func TestCreateOptions(t *testing.T) {
if opts.Modelfile.Template != "test" {
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
}
if opts.Modelfile.Parser != "qwen3-thinking" {
t.Errorf("Modelfile.Parser = %q, want %q", opts.Modelfile.Parser, "qwen3-thinking")
}
if opts.Modelfile.Renderer != "qwen3" {
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
}
}
func TestResolveParserName(t *testing.T) {
tests := []struct {
name string
mf *ModelfileConfig
inferred string
want string
}{
{
name: "nil modelfile uses inferred",
mf: nil,
inferred: "qwen3",
want: "qwen3",
},
{
name: "empty parser uses inferred",
mf: &ModelfileConfig{
Parser: "",
},
inferred: "qwen3",
want: "qwen3",
},
{
name: "explicit parser overrides inferred",
mf: &ModelfileConfig{
Parser: "qwen3-thinking",
},
inferred: "qwen3",
want: "qwen3-thinking",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveParserName(tt.mf, tt.inferred); got != tt.want {
t.Fatalf("resolveParserName() = %q, want %q", got, tt.want)
}
})
}
}
func TestResolveRendererName(t *testing.T) {
tests := []struct {
name string
mf *ModelfileConfig
inferred string
want string
}{
{
name: "nil modelfile uses inferred",
mf: nil,
inferred: "qwen3-coder",
want: "qwen3-coder",
},
{
name: "empty renderer uses inferred",
mf: &ModelfileConfig{
Renderer: "",
},
inferred: "qwen3-coder",
want: "qwen3-coder",
},
{
name: "explicit renderer overrides inferred",
mf: &ModelfileConfig{
Renderer: "qwen3",
},
inferred: "qwen3-coder",
want: "qwen3",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveRendererName(tt.mf, tt.inferred); got != tt.want {
t.Fatalf("resolveRendererName() = %q, want %q", got, tt.want)
}
})
}
}
func TestCreateOptions_Defaults(t *testing.T) {

View File

@@ -102,15 +102,20 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
for _, entry := range entries {
name := entry.name
// Try to get tensor by stripped name first, then with component prefix.
// Blobs may store tensors with the full prefixed name (e.g., "text_encoder/model.layers.0.weight")
// while the tensors map uses stripped names (e.g., "model.layers.0.weight").
// Try to get tensor by stripped name first, then with component prefix,
// then fall back to "data" for legacy blobs created by older versions
// that stored all tensors with the generic key "data".
lookupName := name
arr := sf.Get(lookupName)
if arr == nil && mw.component != "" {
lookupName = mw.component + "/" + name
arr = sf.Get(lookupName)
}
if arr == nil {
// Legacy blob format: tensor stored as "data"
lookupName = "data"
arr = sf.Get(lookupName)
}
if arr != nil {
// Single-tensor blob or tensor found by name
if dtype != 0 && arr.Dtype() != dtype {

View File

@@ -16,10 +16,10 @@ import (
)
type Function struct {
Name string
ReturnType string
Params string
ParamNames []string
Name string
ReturnType string
Params string
ParamNames []string
NeedsARM64Guard bool
}
@@ -29,6 +29,11 @@ func findHeaders(directory string) ([]string, error) {
if err != nil {
return err
}
// Private headers contain C++ implementation helpers and are not part of
// the C API surface; parsing them can produce invalid wrapper signatures.
if d.IsDir() && d.Name() == "private" {
return fs.SkipDir
}
if !d.IsDir() && strings.HasSuffix(path, ".h") {
headers = append(headers, path)
}
@@ -194,10 +199,10 @@ func parseFunctions(content string) []Function {
needsGuard := needsARM64Guard(funcName, returnType, params)
functions = append(functions, Function{
Name: funcName,
ReturnType: returnType,
Params: params,
ParamNames: paramNames,
Name: funcName,
ReturnType: returnType,
Params: params,
ParamNames: paramNames,
NeedsARM64Guard: needsGuard,
})
}

View File

@@ -20,6 +20,8 @@ mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL;
mlx_array (*mlx_array_new_double_ptr)(double val) = NULL;
mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL;
mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL;
mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL;
int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL;
int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL;
int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL;
@@ -49,7 +51,7 @@ int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL;
#if defined(__aarch64__) || defined(_M_ARM64)
int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL;
#endif
@@ -67,7 +69,7 @@ const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL;
const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL;
const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL;
const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL;
const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
#if defined(__aarch64__) || defined(_M_ARM64)
const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL;
#endif
@@ -123,6 +125,7 @@ int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL;
int (*mlx_disable_compile_ptr)(void) = NULL;
int (*mlx_enable_compile_ptr)(void) = NULL;
int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL;
int (*mlx_cuda_is_available_ptr)(bool* res) = NULL;
mlx_device (*mlx_device_new_ptr)(void) = NULL;
mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL;
int (*mlx_device_free_ptr)(mlx_device dev) = NULL;
@@ -133,6 +136,16 @@ int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL;
int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL;
int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL;
int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL;
int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL;
int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL;
mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL;
int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL;
int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL;
int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL;
int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL;
int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
@@ -263,7 +276,6 @@ int (*mlx_reset_peak_memory_ptr)(void) = NULL;
int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL;
int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL;
int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL;
mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL;
int (*mlx_metal_is_available_ptr)(bool* res) = NULL;
int (*mlx_metal_start_capture_ptr)(const char* path) = NULL;
int (*mlx_metal_stop_capture_ptr)(void) = NULL;
@@ -658,6 +670,16 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n");
return -1;
}
mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed");
if (mlx_array_new_data_managed_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n");
return -1;
}
mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload");
if (mlx_array_new_data_managed_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n");
return -1;
}
mlx_array_set_ptr = dlsym(handle, "mlx_array_set");
if (mlx_array_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n");
@@ -1141,6 +1163,11 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n");
return -1;
}
mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available");
if (mlx_cuda_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n");
return -1;
}
mlx_device_new_ptr = dlsym(handle, "mlx_device_new");
if (mlx_device_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n");
@@ -1191,6 +1218,56 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n");
return -1;
}
mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available");
if (mlx_device_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n");
return -1;
}
mlx_device_count_ptr = dlsym(handle, "mlx_device_count");
if (mlx_device_count_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n");
return -1;
}
mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new");
if (mlx_device_info_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n");
return -1;
}
mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get");
if (mlx_device_info_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n");
return -1;
}
mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free");
if (mlx_device_info_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n");
return -1;
}
mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key");
if (mlx_device_info_has_key_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n");
return -1;
}
mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string");
if (mlx_device_info_is_string_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n");
return -1;
}
mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string");
if (mlx_device_info_get_string_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n");
return -1;
}
mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size");
if (mlx_device_info_get_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n");
return -1;
}
mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys");
if (mlx_device_info_get_keys_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n");
return -1;
}
mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather");
if (mlx_distributed_all_gather_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n");
@@ -1841,11 +1918,6 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n");
return -1;
}
mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info");
if (mlx_metal_device_info_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n");
return -1;
}
mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available");
if (mlx_metal_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n");
@@ -3528,6 +3600,14 @@ mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dt
return mlx_array_new_data_ptr(data, shape, dim, dtype);
}
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) {
return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor);
}
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) {
return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor);
}
int mlx_array_set(mlx_array* arr, const mlx_array src) {
return mlx_array_set_ptr(arr, src);
}
@@ -3644,7 +3724,7 @@ int mlx_array_item_float64(double* res, const mlx_array arr) {
return mlx_array_item_float64_ptr(res, arr);
}
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) {
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) {
return mlx_array_item_complex64_ptr(res, arr);
}
@@ -3704,7 +3784,7 @@ const double* mlx_array_data_float64(const mlx_array arr) {
return mlx_array_data_float64_ptr(arr);
}
const float _Complex* mlx_array_data_complex64(const mlx_array arr) {
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) {
return mlx_array_data_complex64_ptr(arr);
}
@@ -3916,6 +3996,10 @@ int mlx_set_compile_mode(mlx_compile_mode mode) {
return mlx_set_compile_mode_ptr(mode);
}
int mlx_cuda_is_available(bool* res) {
return mlx_cuda_is_available_ptr(res);
}
mlx_device mlx_device_new(void) {
return mlx_device_new_ptr();
}
@@ -3956,6 +4040,46 @@ int mlx_set_default_device(mlx_device dev) {
return mlx_set_default_device_ptr(dev);
}
int mlx_device_is_available(bool* avail, mlx_device dev) {
return mlx_device_is_available_ptr(avail, dev);
}
int mlx_device_count(int* count, mlx_device_type type) {
return mlx_device_count_ptr(count, type);
}
mlx_device_info mlx_device_info_new(void) {
return mlx_device_info_new_ptr();
}
int mlx_device_info_get(mlx_device_info* info, mlx_device dev) {
return mlx_device_info_get_ptr(info, dev);
}
int mlx_device_info_free(mlx_device_info info) {
return mlx_device_info_free_ptr(info);
}
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) {
return mlx_device_info_has_key_ptr(exists, info, key);
}
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) {
return mlx_device_info_is_string_ptr(is_string, info, key);
}
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) {
return mlx_device_info_get_string_ptr(value, info, key);
}
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) {
return mlx_device_info_get_size_ptr(value, info, key);
}
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) {
return mlx_device_info_get_keys_ptr(keys, info);
}
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) {
return mlx_distributed_all_gather_ptr(res, x, group, S);
}
@@ -4476,10 +4600,6 @@ int mlx_set_wired_limit(size_t* res, size_t limit) {
return mlx_set_wired_limit_ptr(res, limit);
}
mlx_metal_device_info_t mlx_metal_device_info(void) {
return mlx_metal_device_info_ptr();
}
int mlx_metal_is_available(bool* res) {
return mlx_metal_is_available_ptr(res);
}

View File

@@ -26,6 +26,8 @@
#undef mlx_array_new_double
#undef mlx_array_new_complex
#undef mlx_array_new_data
#undef mlx_array_new_data_managed
#undef mlx_array_new_data_managed_payload
#undef mlx_array_set
#undef mlx_array_set_bool
#undef mlx_array_set_int
@@ -121,6 +123,7 @@
#undef mlx_disable_compile
#undef mlx_enable_compile
#undef mlx_set_compile_mode
#undef mlx_cuda_is_available
#undef mlx_device_new
#undef mlx_device_new_type
#undef mlx_device_free
@@ -131,6 +134,16 @@
#undef mlx_device_get_type
#undef mlx_get_default_device
#undef mlx_set_default_device
#undef mlx_device_is_available
#undef mlx_device_count
#undef mlx_device_info_new
#undef mlx_device_info_get
#undef mlx_device_info_free
#undef mlx_device_info_has_key
#undef mlx_device_info_is_string
#undef mlx_device_info_get_string
#undef mlx_device_info_get_size
#undef mlx_device_info_get_keys
#undef mlx_distributed_all_gather
#undef mlx_distributed_all_max
#undef mlx_distributed_all_min
@@ -261,7 +274,6 @@
#undef mlx_set_cache_limit
#undef mlx_set_memory_limit
#undef mlx_set_wired_limit
#undef mlx_metal_device_info
#undef mlx_metal_is_available
#undef mlx_metal_start_capture
#undef mlx_metal_stop_capture
@@ -602,6 +614,8 @@ extern mlx_array (*mlx_array_new_float64_ptr)(double val);
extern mlx_array (*mlx_array_new_double_ptr)(double val);
extern mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val);
extern mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype);
extern mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
extern mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
extern int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src);
extern int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val);
extern int (*mlx_array_set_int_ptr)(mlx_array* arr, int val);
@@ -631,7 +645,7 @@ extern int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr);
extern int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr);
extern int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr);
extern int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr);
extern int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr);
extern int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
extern int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr);
#endif
@@ -649,7 +663,7 @@ extern const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr);
extern const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr);
extern const float* (*mlx_array_data_float32_ptr)(const mlx_array arr);
extern const double* (*mlx_array_data_float64_ptr)(const mlx_array arr);
extern const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
extern const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
extern const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr);
#endif
@@ -705,6 +719,7 @@ extern int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id);
extern int (*mlx_disable_compile_ptr)(void);
extern int (*mlx_enable_compile_ptr)(void);
extern int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode);
extern int (*mlx_cuda_is_available_ptr)(bool* res);
extern mlx_device (*mlx_device_new_ptr)(void);
extern mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index);
extern int (*mlx_device_free_ptr)(mlx_device dev);
@@ -715,6 +730,16 @@ extern int (*mlx_device_get_index_ptr)(int* index, mlx_device dev);
extern int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev);
extern int (*mlx_get_default_device_ptr)(mlx_device* dev);
extern int (*mlx_set_default_device_ptr)(mlx_device dev);
extern int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev);
extern int (*mlx_device_count_ptr)(int* count, mlx_device_type type);
extern mlx_device_info (*mlx_device_info_new_ptr)(void);
extern int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev);
extern int (*mlx_device_info_free_ptr)(mlx_device_info info);
extern int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key);
extern int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info);
extern int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
extern int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
extern int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
@@ -845,7 +870,6 @@ extern int (*mlx_reset_peak_memory_ptr)(void);
extern int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit);
extern int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit);
extern int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit);
extern mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void);
extern int (*mlx_metal_is_available_ptr)(bool* res);
extern int (*mlx_metal_start_capture_ptr)(const char* path);
extern int (*mlx_metal_stop_capture_ptr)(void);
@@ -1202,6 +1226,10 @@ mlx_array mlx_array_new_complex(float real_val, float imag_val);
mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype);
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
int mlx_array_set(mlx_array* arr, const mlx_array src);
int mlx_array_set_bool(mlx_array* arr, bool val);
@@ -1260,7 +1288,7 @@ int mlx_array_item_float32(float* res, const mlx_array arr);
int mlx_array_item_float64(double* res, const mlx_array arr);
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr);
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
@@ -1292,7 +1320,7 @@ const float* mlx_array_data_float32(const mlx_array arr);
const double* mlx_array_data_float64(const mlx_array arr);
const float _Complex* mlx_array_data_complex64(const mlx_array arr);
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
const float16_t* mlx_array_data_float16(const mlx_array arr);
@@ -1400,6 +1428,8 @@ int mlx_enable_compile(void);
int mlx_set_compile_mode(mlx_compile_mode mode);
int mlx_cuda_is_available(bool* res);
mlx_device mlx_device_new(void);
mlx_device mlx_device_new_type(mlx_device_type type, int index);
@@ -1420,6 +1450,26 @@ int mlx_get_default_device(mlx_device* dev);
int mlx_set_default_device(mlx_device dev);
int mlx_device_is_available(bool* avail, mlx_device dev);
int mlx_device_count(int* count, mlx_device_type type);
mlx_device_info mlx_device_info_new(void);
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
int mlx_device_info_free(mlx_device_info info);
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key);
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key);
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key);
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key);
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
@@ -1680,8 +1730,6 @@ int mlx_set_memory_limit(size_t* res, size_t limit);
int mlx_set_wired_limit(size_t* res, size_t limit);
mlx_metal_device_info_t mlx_metal_device_info(void);
int mlx_metal_is_available(bool* res);
int mlx_metal_start_capture(const char* path);

View File

@@ -2,76 +2,298 @@ package mlxrunner
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
type Client struct {
Port int
*exec.Cmd
port int
modelName string
vramSize uint64
done chan error
client *http.Client
lastErr string
lastErrLock sync.Mutex
mu sync.Mutex
cmd *exec.Cmd
}
func (c *Client) JoinPath(path string) string {
return (&url.URL{
Scheme: "http",
Host: net.JoinHostPort("127.0.0.1", strconv.Itoa(c.Port)),
}).JoinPath(path).String()
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
func NewClient(modelName string) (*Client, error) {
if err := imagegen.CheckPlatformSupport(); err != nil {
return nil, err
}
// Find a free port
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
if l, err := net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
port = rand.Intn(65535-49152) + 49152
}
// Get the current executable path
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
// Spawn subprocess: ollama runner --mlx-engine --model <name> --port <port>
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
if runtime.GOOS == "linux" {
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
found := false
for i := range cmd.Env {
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
// Estimate VRAM based on tensor size from manifest
var vramSize uint64
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
vramSize = uint64(modelManifest.TotalTensorSize())
} else {
vramSize = 8 * 1024 * 1024 * 1024
}
c := &Client{
port: port,
modelName: modelName,
vramSize: vramSize,
done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute},
cmd: cmd,
}
// Forward subprocess stdout/stderr to server logs
stdout, _ := cmd.StdoutPipe()
stderr, _ := cmd.StderrPipe()
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
slog.Info("mlx-runner", "msg", scanner.Text())
}
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
slog.Warn("mlx-runner", "msg", line)
c.lastErrLock.Lock()
c.lastErr = line
c.lastErrLock.Unlock()
}
}()
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
}
// Reap subprocess when it exits
go func() {
err := cmd.Wait()
c.done <- err
}()
// Wait for subprocess to be ready
if err := c.waitUntilRunning(); err != nil {
c.Close()
return nil, err
}
return c, nil
}
func (c *Client) CheckError(w *http.Response) error {
if w.StatusCode >= 400 {
return errors.New(w.Status)
func (c *Client) getLastErr() string {
c.lastErrLock.Lock()
defer c.lastErrLock.Unlock()
return c.lastErr
}
func (c *Client) waitUntilRunning() error {
ctx := context.Background()
timeout := time.After(2 * time.Minute)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case err := <-c.done:
errMsg := c.getLastErr()
if errMsg != "" {
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
}
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
case <-timeout:
errMsg := c.getLastErr()
if errMsg != "" {
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
}
return errors.New("timeout waiting for mlx runner to start")
case <-ticker.C:
if err := c.Ping(ctx); err == nil {
slog.Info("mlx runner is ready", "port", c.port)
return nil
}
}
}
}
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
type completionRequest struct {
Prompt string `json:"prompt"`
Options *completionOpts `json:"options,omitempty"`
}
type completionOpts struct {
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
MinP float32 `json:"min_p,omitempty"`
TopK int `json:"top_k,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
}
// Close terminates the subprocess.
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cmd != nil && c.cmd.Process != nil {
slog.Info("stopping mlx runner subprocess", "pid", c.cmd.Process.Pid)
c.cmd.Process.Signal(os.Interrupt)
select {
case <-c.done:
case <-time.After(5 * time.Second):
c.cmd.Process.Kill()
}
c.cmd = nil
}
return nil
}
// Close implements llm.LlamaServer.
func (c *Client) Close() error {
return c.Cmd.Process.Kill()
}
// Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(req); err != nil {
return err
creq := completionRequest{
Prompt: req.Prompt,
}
if req.Options != nil {
creq.Options = &completionOpts{
Temperature: req.Options.Temperature,
TopP: req.Options.TopP,
MinP: req.Options.MinP,
TopK: req.Options.TopK,
NumPredict: req.Options.NumPredict,
}
}
w, err := http.Post(c.JoinPath("/v1/completions"), "application/json", &b)
body, err := json.Marshal(creq)
if err != nil {
return err
}
defer w.Body.Close()
if err := c.CheckError(w); err != nil {
httpURL := fmt.Sprintf("http://127.0.0.1:%d/completion", c.port)
httpReq, err := http.NewRequestWithContext(ctx, "POST", httpURL, strings.NewReader(string(body)))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
scanner := bufio.NewScanner(w.Body)
for scanner.Scan() {
bts := scanner.Bytes()
resp, err := c.client.Do(httpReq)
if err != nil {
return err
}
defer resp.Body.Close()
var resp llm.CompletionResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
fn(resp)
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%s", strings.TrimSpace(string(respBody)))
}
return nil
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
var raw struct {
Content string `json:"content,omitempty"`
Done bool `json:"done"`
DoneReason int `json:"done_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue
}
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
DoneReason: llm.DoneReason(raw.DoneReason),
PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
EvalCount: raw.EvalCount,
EvalDuration: time.Duration(raw.EvalDuration),
}
fn(cresp)
if cresp.Done {
return nil
}
}
return scanner.Err()
}
func (c *Client) ContextLength() int {
@@ -80,71 +302,89 @@ func (c *Client) ContextLength() int {
// Detokenize implements llm.LlamaServer.
func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
panic("unimplemented")
return "", errors.New("not supported")
}
// Embedding implements llm.LlamaServer.
func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
panic("unimplemented")
return nil, 0, errors.New("not supported")
}
// GetDeviceInfos implements llm.LlamaServer.
func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
panic("unimplemented")
return nil
}
// GetPort implements llm.LlamaServer.
func (c *Client) GetPort() int {
return c.Port
return c.port
}
// HasExited implements llm.LlamaServer.
func (c *Client) HasExited() bool {
panic("unimplemented")
select {
case <-c.done:
return true
default:
return false
}
}
// Load implements llm.LlamaServer.
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) {
w, err := http.Post(c.JoinPath("/v1/models"), "application/json", nil)
if err != nil {
return nil, err
}
defer w.Body.Close()
return []ml.DeviceID{}, nil
return nil, nil
}
// ModelPath implements llm.LlamaServer.
func (c *Client) ModelPath() string {
panic("unimplemented")
return c.modelName
}
// Pid implements llm.LlamaServer.
func (c *Client) Pid() int {
panic("unimplemented")
c.mu.Lock()
defer c.mu.Unlock()
if c.cmd != nil && c.cmd.Process != nil {
return c.cmd.Process.Pid
}
return -1
}
// Ping implements llm.LlamaServer.
func (c *Client) Ping(ctx context.Context) error {
w, err := http.Get(c.JoinPath("/v1/status"))
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil {
return err
}
defer w.Body.Close()
resp, err := c.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("health check failed: %d", resp.StatusCode)
}
return nil
}
// Tokenize implements llm.LlamaServer.
func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
w, err := http.Post(c.JoinPath("/v1/tokenize"), "text/plain", strings.NewReader(content))
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/tokenize", c.port)
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, strings.NewReader(content))
if err != nil {
return nil, err
}
defer w.Body.Close()
req.Header.Set("Content-Type", "text/plain")
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var tokens []int
if err := json.NewDecoder(w.Body).Decode(&tokens); err != nil {
if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil {
return nil, err
}
@@ -153,22 +393,22 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
// TotalSize implements llm.LlamaServer.
func (c *Client) TotalSize() uint64 {
panic("unimplemented")
return c.vramSize
}
// VRAMByGPU implements llm.LlamaServer.
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
panic("unimplemented")
return c.vramSize
}
// VRAMSize implements llm.LlamaServer.
func (c *Client) VRAMSize() uint64 {
panic("unimplemented")
return c.vramSize
}
// WaitUntilRunning implements llm.LlamaServer.
func (c *Client) WaitUntilRunning(ctx context.Context) error {
panic("unimplemented")
return nil
}
var _ llm.LlamaServer = (*Client)(nil)

10
x/mlxrunner/imports.go Normal file
View File

@@ -0,0 +1,10 @@
//go:build mlx
package mlxrunner
import (
_ "github.com/ollama/ollama/x/models/gemma3"
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3"
)

View File

@@ -15,7 +15,7 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
FetchContent_Declare(
mlx-c

View File

@@ -133,6 +133,7 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
}
func (t *Array) Set(other *Array) {
Free(t.desc.inputs...)
other.desc.numRefs++
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx)
@@ -248,9 +249,9 @@ func Free(s ...*Array) (n int) {
free := make([]*Array, 0, 8192)
fn := func(t *Array) {
if t.Valid() {
free = append(free, t.desc.inputs...)
t.desc.numRefs--
if t.desc.numRefs <= 0 {
free = append(free, t.desc.inputs...)
logutil.Trace("Free", "t", t)
n += t.NumBytes()
C.mlx_array_free(t.ctx)

View File

@@ -24,6 +24,37 @@ func CheckInit() error {
return initError
}
// tryLoadFromDir searches a directory for libmlxc.* and tries to load it.
// Returns true if the library was successfully loaded.
func tryLoadFromDir(dir string) bool {
matches, err := fs.Glob(os.DirFS(dir), "libmlxc.*")
if err != nil || len(matches) == 0 {
return false
}
for _, match := range matches {
path := filepath.Join(dir, match)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
slog.Error("Failed to load MLX dynamic library", "path", path)
continue
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
C.mlx_dynamic_unload(&handle)
continue
}
return true
}
return false
}
func init() {
switch runtime.GOOS {
case "darwin":
@@ -33,44 +64,34 @@ func init() {
return
}
paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH")
if !ok {
slog.Debug("OLLAMA_LIBRARY_PATH not set, skipping mlx dynamic loading")
return
// Try OLLAMA_LIBRARY_PATH first
if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
for _, dir := range filepath.SplitList(paths) {
if tryLoadFromDir(dir) {
return
}
}
}
for _, path := range filepath.SplitList(paths) {
matches, err := fs.Glob(os.DirFS(path), "libmlxc.*")
if err != nil {
initError = fmt.Errorf("failed to glob for MLX libraries in %s: %w", path, err)
slog.Warn("MLX dynamic library not available", "error", initError)
return
// Build search paths: executable directory, then build directories
var searchDirs []string
if exe, err := os.Executable(); err == nil {
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
searchDirs = append(searchDirs, filepath.Dir(exe))
}
for _, match := range matches {
path := filepath.Join(paths, match)
slog.Info("Loading MLX dynamic library", "path", path)
if cwd, err := os.Getwd(); err == nil {
searchDirs = append(searchDirs, filepath.Join(cwd, "build", "lib", "ollama"))
}
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
slog.Error("Failed to load MLX dynamic library", "path", path)
continue
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
C.mlx_dynamic_unload(&handle)
continue
}
slog.Info("Loaded MLX dynamic library", "path", path)
for _, dir := range searchDirs {
if tryLoadFromDir(dir) {
return
}
}
initError = fmt.Errorf("failed to load any MLX dynamic library from OLLAMA_LIBRARY_PATH=%s", paths)
initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", searchDirs)
slog.Warn("MLX dynamic library not available", "error", initError)
}

View File

@@ -22,6 +22,19 @@ mlx_array (*mlx_array_new_data_)(
const int* shape,
int dim,
mlx_dtype dtype) = NULL;
mlx_array (*mlx_array_new_data_managed_)(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void (*dtor)(void*)) = NULL;
mlx_array (*mlx_array_new_data_managed_payload_)(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void* payload,
void (*dtor)(void*)) = NULL;
int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL;
int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL;
int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL;
@@ -56,7 +69,7 @@ int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL;
const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL;
@@ -70,7 +83,7 @@ const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL;
const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL;
const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL;
const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL;
const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL;
const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL;
int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL;
@@ -94,10 +107,11 @@ int (*mlx_closure_apply_)(
mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL;
int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -136,11 +150,12 @@ int (*mlx_closure_value_and_grad_apply_)(
const mlx_vector_array input) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL;
int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array)) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array)) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -161,12 +176,13 @@ int (*mlx_closure_custom_apply_)(
const mlx_vector_array input_2) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL;
int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -189,12 +205,13 @@ int (*mlx_closure_custom_jvp_apply_)(
size_t input_2_num) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL;
int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(
int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -228,6 +245,7 @@ int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL;
int (*mlx_disable_compile_)(void) = NULL;
int (*mlx_enable_compile_)(void) = NULL;
int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL;
int (*mlx_cuda_is_available_)(bool* res) = NULL;
mlx_device (*mlx_device_new_)(void) = NULL;
mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL;
int (*mlx_device_free_)(mlx_device dev) = NULL;
@@ -238,11 +256,28 @@ int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL;
int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL;
int (*mlx_get_default_device_)(mlx_device* dev) = NULL;
int (*mlx_set_default_device_)(mlx_device dev) = NULL;
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_)(void) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
int (*mlx_device_is_available_)(bool* avail, mlx_device dev) = NULL;
int (*mlx_device_count_)(int* count, mlx_device_type type) = NULL;
mlx_device_info (*mlx_device_info_new_)(void) = NULL;
int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev) = NULL;
int (*mlx_device_info_free_)(mlx_device_info info) = NULL;
int (*mlx_device_info_has_key_)(
bool* exists,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_is_string_)(
bool* is_string,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_string_)(
const char** value,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_size_)(
size_t* value,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info) = NULL;
int (*mlx_distributed_all_gather_)(
mlx_array* res,
const mlx_array x,
@@ -288,6 +323,11 @@ int (*mlx_distributed_sum_scatter_)(
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_)(void) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
void (*mlx_set_error_handler_)(
mlx_error_handler_func handler,
void* data,
@@ -450,6 +490,16 @@ int (*mlx_fast_rope_)(
int offset,
const mlx_array freqs /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_fast_rope_dynamic_)(
mlx_array* res,
const mlx_array x,
int dims,
bool traditional,
mlx_optional_float base,
float scale,
const mlx_array offset,
const mlx_array freqs /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_fast_scaled_dot_product_attention_)(
mlx_array* res,
const mlx_array queries,
@@ -560,14 +610,6 @@ int (*mlx_fft_rfftn_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
int (*mlx_load_reader_)(
mlx_array* res,
mlx_io_reader in_stream,
@@ -593,6 +635,14 @@ int (*mlx_save_safetensors_)(
const char* file,
const mlx_map_string_to_array param,
const mlx_map_string_to_string metadata) = NULL;
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
int (*mlx_linalg_cholesky_)(
mlx_array* res,
const mlx_array a,
@@ -733,7 +783,6 @@ int (*mlx_reset_peak_memory_)(void) = NULL;
int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL;
int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL;
int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL;
mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL;
int (*mlx_metal_is_available_)(bool* res) = NULL;
int (*mlx_metal_start_capture_)(const char* path) = NULL;
int (*mlx_metal_stop_capture_)(void) = NULL;
@@ -1162,6 +1211,14 @@ int (*mlx_gather_)(
const int* slice_sizes,
size_t slice_sizes_num,
const mlx_stream s) = NULL;
int (*mlx_gather_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
int axis,
const int* slice_sizes,
size_t slice_sizes_num,
const mlx_stream s) = NULL;
int (*mlx_gather_mm_)(
mlx_array* res,
const mlx_array a,
@@ -1483,6 +1540,15 @@ int (*mlx_put_along_axis_)(
const mlx_array values,
int axis,
const mlx_stream s) = NULL;
int (*mlx_qqmm_)(
mlx_array* res,
const mlx_array x,
const mlx_array w,
const mlx_array w_scales /* may be null */,
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_stream s) = NULL;
int (*mlx_quantize_)(
mlx_vector_array* res,
const mlx_array w,
@@ -1566,6 +1632,13 @@ int (*mlx_scatter_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_)(
mlx_array* res,
const mlx_array a,
@@ -1574,6 +1647,13 @@ int (*mlx_scatter_add_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_axis_)(
mlx_array* res,
const mlx_array a,
@@ -1589,6 +1669,13 @@ int (*mlx_scatter_max_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_max_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_min_)(
mlx_array* res,
const mlx_array a,
@@ -1597,6 +1684,13 @@ int (*mlx_scatter_min_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_min_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_prod_)(
mlx_array* res,
const mlx_array a,
@@ -1605,6 +1699,13 @@ int (*mlx_scatter_prod_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_prod_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_segmented_mm_)(
mlx_array* res,
const mlx_array a,
@@ -2028,22 +2129,6 @@ mlx_string (*mlx_string_new_data_)(const char* str) = NULL;
int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL;
const char * (*mlx_string_data_)(mlx_string str) = NULL;
int (*mlx_string_free_)(mlx_string str) = NULL;
int (*mlx_detail_vmap_replace_)(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num) = NULL;
int (*mlx_detail_vmap_trace_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num) = NULL;
int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL;
int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL;
int (*mlx_custom_function_)(
@@ -2074,6 +2159,22 @@ int (*mlx_vjp_)(
const mlx_closure fun,
const mlx_vector_array primals,
const mlx_vector_array cotangents) = NULL;
int (*mlx_detail_vmap_replace_)(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num) = NULL;
int (*mlx_detail_vmap_trace_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num) = NULL;
mlx_vector_array (*mlx_vector_array_new_)(void) = NULL;
int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL;
int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL;
@@ -2166,6 +2267,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_array_new_double);
CHECK_LOAD(handle, mlx_array_new_complex);
CHECK_LOAD(handle, mlx_array_new_data);
CHECK_LOAD(handle, mlx_array_new_data_managed);
CHECK_LOAD(handle, mlx_array_new_data_managed_payload);
CHECK_LOAD(handle, mlx_array_set);
CHECK_LOAD(handle, mlx_array_set_bool);
CHECK_LOAD(handle, mlx_array_set_int);
@@ -2261,6 +2364,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_disable_compile);
CHECK_LOAD(handle, mlx_enable_compile);
CHECK_LOAD(handle, mlx_set_compile_mode);
CHECK_LOAD(handle, mlx_cuda_is_available);
CHECK_LOAD(handle, mlx_device_new);
CHECK_LOAD(handle, mlx_device_new_type);
CHECK_LOAD(handle, mlx_device_free);
@@ -2271,11 +2375,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_device_get_type);
CHECK_LOAD(handle, mlx_get_default_device);
CHECK_LOAD(handle, mlx_set_default_device);
CHECK_LOAD(handle, mlx_distributed_group_rank);
CHECK_LOAD(handle, mlx_distributed_group_size);
CHECK_LOAD(handle, mlx_distributed_group_split);
CHECK_LOAD(handle, mlx_distributed_is_available);
CHECK_LOAD(handle, mlx_distributed_init);
CHECK_LOAD(handle, mlx_device_is_available);
CHECK_LOAD(handle, mlx_device_count);
CHECK_LOAD(handle, mlx_device_info_new);
CHECK_LOAD(handle, mlx_device_info_get);
CHECK_LOAD(handle, mlx_device_info_free);
CHECK_LOAD(handle, mlx_device_info_has_key);
CHECK_LOAD(handle, mlx_device_info_is_string);
CHECK_LOAD(handle, mlx_device_info_get_string);
CHECK_LOAD(handle, mlx_device_info_get_size);
CHECK_LOAD(handle, mlx_device_info_get_keys);
CHECK_LOAD(handle, mlx_distributed_all_gather);
CHECK_LOAD(handle, mlx_distributed_all_max);
CHECK_LOAD(handle, mlx_distributed_all_min);
@@ -2284,6 +2393,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_distributed_recv_like);
CHECK_LOAD(handle, mlx_distributed_send);
CHECK_LOAD(handle, mlx_distributed_sum_scatter);
CHECK_LOAD(handle, mlx_distributed_group_rank);
CHECK_LOAD(handle, mlx_distributed_group_size);
CHECK_LOAD(handle, mlx_distributed_group_split);
CHECK_LOAD(handle, mlx_distributed_is_available);
CHECK_LOAD(handle, mlx_distributed_init);
CHECK_LOAD(handle, mlx_set_error_handler);
CHECK_LOAD(handle, _mlx_error);
CHECK_LOAD(handle, mlx_export_function);
@@ -2325,6 +2439,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_fast_metal_kernel_apply);
CHECK_LOAD(handle, mlx_fast_rms_norm);
CHECK_LOAD(handle, mlx_fast_rope);
CHECK_LOAD(handle, mlx_fast_rope_dynamic);
CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention);
CHECK_LOAD(handle, mlx_fft_fft);
CHECK_LOAD(handle, mlx_fft_fft2);
@@ -2340,14 +2455,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_fft_rfft);
CHECK_LOAD(handle, mlx_fft_rfft2);
CHECK_LOAD(handle, mlx_fft_rfftn);
CHECK_LOAD(handle, mlx_io_reader_new);
CHECK_LOAD(handle, mlx_io_reader_descriptor);
CHECK_LOAD(handle, mlx_io_reader_tostring);
CHECK_LOAD(handle, mlx_io_reader_free);
CHECK_LOAD(handle, mlx_io_writer_new);
CHECK_LOAD(handle, mlx_io_writer_descriptor);
CHECK_LOAD(handle, mlx_io_writer_tostring);
CHECK_LOAD(handle, mlx_io_writer_free);
CHECK_LOAD(handle, mlx_load_reader);
CHECK_LOAD(handle, mlx_load);
CHECK_LOAD(handle, mlx_load_safetensors_reader);
@@ -2356,6 +2463,14 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_save);
CHECK_LOAD(handle, mlx_save_safetensors_writer);
CHECK_LOAD(handle, mlx_save_safetensors);
CHECK_LOAD(handle, mlx_io_reader_new);
CHECK_LOAD(handle, mlx_io_reader_descriptor);
CHECK_LOAD(handle, mlx_io_reader_tostring);
CHECK_LOAD(handle, mlx_io_reader_free);
CHECK_LOAD(handle, mlx_io_writer_new);
CHECK_LOAD(handle, mlx_io_writer_descriptor);
CHECK_LOAD(handle, mlx_io_writer_tostring);
CHECK_LOAD(handle, mlx_io_writer_free);
CHECK_LOAD(handle, mlx_linalg_cholesky);
CHECK_LOAD(handle, mlx_linalg_cholesky_inv);
CHECK_LOAD(handle, mlx_linalg_cross);
@@ -2400,7 +2515,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_set_cache_limit);
CHECK_LOAD(handle, mlx_set_memory_limit);
CHECK_LOAD(handle, mlx_set_wired_limit);
CHECK_LOAD(handle, mlx_metal_device_info);
CHECK_LOAD(handle, mlx_metal_is_available);
CHECK_LOAD(handle, mlx_metal_start_capture);
CHECK_LOAD(handle, mlx_metal_stop_capture);
@@ -2486,6 +2600,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_full);
CHECK_LOAD(handle, mlx_full_like);
CHECK_LOAD(handle, mlx_gather);
CHECK_LOAD(handle, mlx_gather_single);
CHECK_LOAD(handle, mlx_gather_mm);
CHECK_LOAD(handle, mlx_gather_qmm);
CHECK_LOAD(handle, mlx_greater);
@@ -2550,6 +2665,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_prod_axis);
CHECK_LOAD(handle, mlx_prod);
CHECK_LOAD(handle, mlx_put_along_axis);
CHECK_LOAD(handle, mlx_qqmm);
CHECK_LOAD(handle, mlx_quantize);
CHECK_LOAD(handle, mlx_quantized_matmul);
CHECK_LOAD(handle, mlx_radians);
@@ -2566,11 +2682,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_round);
CHECK_LOAD(handle, mlx_rsqrt);
CHECK_LOAD(handle, mlx_scatter);
CHECK_LOAD(handle, mlx_scatter_single);
CHECK_LOAD(handle, mlx_scatter_add);
CHECK_LOAD(handle, mlx_scatter_add_single);
CHECK_LOAD(handle, mlx_scatter_add_axis);
CHECK_LOAD(handle, mlx_scatter_max);
CHECK_LOAD(handle, mlx_scatter_max_single);
CHECK_LOAD(handle, mlx_scatter_min);
CHECK_LOAD(handle, mlx_scatter_min_single);
CHECK_LOAD(handle, mlx_scatter_prod);
CHECK_LOAD(handle, mlx_scatter_prod_single);
CHECK_LOAD(handle, mlx_segmented_mm);
CHECK_LOAD(handle, mlx_sigmoid);
CHECK_LOAD(handle, mlx_sign);
@@ -2665,8 +2786,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_string_set);
CHECK_LOAD(handle, mlx_string_data);
CHECK_LOAD(handle, mlx_string_free);
CHECK_LOAD(handle, mlx_detail_vmap_replace);
CHECK_LOAD(handle, mlx_detail_vmap_trace);
CHECK_LOAD(handle, mlx_async_eval);
CHECK_LOAD(handle, mlx_checkpoint);
CHECK_LOAD(handle, mlx_custom_function);
@@ -2675,6 +2794,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_jvp);
CHECK_LOAD(handle, mlx_value_and_grad);
CHECK_LOAD(handle, mlx_vjp);
CHECK_LOAD(handle, mlx_detail_vmap_replace);
CHECK_LOAD(handle, mlx_detail_vmap_trace);
CHECK_LOAD(handle, mlx_vector_array_new);
CHECK_LOAD(handle, mlx_vector_array_set);
CHECK_LOAD(handle, mlx_vector_array_free);

View File

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,10 @@
#define MLX_GENERATED_H
#include "dynamic.h"
{{ range .Functions }}
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
{{- end }}
#include "mlx/c/mlx.h"
{{ range .Functions }}
#undef {{ .Name }}

View File

@@ -306,19 +306,42 @@ func AddMM(c, a, b *Array, alpha, beta float32) *Array {
// Scalar helpers
// scalarWithDtype creates a scalar array matching the dtype of a.
// Matching dtype is important for graph fusion and avoiding implicit casts.
func scalarWithDtype(s float32, a *Array) C.mlx_array {
f32 := C.mlx_array_new_float(C.float(s))
dtype := a.DType()
if dtype == DTypeFloat32 {
return f32
}
casted := C.mlx_array_new()
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx)
C.mlx_array_free(f32)
return casted
}
func AddScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return a.Add(scalar)
scalar := scalarWithDtype(s, a)
out := New("ADD_SCALAR", a)
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func MulScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return a.Multiply(scalar)
scalar := scalarWithDtype(s, a)
out := New("MUL_SCALAR", a)
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func DivScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return a.Divide(scalar)
scalar := scalarWithDtype(s, a)
out := New("DIV_SCALAR", a)
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func FloorDivideScalar(a *Array, s int32) *Array {

View File

@@ -0,0 +1,85 @@
//go:build mlx
package base
import (
"encoding/json"
"fmt"
"log/slog"
"sync"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
)
// Model is the interface that model implementations must satisfy.
type Model interface {
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
Unembed(x *mlx.Array) *mlx.Array
NumLayers() int
Tokenizer() *tokenizer.Tokenizer
// LoadWeights receives all tensors loaded from the manifest and assigns
// them to model fields. Model-specific logic (MLA absorption, expert
// stacking, quantized layer creation) happens here.
LoadWeights(tensors map[string]*mlx.Array) error
}
var (
mu sync.Mutex
registry = make(map[string]func(root *model.Root) (Model, error))
)
// Register registers a model constructor by architecture name.
// Called from init() in model packages. Panics on duplicate registration.
func Register(arch string, fn func(root *model.Root) (Model, error)) {
mu.Lock()
defer mu.Unlock()
if _, exists := registry[arch]; exists {
panic(fmt.Sprintf("model architecture %q already registered", arch))
}
registry[arch] = fn
}
// New reads config.json from the manifest, detects the architecture, looks up
// the registered constructor, and calls it to create the model (with config
// parsed and struct created, but weights not yet loaded).
func New(root *model.Root) (Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("failed to read config.json: %w", err)
}
var archConfig struct {
Architectures []string `json:"architectures"`
}
if err := json.Unmarshal(configData, &archConfig); err != nil {
return nil, fmt.Errorf("failed to parse config.json: %w", err)
}
if len(archConfig.Architectures) == 0 {
return nil, fmt.Errorf("no architectures found in config.json")
}
arch := archConfig.Architectures[0]
slog.Info("Model architecture", "arch", arch)
mu.Lock()
fn, ok := registry[arch]
mu.Unlock()
if !ok {
return nil, fmt.Errorf("unsupported architecture: %s", arch)
}
return fn(root)
}
// Weights returns the model's LoadWeights method, which encapsulates all
// weight assignment and post-processing (MLA absorption, expert stacking).
func Weights(m Model) func(map[string]*mlx.Array) error {
return m.LoadWeights
}

View File

@@ -0,0 +1,3 @@
//go:build !mlx
package base

View File

@@ -0,0 +1,92 @@
//go:build mlx
package model
import (
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
// LinearFactory builds linear layers using shared tensor maps and quant defaults.
type LinearFactory struct {
tensors map[string]*mlx.Array
defaultGroupSize int
defaultBits int
defaultMode string
tensorQuant map[string]*TensorQuantInfo
}
// NewLinearFactory creates a reusable constructor for model linear layers.
func NewLinearFactory(
tensors map[string]*mlx.Array,
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
) LinearFactory {
return LinearFactory{
tensors: tensors,
defaultGroupSize: defaultGroupSize,
defaultBits: defaultBits,
defaultMode: defaultMode,
tensorQuant: tensorQuant,
}
}
// Make constructs a linear layer at path.
func (f LinearFactory) Make(path string) nn.LinearLayer {
return MakeLinearLayer(
f.tensors,
path,
f.defaultGroupSize,
f.defaultBits,
f.defaultMode,
f.tensorQuant,
)
}
// MakeLinearLayer constructs a linear layer from a tensor map.
//
// For quantized tensors (path.weight + path.weight_scale), it resolves per-tensor
// quant params via TensorQuant metadata (with shape-based affine fallback).
// For non-quantized tensors, it returns a standard nn.Linear.
func MakeLinearLayer(
tensors map[string]*mlx.Array,
path string,
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
) nn.LinearLayer {
w := tensors[path+".weight"]
if w == nil {
return nil
}
scales := tensors[path+".weight_scale"]
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
bias := tensors[path+".bias"]
groupSize, bits, mode := ResolveLinearQuantParams(
defaultGroupSize,
defaultBits,
defaultMode,
tensorQuant,
path+".weight",
w,
scales,
)
return &nn.QuantizedLinear{
Weight: w,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
bias := tensors[path+".bias"]
return nn.NewLinear(w, bias)
}

130
x/mlxrunner/model/quant.go Normal file
View File

@@ -0,0 +1,130 @@
//go:build mlx
package model
import (
"strings"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// QuantizationParams returns default groupSize, bits, and mode for a quantization type.
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "NVFP4":
return 16, 4, "nvfp4"
case "FP4", "Q4", "INT4":
return 32, 4, "affine"
case "MXFP8":
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8", "":
return 64, 8, "affine"
default:
return 32, 8, "affine"
}
}
// TensorQuantParams resolves quant params for a tensor using per-tensor metadata
// when available, otherwise falling back to the provided model defaults.
func TensorQuantParams(
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
tensorName string,
) (groupSize, bits int, mode string, fromTensor bool) {
if tensorQuant != nil {
if tq := tensorQuant[tensorName]; tq != nil {
groupSize, bits, mode = QuantizationParams(tq.QuantType)
if tq.GroupSize > 0 {
groupSize = tq.GroupSize
}
return groupSize, bits, mode, true
}
}
return defaultGroupSize, defaultBits, defaultMode, false
}
// ResolveLinearQuantParams resolves quantization params for a quantized linear
// tensor, preferring per-tensor metadata and falling back to shape-based
// inference for affine packed tensors.
func ResolveLinearQuantParams(
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
tensorName string,
weight, scales *mlx.Array,
) (groupSize, bits int, mode string) {
groupSize, bits, mode, fromTensor := TensorQuantParams(
defaultGroupSize,
defaultBits,
defaultMode,
tensorQuant,
tensorName,
)
if mode == "affine" {
if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok {
if !fromTensor || groupSize == 0 || bits == 0 {
groupSize = inferredGroupSize
bits = inferredBits
}
}
}
return groupSize, bits, mode
}
// InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized
// tensors from packed weight and scale shapes.
func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) {
if weight == nil || scales == nil {
return 0, 0, false
}
weightShape := weight.Dims()
scaleShape := scales.Dims()
if len(weightShape) == 0 || len(scaleShape) == 0 {
return 0, 0, false
}
weightCols := weightShape[len(weightShape)-1]
scalesCols := scaleShape[len(scaleShape)-1]
if weightCols <= 0 || scalesCols <= 0 {
return 0, 0, false
}
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
switch {
case groupSize4 == 32:
return 32, 4, true
case groupSize8 == 64:
return 64, 8, true
case groupSize4 == 64 && groupSize8 == 32:
if hintBits == 8 {
return 32, 8, true
}
if hintBits == 4 {
return 64, 4, true
}
}
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
return groupSize4, 4, true
}
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
return groupSize8, 8, true
}
return 0, 0, false
}
func isCommonGroupSize(v int) bool {
switch v {
case 16, 32, 64, 128:
return true
default:
return false
}
}

Some files were not shown because too many files have changed in this diff Show More