mirror of
https://github.com/ollama/ollama.git
synced 2026-02-19 07:45:22 -05:00
Compare commits
18 Commits
v0.16.1
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
365a3657ad | ||
|
|
71c1d8d0a9 | ||
|
|
325b72bc31 | ||
|
|
f01a9a7859 | ||
|
|
9aefd2dfee | ||
|
|
d07e4a1dd3 | ||
|
|
8a257ec00a | ||
|
|
2f4de1acf7 | ||
|
|
ec95c45f70 | ||
|
|
3a88f7eb20 | ||
|
|
0d5da826d4 | ||
|
|
9b795698b8 | ||
|
|
041fb77639 | ||
|
|
8224cce583 | ||
|
|
d18dcd7775 | ||
|
|
5f5ef20131 | ||
|
|
f0a07a353b | ||
|
|
948de6bbd2 |
@@ -1 +1 @@
|
||||
v0.4.1
|
||||
v0.5.0
|
||||
|
||||
@@ -16,7 +16,7 @@ Start building with open models.
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
```
|
||||
|
||||
or [download manually](http://localhost:8080/download/Ollama.dmg)
|
||||
or [download manually](https://ollama.com/download/Ollama.dmg)
|
||||
|
||||
### Windows
|
||||
|
||||
|
||||
@@ -1,17 +1,25 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
// Error types matching Anthropic API
|
||||
@@ -82,22 +90,25 @@ type MessageParam struct {
|
||||
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
|
||||
// only when set, which is required for SDK streaming accumulation.
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking, server_tool_use, web_search_tool_result
|
||||
|
||||
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Text *string `json:"text,omitempty"`
|
||||
|
||||
// For text blocks with citations
|
||||
Citations []Citation `json:"citations,omitempty"`
|
||||
|
||||
// For image blocks
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
|
||||
// For tool_use blocks
|
||||
// For tool_use and server_tool_use blocks
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
|
||||
// For tool_result blocks
|
||||
// For tool_result and web_search_tool_result blocks
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content any `json:"content,omitempty"` // string or []ContentBlock
|
||||
Content any `json:"content,omitempty"` // string, []ContentBlock, []WebSearchResult, or WebSearchToolResultError
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
|
||||
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
@@ -105,6 +116,30 @@ type ContentBlock struct {
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// Citation represents a citation in a text block
|
||||
type Citation struct {
|
||||
Type string `json:"type"` // "web_search_result_location"
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
EncryptedIndex string `json:"encrypted_index,omitempty"`
|
||||
CitedText string `json:"cited_text,omitempty"`
|
||||
}
|
||||
|
||||
// WebSearchResult represents a single web search result
|
||||
type WebSearchResult struct {
|
||||
Type string `json:"type"` // "web_search_result"
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
EncryptedContent string `json:"encrypted_content,omitempty"`
|
||||
PageAge string `json:"page_age,omitempty"`
|
||||
}
|
||||
|
||||
// WebSearchToolResultError represents an error from web search
|
||||
type WebSearchToolResultError struct {
|
||||
Type string `json:"type"` // "web_search_tool_result_error"
|
||||
ErrorCode string `json:"error_code"`
|
||||
}
|
||||
|
||||
// ImageSource represents the source of an image
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64" or "url"
|
||||
@@ -115,10 +150,13 @@ type ImageSource struct {
|
||||
|
||||
// Tool represents a tool definition
|
||||
type Tool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools, or "web_search_20250305" for web search
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
||||
|
||||
// Web search specific fields
|
||||
MaxUses int `json:"max_uses,omitempty"`
|
||||
}
|
||||
|
||||
// ToolChoice controls how the model uses tools
|
||||
@@ -233,6 +271,8 @@ type StreamErrorEvent struct {
|
||||
|
||||
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
||||
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
logutil.Trace("anthropic: converting request", "req", TraceMessagesRequest(r))
|
||||
|
||||
var messages []api.Message
|
||||
|
||||
if r.System != nil {
|
||||
@@ -259,9 +299,10 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range r.Messages {
|
||||
for i, msg := range r.Messages {
|
||||
converted, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: message conversion failed", "index", i, "role", msg.Role, "err", err)
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, converted...)
|
||||
@@ -288,8 +329,24 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
var tools api.Tools
|
||||
hasBuiltinWebSearch := false
|
||||
for _, t := range r.Tools {
|
||||
tool, err := convertTool(t)
|
||||
if strings.HasPrefix(t.Type, "web_search") {
|
||||
hasBuiltinWebSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range r.Tools {
|
||||
// Anthropic built-in web_search maps to Ollama function name "web_search".
|
||||
// If a user-defined tool also uses that name in the same request, drop the
|
||||
// user-defined one to avoid ambiguous tool-call routing.
|
||||
if hasBuiltinWebSearch && !strings.HasPrefix(t.Type, "web_search") && t.Name == "web_search" {
|
||||
logutil.Trace("anthropic: dropping colliding custom web_search tool", "tool", TraceTool(t))
|
||||
continue
|
||||
}
|
||||
|
||||
tool, _, err := convertTool(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -302,15 +359,17 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
stream := r.Stream
|
||||
|
||||
return &api.ChatRequest{
|
||||
convertedRequest := &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Options: options,
|
||||
Stream: &stream,
|
||||
Tools: tools,
|
||||
Think: think,
|
||||
}, nil
|
||||
}
|
||||
logutil.Trace("anthropic: converted request", "req", TraceChatRequest(convertedRequest))
|
||||
|
||||
return convertedRequest, nil
|
||||
}
|
||||
|
||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||
@@ -328,10 +387,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
textBlocks := 0
|
||||
imageBlocks := 0
|
||||
toolUseBlocks := 0
|
||||
toolResultBlocks := 0
|
||||
serverToolUseBlocks := 0
|
||||
webSearchToolResultBlocks := 0
|
||||
thinkingBlocks := 0
|
||||
unknownBlocks := 0
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid content block format", "role", role)
|
||||
return nil, errors.New("invalid content block format")
|
||||
}
|
||||
|
||||
@@ -339,13 +407,16 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
|
||||
switch blockType {
|
||||
case "text":
|
||||
textBlocks++
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent.WriteString(text)
|
||||
}
|
||||
|
||||
case "image":
|
||||
imageBlocks++
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid image source", "role", role)
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
@@ -354,21 +425,26 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
|
||||
case "tool_use":
|
||||
toolUseBlocks++
|
||||
id, ok := blockMap["id"].(string)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: tool_use block missing id", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
name, ok := blockMap["name"].(string)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: tool_use block missing name", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
tc := api.ToolCall{
|
||||
@@ -383,6 +459,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "tool_result":
|
||||
toolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
|
||||
@@ -408,9 +485,36 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
thinkingBlocks++
|
||||
if t, ok := blockMap["thinking"].(string); ok {
|
||||
thinking = t
|
||||
}
|
||||
|
||||
case "server_tool_use":
|
||||
serverToolUseBlocks++
|
||||
id, _ := blockMap["id"].(string)
|
||||
name, _ := blockMap["name"].(string)
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "web_search_tool_result":
|
||||
webSearchToolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: formatWebSearchToolResultContent(blockMap["content"]),
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
default:
|
||||
unknownBlocks++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,6 +531,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
logutil.Trace("anthropic: converted block message",
|
||||
"role", role,
|
||||
"blocks", len(content),
|
||||
"text", textBlocks,
|
||||
"image", imageBlocks,
|
||||
"tool_use", toolUseBlocks,
|
||||
"tool_result", toolResultBlocks,
|
||||
"server_tool_use", serverToolUseBlocks,
|
||||
"web_search_result", webSearchToolResultBlocks,
|
||||
"thinking", thinkingBlocks,
|
||||
"unknown", unknownBlocks,
|
||||
"messages", TraceAPIMessages(messages),
|
||||
)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
@@ -435,12 +552,94 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool
|
||||
func convertTool(t Tool) (api.Tool, error) {
|
||||
func formatWebSearchToolResultContent(content any) string {
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return c
|
||||
case []WebSearchResult:
|
||||
var resultContent strings.Builder
|
||||
for _, item := range c {
|
||||
if item.Type != "web_search_result" {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(&resultContent, "- %s: %s\n", item.Title, item.URL)
|
||||
}
|
||||
return resultContent.String()
|
||||
case []any:
|
||||
var resultContent strings.Builder
|
||||
for _, item := range c {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch itemMap["type"] {
|
||||
case "web_search_result":
|
||||
title, _ := itemMap["title"].(string)
|
||||
url, _ := itemMap["url"].(string)
|
||||
fmt.Fprintf(&resultContent, "- %s: %s\n", title, url)
|
||||
case "web_search_tool_result_error":
|
||||
errorCode, _ := itemMap["error_code"].(string)
|
||||
if errorCode == "" {
|
||||
return "web_search_tool_result_error"
|
||||
}
|
||||
return "web_search_tool_result_error: " + errorCode
|
||||
}
|
||||
}
|
||||
return resultContent.String()
|
||||
case map[string]any:
|
||||
if c["type"] == "web_search_tool_result_error" {
|
||||
errorCode, _ := c["error_code"].(string)
|
||||
if errorCode == "" {
|
||||
return "web_search_tool_result_error"
|
||||
}
|
||||
return "web_search_tool_result_error: " + errorCode
|
||||
}
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
case WebSearchToolResultError:
|
||||
if c.ErrorCode == "" {
|
||||
return "web_search_tool_result_error"
|
||||
}
|
||||
return "web_search_tool_result_error: " + c.ErrorCode
|
||||
default:
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool, returning true if it's a server tool
|
||||
func convertTool(t Tool) (api.Tool, bool, error) {
|
||||
if strings.HasPrefix(t.Type, "web_search") {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("query", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The search query to look up on the web",
|
||||
})
|
||||
return api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "web_search",
|
||||
Description: "Search the web for current information. Use this to find up-to-date information about any topic.",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"query"},
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
var params api.ToolFunctionParameters
|
||||
if len(t.InputSchema) > 0 {
|
||||
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
||||
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
logutil.Trace("anthropic: invalid tool schema", "tool", t.Name, "err", err)
|
||||
return api.Tool{}, false, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -451,7 +650,7 @@ func convertTool(t Tool) (api.Tool, error) {
|
||||
Description: t.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
}, false, nil
|
||||
}
|
||||
|
||||
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
||||
@@ -899,3 +1098,113 @@ func countContentBlock(block any) int {
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
// OllamaWebSearchRequest represents a request to the Ollama web search API
|
||||
type OllamaWebSearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
MaxResults int `json:"max_results,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaWebSearchResult represents a single search result from Ollama API
|
||||
type OllamaWebSearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// OllamaWebSearchResponse represents the response from the Ollama web search API
|
||||
type OllamaWebSearchResponse struct {
|
||||
Results []OllamaWebSearchResult `json:"results"`
|
||||
}
|
||||
|
||||
var WebSearchEndpoint = "https://ollama.com/api/web_search"
|
||||
|
||||
func WebSearch(ctx context.Context, query string, maxResults int) (*OllamaWebSearchResponse, error) {
|
||||
if internalcloud.Disabled() {
|
||||
logutil.TraceContext(ctx, "anthropic: web search blocked", "reason", "cloud_disabled")
|
||||
return nil, errors.New(internalcloud.DisabledError("web search is unavailable"))
|
||||
}
|
||||
|
||||
if maxResults <= 0 {
|
||||
maxResults = 5
|
||||
}
|
||||
if maxResults > 10 {
|
||||
maxResults = 10
|
||||
}
|
||||
|
||||
reqBody := OllamaWebSearchRequest{
|
||||
Query: query,
|
||||
MaxResults: maxResults,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal web search request: %w", err)
|
||||
}
|
||||
|
||||
searchURL, err := url.Parse(WebSearchEndpoint)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse web search URL: %w", err)
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic: web search request",
|
||||
"query", TraceTruncateString(query),
|
||||
"max_results", maxResults,
|
||||
"url", searchURL.String(),
|
||||
)
|
||||
|
||||
q := searchURL.Query()
|
||||
q.Set("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
searchURL.RawQuery = q.Encode()
|
||||
|
||||
signature := ""
|
||||
if strings.EqualFold(searchURL.Hostname(), "ollama.com") {
|
||||
challenge := fmt.Sprintf("%s,%s", http.MethodPost, searchURL.RequestURI())
|
||||
signature, err = auth.Sign(ctx, []byte(challenge))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign web search request: %w", err)
|
||||
}
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic: web search auth", "signed", signature != "")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL.String(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create web search request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if signature != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("web search request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
logutil.TraceContext(ctx, "anthropic: web search response", "status", resp.StatusCode)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("web search returned status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var searchResp OllamaWebSearchResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&searchResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode web search response: %w", err)
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic: web search results", "count", len(searchResp.Results))
|
||||
|
||||
return &searchResp, nil
|
||||
}
|
||||
|
||||
func ConvertOllamaToAnthropicResults(ollamaResults *OllamaWebSearchResponse) []WebSearchResult {
|
||||
var results []WebSearchResult
|
||||
for _, r := range ollamaResults.Results {
|
||||
results = append(results, WebSearchResult{
|
||||
Type: "web_search_result",
|
||||
URL: r.URL,
|
||||
Title: r.Title,
|
||||
})
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package anthropic
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -300,6 +301,78 @@ func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Type: "web_search_20250305",
|
||||
Name: "web_search",
|
||||
},
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "web_search",
|
||||
Description: "User-defined web search that should be dropped",
|
||||
InputSchema: json.RawMessage(`{"type":"invalid"}`),
|
||||
},
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 2 {
|
||||
t.Fatalf("expected 2 tools after dropping custom web_search, got %d", len(result.Tools))
|
||||
}
|
||||
if result.Tools[0].Function.Name != "web_search" {
|
||||
t.Fatalf("expected first tool to be built-in web_search, got %q", result.Tools[0].Function.Name)
|
||||
}
|
||||
if result.Tools[1].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected second tool to be get_weather, got %q", result.Tools[1].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "web_search",
|
||||
Description: "User-defined web search",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 1 {
|
||||
t.Fatalf("expected 1 custom tool, got %d", len(result.Tools))
|
||||
}
|
||||
if result.Tools[0].Function.Name != "web_search" {
|
||||
t.Fatalf("expected custom tool name web_search, got %q", result.Tools[0].Function.Name)
|
||||
}
|
||||
if result.Tools[0].Function.Description != "User-defined web search" {
|
||||
t.Fatalf("expected custom description preserved, got %q", result.Tools[0].Function.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
@@ -1063,3 +1136,320 @@ func TestEstimateTokens_EmptyContent(t *testing.T) {
|
||||
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
// Web Search Tests
|
||||
|
||||
func TestConvertTool_WebSearch(t *testing.T) {
|
||||
tool := Tool{
|
||||
Type: "web_search_20250305",
|
||||
Name: "web_search",
|
||||
MaxUses: 5,
|
||||
}
|
||||
|
||||
result, isServerTool, err := convertTool(tool)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !isServerTool {
|
||||
t.Error("expected isServerTool to be true for web_search tool")
|
||||
}
|
||||
|
||||
if result.Type != "function" {
|
||||
t.Errorf("expected type 'function', got %q", result.Type)
|
||||
}
|
||||
|
||||
if result.Function.Name != "web_search" {
|
||||
t.Errorf("expected name 'web_search', got %q", result.Function.Name)
|
||||
}
|
||||
|
||||
if result.Function.Description == "" {
|
||||
t.Error("expected non-empty description for web_search tool")
|
||||
}
|
||||
|
||||
// Check that query parameter is defined
|
||||
if result.Function.Parameters.Properties == nil {
|
||||
t.Fatal("expected properties to be defined")
|
||||
}
|
||||
|
||||
queryProp, ok := result.Function.Parameters.Properties.Get("query")
|
||||
if !ok {
|
||||
t.Error("expected 'query' property to be defined")
|
||||
}
|
||||
|
||||
if len(queryProp.Type) == 0 || queryProp.Type[0] != "string" {
|
||||
t.Errorf("expected query type to be 'string', got %v", queryProp.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertTool_RegularTool(t *testing.T) {
|
||||
tool := Tool{
|
||||
Type: "custom",
|
||||
Name: "get_weather",
|
||||
Description: "Get the weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
|
||||
}
|
||||
|
||||
result, isServerTool, err := convertTool(tool)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if isServerTool {
|
||||
t.Error("expected isServerTool to be false for regular tool")
|
||||
}
|
||||
|
||||
if result.Function.Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", result.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_ServerToolUse(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "server_tool_use",
|
||||
"id": "srvtoolu_123",
|
||||
"name": "web_search",
|
||||
"input": map[string]any{"query": "test query"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
|
||||
if len(messages[0].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(messages[0].ToolCalls))
|
||||
}
|
||||
|
||||
tc := messages[0].ToolCalls[0]
|
||||
if tc.ID != "srvtoolu_123" {
|
||||
t.Errorf("expected tool call ID 'srvtoolu_123', got %q", tc.ID)
|
||||
}
|
||||
|
||||
if tc.Function.Name != "web_search" {
|
||||
t.Errorf("expected tool name 'web_search', got %q", tc.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_WebSearchToolResult(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_123",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "web_search_result",
|
||||
"title": "Test Result",
|
||||
"url": "https://example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should have a tool result message
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
|
||||
if messages[0].Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", messages[0].Role)
|
||||
}
|
||||
|
||||
if messages[0].ToolCallID != "srvtoolu_123" {
|
||||
t.Errorf("expected tool_call_id 'srvtoolu_123', got %q", messages[0].ToolCallID)
|
||||
}
|
||||
|
||||
if messages[0].Content == "" {
|
||||
t.Error("expected non-empty content from web search results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_empty",
|
||||
"content": []any{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Role != "tool" {
|
||||
t.Fatalf("expected role tool, got %q", messages[0].Role)
|
||||
}
|
||||
if messages[0].ToolCallID != "srvtoolu_empty" {
|
||||
t.Fatalf("expected tool_call_id srvtoolu_empty, got %q", messages[0].ToolCallID)
|
||||
}
|
||||
if messages[0].Content != "" {
|
||||
t.Fatalf("expected empty content for empty web search results, got %q", messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_error",
|
||||
"content": map[string]any{
|
||||
"type": "web_search_tool_result_error",
|
||||
"error_code": "max_uses_exceeded",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Role != "tool" {
|
||||
t.Fatalf("expected role tool, got %q", messages[0].Role)
|
||||
}
|
||||
if messages[0].ToolCallID != "srvtoolu_error" {
|
||||
t.Fatalf("expected tool_call_id srvtoolu_error, got %q", messages[0].ToolCallID)
|
||||
}
|
||||
if !strings.Contains(messages[0].Content, "max_uses_exceeded") {
|
||||
t.Fatalf("expected error code in converted tool content, got %q", messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOllamaToAnthropicResults(t *testing.T) {
|
||||
ollamaResp := &OllamaWebSearchResponse{
|
||||
Results: []OllamaWebSearchResult{
|
||||
{
|
||||
Title: "Test Title",
|
||||
URL: "https://example.com",
|
||||
Content: "Test content",
|
||||
},
|
||||
{
|
||||
Title: "Another Result",
|
||||
URL: "https://example.org",
|
||||
Content: "More content",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
results := ConvertOllamaToAnthropicResults(ollamaResp)
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d", len(results))
|
||||
}
|
||||
|
||||
if results[0].Type != "web_search_result" {
|
||||
t.Errorf("expected type 'web_search_result', got %q", results[0].Type)
|
||||
}
|
||||
|
||||
if results[0].Title != "Test Title" {
|
||||
t.Errorf("expected title 'Test Title', got %q", results[0].Title)
|
||||
}
|
||||
|
||||
if results[0].URL != "https://example.com" {
|
||||
t.Errorf("expected URL 'https://example.com', got %q", results[0].URL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSearchTypes(t *testing.T) {
|
||||
// Test that WebSearchResult serializes correctly
|
||||
result := WebSearchResult{
|
||||
Type: "web_search_result",
|
||||
URL: "https://example.com",
|
||||
Title: "Test",
|
||||
EncryptedContent: "abc123",
|
||||
PageAge: "2025-01-01",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal WebSearchResult: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaled WebSearchResult
|
||||
if err := json.Unmarshal(data, &unmarshaled); err != nil {
|
||||
t.Fatalf("failed to unmarshal WebSearchResult: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaled.Type != result.Type {
|
||||
t.Errorf("type mismatch: expected %q, got %q", result.Type, unmarshaled.Type)
|
||||
}
|
||||
|
||||
// Test WebSearchToolResultError
|
||||
errResult := WebSearchToolResultError{
|
||||
Type: "web_search_tool_result_error",
|
||||
ErrorCode: "max_uses_exceeded",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(errResult)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal WebSearchToolResultError: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaledErr WebSearchToolResultError
|
||||
if err := json.Unmarshal(data, &unmarshaledErr); err != nil {
|
||||
t.Fatalf("failed to unmarshal WebSearchToolResultError: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaledErr.ErrorCode != "max_uses_exceeded" {
|
||||
t.Errorf("error_code mismatch: expected 'max_uses_exceeded', got %q", unmarshaledErr.ErrorCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitation(t *testing.T) {
|
||||
citation := Citation{
|
||||
Type: "web_search_result_location",
|
||||
URL: "https://example.com",
|
||||
Title: "Example",
|
||||
EncryptedIndex: "enc123",
|
||||
CitedText: "Some cited text...",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(citation)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal Citation: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaled Citation
|
||||
if err := json.Unmarshal(data, &unmarshaled); err != nil {
|
||||
t.Fatalf("failed to unmarshal Citation: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaled.Type != "web_search_result_location" {
|
||||
t.Errorf("type mismatch: expected 'web_search_result_location', got %q", unmarshaled.Type)
|
||||
}
|
||||
|
||||
if unmarshaled.CitedText != "Some cited text..." {
|
||||
t.Errorf("cited_text mismatch: expected 'Some cited text...', got %q", unmarshaled.CitedText)
|
||||
}
|
||||
}
|
||||
|
||||
352
anthropic/trace.go
Normal file
352
anthropic/trace.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Trace truncation limits.
|
||||
const (
|
||||
TraceMaxStringRunes = 240
|
||||
TraceMaxSliceItems = 8
|
||||
TraceMaxMapEntries = 16
|
||||
TraceMaxDepth = 4
|
||||
)
|
||||
|
||||
// TraceTruncateString shortens s to TraceMaxStringRunes, appending a count of
|
||||
// omitted characters when truncated.
|
||||
func TraceTruncateString(s string) string {
|
||||
if len(s) == 0 {
|
||||
return s
|
||||
}
|
||||
runes := []rune(s)
|
||||
if len(runes) <= TraceMaxStringRunes {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%s...(+%d chars)", string(runes[:TraceMaxStringRunes]), len(runes)-TraceMaxStringRunes)
|
||||
}
|
||||
|
||||
// TraceJSON round-trips v through JSON and returns a compacted representation.
|
||||
func TraceJSON(v any) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return map[string]any{"marshal_error": err.Error(), "type": fmt.Sprintf("%T", v)}
|
||||
}
|
||||
var out any
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
return TraceTruncateString(string(data))
|
||||
}
|
||||
return TraceCompactValue(out, 0)
|
||||
}
|
||||
|
||||
// TraceCompactValue recursively truncates strings, slices, and maps for trace
|
||||
// output. depth tracks recursion to enforce TraceMaxDepth.
|
||||
func TraceCompactValue(v any, depth int) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
if depth >= TraceMaxDepth {
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return TraceTruncateString(t)
|
||||
case []any:
|
||||
return fmt.Sprintf("<array len=%d>", len(t))
|
||||
case map[string]any:
|
||||
return fmt.Sprintf("<object keys=%d>", len(t))
|
||||
default:
|
||||
return fmt.Sprintf("<%T>", v)
|
||||
}
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return TraceTruncateString(t)
|
||||
case []any:
|
||||
limit := min(len(t), TraceMaxSliceItems)
|
||||
out := make([]any, 0, limit+1)
|
||||
for i := range limit {
|
||||
out = append(out, TraceCompactValue(t[i], depth+1))
|
||||
}
|
||||
if len(t) > limit {
|
||||
out = append(out, fmt.Sprintf("... +%d more items", len(t)-limit))
|
||||
}
|
||||
return out
|
||||
case map[string]any:
|
||||
keys := make([]string, 0, len(t))
|
||||
for k := range t {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
limit := min(len(keys), TraceMaxMapEntries)
|
||||
out := make(map[string]any, limit+1)
|
||||
for i := range limit {
|
||||
out[keys[i]] = TraceCompactValue(t[keys[i]], depth+1)
|
||||
}
|
||||
if len(keys) > limit {
|
||||
out["__truncated_keys"] = len(keys) - limit
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Anthropic request/response tracing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TraceMessagesRequest returns a compact trace representation of a MessagesRequest.
|
||||
func TraceMessagesRequest(r MessagesRequest) map[string]any {
|
||||
return map[string]any{
|
||||
"model": r.Model,
|
||||
"max_tokens": r.MaxTokens,
|
||||
"messages": traceMessageParams(r.Messages),
|
||||
"system": traceAnthropicContent(r.System),
|
||||
"stream": r.Stream,
|
||||
"tools": traceTools(r.Tools),
|
||||
"tool_choice": TraceJSON(r.ToolChoice),
|
||||
"thinking": TraceJSON(r.Thinking),
|
||||
"stop_sequences": r.StopSequences,
|
||||
"temperature": ptrVal(r.Temperature),
|
||||
"top_p": ptrVal(r.TopP),
|
||||
"top_k": ptrVal(r.TopK),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceMessagesResponse returns a compact trace representation of a MessagesResponse.
|
||||
func TraceMessagesResponse(r MessagesResponse) map[string]any {
|
||||
return map[string]any{
|
||||
"id": r.ID,
|
||||
"model": r.Model,
|
||||
"content": TraceJSON(r.Content),
|
||||
"stop_reason": r.StopReason,
|
||||
"usage": r.Usage,
|
||||
}
|
||||
}
|
||||
|
||||
func traceMessageParams(msgs []MessageParam) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
out = append(out, map[string]any{
|
||||
"role": m.Role,
|
||||
"content": traceAnthropicContent(m.Content),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func traceAnthropicContent(content any) any {
|
||||
switch c := content.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case string:
|
||||
return TraceTruncateString(c)
|
||||
case []any:
|
||||
blocks := make([]any, 0, len(c))
|
||||
for _, block := range c {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
blocks = append(blocks, TraceCompactValue(block, 0))
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, traceAnthropicBlock(blockMap))
|
||||
}
|
||||
return blocks
|
||||
default:
|
||||
return TraceJSON(c)
|
||||
}
|
||||
}
|
||||
|
||||
func traceAnthropicBlock(block map[string]any) map[string]any {
|
||||
blockType, _ := block["type"].(string)
|
||||
out := map[string]any{"type": blockType}
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := block["text"].(string); ok {
|
||||
out["text"] = TraceTruncateString(text)
|
||||
} else {
|
||||
out["text"] = TraceCompactValue(block["text"], 0)
|
||||
}
|
||||
case "thinking":
|
||||
if thinking, ok := block["thinking"].(string); ok {
|
||||
out["thinking"] = TraceTruncateString(thinking)
|
||||
} else {
|
||||
out["thinking"] = TraceCompactValue(block["thinking"], 0)
|
||||
}
|
||||
case "tool_use", "server_tool_use":
|
||||
out["id"] = block["id"]
|
||||
out["name"] = block["name"]
|
||||
out["input"] = TraceCompactValue(block["input"], 0)
|
||||
case "tool_result", "web_search_tool_result":
|
||||
out["tool_use_id"] = block["tool_use_id"]
|
||||
out["content"] = TraceCompactValue(block["content"], 0)
|
||||
case "image":
|
||||
if source, ok := block["source"].(map[string]any); ok {
|
||||
out["source"] = map[string]any{
|
||||
"type": source["type"],
|
||||
"media_type": source["media_type"],
|
||||
"url": source["url"],
|
||||
"data_len": len(fmt.Sprint(source["data"])),
|
||||
}
|
||||
}
|
||||
default:
|
||||
out["block"] = TraceCompactValue(block, 0)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func traceTools(tools []Tool) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
out = append(out, TraceTool(t))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TraceTool returns a compact trace representation of an Anthropic Tool.
|
||||
func TraceTool(t Tool) map[string]any {
|
||||
return map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Name,
|
||||
"description": TraceTruncateString(t.Description),
|
||||
"input_schema": TraceJSON(t.InputSchema),
|
||||
"max_uses": t.MaxUses,
|
||||
}
|
||||
}
|
||||
|
||||
// ContentBlockTypes returns the type strings from content (when it's []any blocks).
|
||||
func ContentBlockTypes(content any) []string {
|
||||
blocks, ok := content.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
types := make([]string, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
types = append(types, fmt.Sprintf("%T", block))
|
||||
continue
|
||||
}
|
||||
t, _ := blockMap["type"].(string)
|
||||
types = append(types, t)
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
||||
func ptrVal[T any](v *T) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Ollama api.* tracing (shared between anthropic and middleware packages)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TraceChatRequest returns a compact trace representation of an Ollama ChatRequest.
|
||||
func TraceChatRequest(req *api.ChatRequest) map[string]any {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
stream := false
|
||||
if req.Stream != nil {
|
||||
stream = *req.Stream
|
||||
}
|
||||
return map[string]any{
|
||||
"model": req.Model,
|
||||
"messages": TraceAPIMessages(req.Messages),
|
||||
"tools": TraceAPITools(req.Tools),
|
||||
"stream": stream,
|
||||
"options": req.Options,
|
||||
"think": TraceJSON(req.Think),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceChatResponse returns a compact trace representation of an Ollama ChatResponse.
|
||||
func TraceChatResponse(resp api.ChatResponse) map[string]any {
|
||||
return map[string]any{
|
||||
"model": resp.Model,
|
||||
"done": resp.Done,
|
||||
"done_reason": resp.DoneReason,
|
||||
"message": TraceAPIMessage(resp.Message),
|
||||
"metrics": TraceJSON(resp.Metrics),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceAPIMessages returns compact trace representations for a slice of api.Message.
|
||||
func TraceAPIMessages(msgs []api.Message) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
out = append(out, TraceAPIMessage(m))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TraceAPIMessage returns a compact trace representation of a single api.Message.
|
||||
func TraceAPIMessage(m api.Message) map[string]any {
|
||||
return map[string]any{
|
||||
"role": m.Role,
|
||||
"content": TraceTruncateString(m.Content),
|
||||
"thinking": TraceTruncateString(m.Thinking),
|
||||
"images": traceImageSizes(m.Images),
|
||||
"tool_calls": traceToolCalls(m.ToolCalls),
|
||||
"tool_name": m.ToolName,
|
||||
"tool_call_id": m.ToolCallID,
|
||||
}
|
||||
}
|
||||
|
||||
func traceImageSizes(images []api.ImageData) []int {
|
||||
if len(images) == 0 {
|
||||
return nil
|
||||
}
|
||||
sizes := make([]int, 0, len(images))
|
||||
for _, img := range images {
|
||||
sizes = append(sizes, len(img))
|
||||
}
|
||||
return sizes
|
||||
}
|
||||
|
||||
// TraceAPITools returns compact trace representations for a slice of api.Tool.
|
||||
func TraceAPITools(tools api.Tools) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
out = append(out, TraceAPITool(t))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TraceAPITool returns a compact trace representation of a single api.Tool.
|
||||
func TraceAPITool(t api.Tool) map[string]any {
|
||||
return map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Function.Name,
|
||||
"description": TraceTruncateString(t.Function.Description),
|
||||
"parameters": TraceJSON(t.Function.Parameters),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceToolCall returns a compact trace representation of an api.ToolCall.
|
||||
func TraceToolCall(tc api.ToolCall) map[string]any {
|
||||
return map[string]any{
|
||||
"id": tc.ID,
|
||||
"name": tc.Function.Name,
|
||||
"args": TraceJSON(tc.Function.Arguments),
|
||||
}
|
||||
}
|
||||
|
||||
func traceToolCalls(tcs []api.ToolCall) []map[string]any {
|
||||
if len(tcs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(tcs))
|
||||
for _, tc := range tcs {
|
||||
out = append(out, TraceToolCall(tc))
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -449,6 +449,16 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
return version.Version, nil
|
||||
}
|
||||
|
||||
// CloudStatusExperimental returns whether cloud features are disabled on the server.
|
||||
func (c *Client) CloudStatusExperimental(ctx context.Context) (*StatusResponse, error) {
|
||||
var status StatusResponse
|
||||
if err := c.do(ctx, http.MethodGet, "/api/status", nil, &status); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
// Signout will signout a client for a local ollama server.
|
||||
func (c *Client) Signout(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/signout", nil, nil)
|
||||
|
||||
10
api/types.go
10
api/types.go
@@ -834,6 +834,16 @@ type TokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type CloudStatus struct {
|
||||
Disabled bool `json:"disabled"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
// StatusResponse is the response from [Client.CloudStatusExperimental].
|
||||
type StatusResponse struct {
|
||||
Cloud CloudStatus `json:"cloud"`
|
||||
}
|
||||
|
||||
// GenerateResponse is the response passed into [GenerateResponseFunc].
|
||||
type GenerateResponse struct {
|
||||
// Model is the model name that generated the response.
|
||||
|
||||
@@ -205,6 +205,11 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cloudDisabled, err := s.store.CloudDisabled()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd := commandContext(ctx, s.bin, "serve")
|
||||
cmd.Stdout, cmd.Stderr = s.log, s.log
|
||||
|
||||
@@ -230,6 +235,11 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
|
||||
if settings.ContextLength > 0 {
|
||||
env["OLLAMA_CONTEXT_LENGTH"] = strconv.Itoa(settings.ContextLength)
|
||||
}
|
||||
if cloudDisabled {
|
||||
env["OLLAMA_NO_CLOUD"] = "1"
|
||||
} else {
|
||||
env["OLLAMA_NO_CLOUD"] = "0"
|
||||
}
|
||||
cmd.Env = []string{}
|
||||
for k, v := range env {
|
||||
cmd.Env = append(cmd.Env, k+"="+v)
|
||||
|
||||
@@ -111,7 +111,7 @@ func TestServerCmd(t *testing.T) {
|
||||
for _, want := range tt.want {
|
||||
found := false
|
||||
for _, env := range cmd.Env {
|
||||
if strings.Contains(env, want) {
|
||||
if strings.HasPrefix(env, want) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func TestServerCmd(t *testing.T) {
|
||||
|
||||
for _, dont := range tt.dont {
|
||||
for _, env := range cmd.Env {
|
||||
if strings.Contains(env, dont) {
|
||||
if strings.HasPrefix(env, dont) {
|
||||
t.Errorf("unexpected environment variable: %s", env)
|
||||
}
|
||||
}
|
||||
@@ -136,6 +136,75 @@ func TestServerCmd(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerCmdCloudSettingEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
configContent string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "default cloud enabled",
|
||||
want: "OLLAMA_NO_CLOUD=0",
|
||||
},
|
||||
{
|
||||
name: "env disables cloud",
|
||||
envValue: "1",
|
||||
want: "OLLAMA_NO_CLOUD=1",
|
||||
},
|
||||
{
|
||||
name: "config disables cloud",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
want: "OLLAMA_NO_CLOUD=1",
|
||||
},
|
||||
{
|
||||
name: "invalid env disables cloud",
|
||||
envValue: "invalid",
|
||||
want: "OLLAMA_NO_CLOUD=1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
t.Setenv("HOME", tmpHome)
|
||||
t.Setenv("USERPROFILE", tmpHome)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
|
||||
|
||||
if tt.configContent != "" {
|
||||
configDir := filepath.Join(tmpHome, ".ollama")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir config dir: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "server.json")
|
||||
if err := os.WriteFile(configPath, []byte(tt.configContent), 0o644); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
st := &store.Store{DBPath: filepath.Join(t.TempDir(), "db.sqlite")}
|
||||
defer st.Close()
|
||||
|
||||
s := &Server{store: st}
|
||||
cmd, err := s.cmd(t.Context())
|
||||
if err != nil {
|
||||
t.Fatalf("s.cmd() error = %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, env := range cmd.Env {
|
||||
if env == tt.want {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected environment variable %q in command env", tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInferenceComputer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
128
app/store/cloud_config.go
Normal file
128
app/store/cloud_config.go
Normal file
@@ -0,0 +1,128 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
const serverConfigFilename = "server.json"
|
||||
|
||||
type serverConfig struct {
|
||||
DisableOllamaCloud bool `json:"disable_ollama_cloud,omitempty"`
|
||||
}
|
||||
|
||||
// CloudDisabled returns whether cloud features should be disabled.
|
||||
// The source of truth is: OLLAMA_NO_CLOUD OR ~/.ollama/server.json:disable_ollama_cloud.
|
||||
func (s *Store) CloudDisabled() (bool, error) {
|
||||
disabled, _, err := s.CloudStatus()
|
||||
return disabled, err
|
||||
}
|
||||
|
||||
// CloudStatus returns whether cloud is disabled and the source of that decision.
|
||||
// Source is one of: "none", "env", "config", "both".
|
||||
func (s *Store) CloudStatus() (bool, string, error) {
|
||||
if err := s.ensureDB(); err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
configDisabled, err := readServerConfigCloudDisabled()
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
envDisabled := envconfig.NoCloudEnv()
|
||||
return envDisabled || configDisabled, cloudStatusSource(envDisabled, configDisabled), nil
|
||||
}
|
||||
|
||||
// SetCloudEnabled writes the cloud setting to ~/.ollama/server.json.
|
||||
func (s *Store) SetCloudEnabled(enabled bool) error {
|
||||
if err := s.ensureDB(); err != nil {
|
||||
return err
|
||||
}
|
||||
return setCloudEnabled(enabled)
|
||||
}
|
||||
|
||||
func setCloudEnabled(enabled bool) error {
|
||||
configPath, err := serverConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return fmt.Errorf("create server config directory: %w", err)
|
||||
}
|
||||
|
||||
configMap := map[string]any{}
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
if err := json.Unmarshal(data, &configMap); err != nil {
|
||||
// If the existing file is invalid JSON, overwrite with a fresh object.
|
||||
configMap = map[string]any{}
|
||||
}
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("read server config: %w", err)
|
||||
}
|
||||
|
||||
configMap["disable_ollama_cloud"] = !enabled
|
||||
|
||||
data, err := json.MarshalIndent(configMap, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal server config: %w", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
|
||||
if err := os.WriteFile(configPath, data, 0o644); err != nil {
|
||||
return fmt.Errorf("write server config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func readServerConfigCloudDisabled() (bool, error) {
|
||||
configPath, err := serverConfigPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("read server config: %w", err)
|
||||
}
|
||||
|
||||
var cfg serverConfig
|
||||
// Invalid or unexpected JSON should not block startup; treat as default.
|
||||
if json.Unmarshal(data, &cfg) == nil {
|
||||
return cfg.DisableOllamaCloud, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func serverConfigPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolve home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".ollama", serverConfigFilename), nil
|
||||
}
|
||||
|
||||
func cloudStatusSource(envDisabled bool, configDisabled bool) string {
|
||||
switch {
|
||||
case envDisabled && configDisabled:
|
||||
return "both"
|
||||
case envDisabled:
|
||||
return "env"
|
||||
case configDisabled:
|
||||
return "config"
|
||||
default:
|
||||
return "none"
|
||||
}
|
||||
}
|
||||
130
app/store/cloud_config_test.go
Normal file
130
app/store/cloud_config_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCloudDisabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
configContent string
|
||||
wantDisabled bool
|
||||
wantSource string
|
||||
}{
|
||||
{
|
||||
name: "default enabled",
|
||||
wantDisabled: false,
|
||||
wantSource: "none",
|
||||
},
|
||||
{
|
||||
name: "env disables cloud",
|
||||
envValue: "1",
|
||||
wantDisabled: true,
|
||||
wantSource: "env",
|
||||
},
|
||||
{
|
||||
name: "config disables cloud",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
wantDisabled: true,
|
||||
wantSource: "config",
|
||||
},
|
||||
{
|
||||
name: "env and config",
|
||||
envValue: "1",
|
||||
configContent: `{"disable_ollama_cloud": false}`,
|
||||
wantDisabled: true,
|
||||
wantSource: "env",
|
||||
},
|
||||
{
|
||||
name: "invalid config is ignored",
|
||||
configContent: `{bad`,
|
||||
wantDisabled: false,
|
||||
wantSource: "none",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
setTestHome(t, tmpHome)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
|
||||
|
||||
if tt.configContent != "" {
|
||||
configDir := filepath.Join(tmpHome, ".ollama")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir config dir: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, serverConfigFilename)
|
||||
if err := os.WriteFile(configPath, []byte(tt.configContent), 0o644); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
s := &Store{DBPath: filepath.Join(tmpHome, "db.sqlite")}
|
||||
defer s.Close()
|
||||
|
||||
disabled, err := s.CloudDisabled()
|
||||
if err != nil {
|
||||
t.Fatalf("CloudDisabled() error = %v", err)
|
||||
}
|
||||
if disabled != tt.wantDisabled {
|
||||
t.Fatalf("CloudDisabled() = %v, want %v", disabled, tt.wantDisabled)
|
||||
}
|
||||
|
||||
statusDisabled, source, err := s.CloudStatus()
|
||||
if err != nil {
|
||||
t.Fatalf("CloudStatus() error = %v", err)
|
||||
}
|
||||
if statusDisabled != tt.wantDisabled {
|
||||
t.Fatalf("CloudStatus() disabled = %v, want %v", statusDisabled, tt.wantDisabled)
|
||||
}
|
||||
if source != tt.wantSource {
|
||||
t.Fatalf("CloudStatus() source = %v, want %v", source, tt.wantSource)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCloudEnabled(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
setTestHome(t, tmpHome)
|
||||
|
||||
configDir := filepath.Join(tmpHome, ".ollama")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir config dir: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, serverConfigFilename)
|
||||
if err := os.WriteFile(configPath, []byte(`{"another_key":"value","disable_ollama_cloud":true}`), 0o644); err != nil {
|
||||
t.Fatalf("seed config: %v", err)
|
||||
}
|
||||
|
||||
s := &Store{DBPath: filepath.Join(tmpHome, "db.sqlite")}
|
||||
defer s.Close()
|
||||
|
||||
if err := s.SetCloudEnabled(true); err != nil {
|
||||
t.Fatalf("SetCloudEnabled(true) error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read config: %v", err)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
if got["disable_ollama_cloud"] != false {
|
||||
t.Fatalf("disable_ollama_cloud = %v, want false", got["disable_ollama_cloud"])
|
||||
}
|
||||
if got["another_key"] != "value" {
|
||||
t.Fatalf("another_key = %v, want value", got["another_key"])
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
// currentSchemaVersion defines the current database schema version.
|
||||
// Increment this when making schema changes that require migrations.
|
||||
const currentSchemaVersion = 12
|
||||
const currentSchemaVersion = 13
|
||||
|
||||
// database wraps the SQLite connection.
|
||||
// SQLite handles its own locking for concurrent access:
|
||||
@@ -84,6 +84,7 @@ func (db *database) init() error {
|
||||
sidebar_open BOOLEAN NOT NULL DEFAULT 0,
|
||||
think_enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||
think_level TEXT NOT NULL DEFAULT '',
|
||||
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
||||
schema_version INTEGER NOT NULL DEFAULT %d
|
||||
);
|
||||
@@ -244,6 +245,12 @@ func (db *database) migrate() error {
|
||||
return fmt.Errorf("migrate v11 to v12: %w", err)
|
||||
}
|
||||
version = 12
|
||||
case 12:
|
||||
// add cloud_setting_migrated column to settings table
|
||||
if err := db.migrateV12ToV13(); err != nil {
|
||||
return fmt.Errorf("migrate v12 to v13: %w", err)
|
||||
}
|
||||
version = 13
|
||||
default:
|
||||
// If we have a version we don't recognize, just set it to current
|
||||
// This might happen during development
|
||||
@@ -452,6 +459,21 @@ func (db *database) migrateV11ToV12() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateV12ToV13 adds cloud_setting_migrated to settings.
|
||||
func (db *database) migrateV12ToV13() error {
|
||||
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0`)
|
||||
if err != nil && !duplicateColumnError(err) {
|
||||
return fmt.Errorf("add cloud_setting_migrated column: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 13`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update schema version: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
||||
func (db *database) cleanupOrphanedData() error {
|
||||
_, err := db.conn.Exec(`
|
||||
@@ -1108,9 +1130,9 @@ func (db *database) getSettings() (Settings, error) {
|
||||
var s Settings
|
||||
|
||||
err := db.conn.QueryRow(`
|
||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
|
||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
|
||||
FROM settings
|
||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
|
||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
|
||||
if err != nil {
|
||||
return Settings{}, fmt.Errorf("get settings: %w", err)
|
||||
}
|
||||
@@ -1121,14 +1143,40 @@ func (db *database) getSettings() (Settings, error) {
|
||||
func (db *database) setSettings(s Settings) error {
|
||||
_, err := db.conn.Exec(`
|
||||
UPDATE settings
|
||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
|
||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
|
||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
|
||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set settings: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *database) isCloudSettingMigrated() (bool, error) {
|
||||
var migrated bool
|
||||
err := db.conn.QueryRow("SELECT cloud_setting_migrated FROM settings").Scan(&migrated)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get cloud setting migration status: %w", err)
|
||||
}
|
||||
return migrated, nil
|
||||
}
|
||||
|
||||
func (db *database) setCloudSettingMigrated(migrated bool) error {
|
||||
_, err := db.conn.Exec("UPDATE settings SET cloud_setting_migrated = ?", migrated)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set cloud setting migration status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *database) getAirplaneMode() (bool, error) {
|
||||
var airplaneMode bool
|
||||
err := db.conn.QueryRow("SELECT airplane_mode FROM settings").Scan(&airplaneMode)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get airplane_mode: %w", err)
|
||||
}
|
||||
return airplaneMode, nil
|
||||
}
|
||||
|
||||
func (db *database) getWindowSize() (int, int, error) {
|
||||
var width, height int
|
||||
err := db.conn.QueryRow("SELECT window_width, window_height FROM settings").Scan(&width, &height)
|
||||
|
||||
@@ -127,6 +127,65 @@ func TestNoConfigToMigrate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudMigrationFromAirplaneMode(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
setTestHome(t, tmpHome)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "")
|
||||
|
||||
dbPath := filepath.Join(tmpHome, "db.sqlite")
|
||||
db, err := newDatabase(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create database: %v", err)
|
||||
}
|
||||
|
||||
if _, err := db.conn.Exec("UPDATE settings SET airplane_mode = 1, cloud_setting_migrated = 0"); err != nil {
|
||||
db.Close()
|
||||
t.Fatalf("failed to seed airplane migration state: %v", err)
|
||||
}
|
||||
db.Close()
|
||||
|
||||
s := Store{DBPath: dbPath}
|
||||
defer s.Close()
|
||||
|
||||
// Trigger DB initialization + one-time cloud migration.
|
||||
if _, err := s.ID(); err != nil {
|
||||
t.Fatalf("failed to initialize store: %v", err)
|
||||
}
|
||||
|
||||
disabled, err := s.CloudDisabled()
|
||||
if err != nil {
|
||||
t.Fatalf("CloudDisabled() error: %v", err)
|
||||
}
|
||||
if !disabled {
|
||||
t.Fatal("expected cloud to be disabled after migrating airplane_mode=true")
|
||||
}
|
||||
|
||||
configPath := filepath.Join(tmpHome, ".ollama", serverConfigFilename)
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read migrated server config: %v", err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("failed to parse migrated server config: %v", err)
|
||||
}
|
||||
if cfg["disable_ollama_cloud"] != true {
|
||||
t.Fatalf("disable_ollama_cloud = %v, want true", cfg["disable_ollama_cloud"])
|
||||
}
|
||||
|
||||
var airplaneMode, migrated bool
|
||||
if err := s.db.conn.QueryRow("SELECT airplane_mode, cloud_setting_migrated FROM settings").Scan(&airplaneMode, &migrated); err != nil {
|
||||
t.Fatalf("failed to read migration flags from DB: %v", err)
|
||||
}
|
||||
if !airplaneMode {
|
||||
t.Fatal("expected legacy airplane_mode value to remain unchanged")
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected cloud_setting_migrated to be true")
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
v1Schema = `
|
||||
CREATE TABLE IF NOT EXISTS settings (
|
||||
|
||||
@@ -149,9 +149,6 @@ type Settings struct {
|
||||
// ContextLength specifies the context length for the ollama server (using OLLAMA_CONTEXT_LENGTH)
|
||||
ContextLength int
|
||||
|
||||
// AirplaneMode when true, turns off Ollama Turbo features and only uses local models
|
||||
AirplaneMode bool
|
||||
|
||||
// TurboEnabled indicates if Ollama Turbo features are enabled
|
||||
TurboEnabled bool
|
||||
|
||||
@@ -259,6 +256,40 @@ func (s *Store) ensureDB() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Run one-time migration from legacy airplane_mode behavior.
|
||||
if err := s.migrateCloudSetting(database); err != nil {
|
||||
return fmt.Errorf("migrate cloud setting: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateCloudSetting migrates legacy airplane_mode into server.json exactly once.
|
||||
// After this, cloud state is sourced from server.json OR OLLAMA_NO_CLOUD.
|
||||
func (s *Store) migrateCloudSetting(database *database) error {
|
||||
migrated, err := database.isCloudSettingMigrated()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if migrated {
|
||||
return nil
|
||||
}
|
||||
|
||||
airplaneMode, err := database.getAirplaneMode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if airplaneMode {
|
||||
if err := setCloudEnabled(false); err != nil {
|
||||
return fmt.Errorf("migrate airplane_mode to cloud disabled: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := database.setCloudSettingMigrated(true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
11
app/store/test_home_test.go
Normal file
11
app/store/test_home_test.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package store
|
||||
|
||||
import "testing"
|
||||
|
||||
func setTestHome(t *testing.T, home string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
}
|
||||
35
app/tools/cloud_policy.go
Normal file
35
app/tools/cloud_policy.go
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
)
|
||||
|
||||
// ensureCloudEnabledForTool checks cloud policy from the connected Ollama server.
|
||||
// If policy cannot be determined, this fails closed and blocks the operation.
|
||||
func ensureCloudEnabledForTool(ctx context.Context, operation string) error {
|
||||
// Reuse shared message formatting; policy evaluation is still done via
|
||||
// the connected server's /api/status endpoint below.
|
||||
disabledMessage := internalcloud.DisabledError(operation)
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return errors.New(disabledMessage + " (unable to verify server cloud policy)")
|
||||
}
|
||||
|
||||
status, err := client.CloudStatusExperimental(ctx)
|
||||
if err != nil {
|
||||
return errors.New(disabledMessage + " (unable to verify server cloud policy)")
|
||||
}
|
||||
|
||||
if status.Cloud.Disabled {
|
||||
return errors.New(disabledMessage)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
73
app/tools/cloud_policy_test.go
Normal file
73
app/tools/cloud_policy_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEnsureCloudEnabledForTool(t *testing.T) {
|
||||
const op = "web search is unavailable"
|
||||
const disabledPrefix = "ollama cloud is disabled: web search is unavailable"
|
||||
|
||||
t.Run("enabled allows tool execution", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/status" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"cloud":{"disabled":false,"source":"none"}}`))
|
||||
}))
|
||||
t.Cleanup(ts.Close)
|
||||
t.Setenv("OLLAMA_HOST", ts.URL)
|
||||
|
||||
if err := ensureCloudEnabledForTool(context.Background(), op); err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disabled blocks tool execution", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/status" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"cloud":{"disabled":true,"source":"config"}}`))
|
||||
}))
|
||||
t.Cleanup(ts.Close)
|
||||
t.Setenv("OLLAMA_HOST", ts.URL)
|
||||
|
||||
err := ensureCloudEnabledForTool(context.Background(), op)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if got := err.Error(); got != disabledPrefix {
|
||||
t.Fatalf("unexpected error: %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("status unavailable fails closed", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
t.Cleanup(ts.Close)
|
||||
t.Setenv("OLLAMA_HOST", ts.URL)
|
||||
|
||||
err := ensureCloudEnabledForTool(context.Background(), op)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if got := err.Error(); !strings.Contains(got, disabledPrefix) {
|
||||
t.Fatalf("expected disabled prefix, got %q", got)
|
||||
}
|
||||
if got := err.Error(); !strings.Contains(got, "unable to verify server cloud policy") {
|
||||
t.Fatalf("expected verification failure detail, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -77,6 +77,10 @@ func (w *WebFetch) Execute(ctx context.Context, args map[string]any) (any, strin
|
||||
}
|
||||
|
||||
func performWebFetch(ctx context.Context, targetURL string) (*FetchResponse, error) {
|
||||
if err := ensureCloudEnabledForTool(ctx, "web fetch is unavailable"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqBody := FetchRequest{URL: targetURL}
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
|
||||
@@ -93,6 +93,10 @@ func (w *WebSearch) Execute(ctx context.Context, args map[string]any) (any, stri
|
||||
}
|
||||
|
||||
func performWebSearch(ctx context.Context, query string, maxResults int) (*SearchResponse, error) {
|
||||
if err := ensureCloudEnabledForTool(ctx, "web search is unavailable"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqBody := SearchRequest{Query: query, MaxResults: maxResults}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
|
||||
@@ -406,7 +406,6 @@ export class Settings {
|
||||
Tools: boolean;
|
||||
WorkingDir: string;
|
||||
ContextLength: number;
|
||||
AirplaneMode: boolean;
|
||||
TurboEnabled: boolean;
|
||||
WebSearchEnabled: boolean;
|
||||
ThinkEnabled: boolean;
|
||||
@@ -424,7 +423,6 @@ export class Settings {
|
||||
this.Tools = source["Tools"];
|
||||
this.WorkingDir = source["WorkingDir"];
|
||||
this.ContextLength = source["ContextLength"];
|
||||
this.AirplaneMode = source["AirplaneMode"];
|
||||
this.TurboEnabled = source["TurboEnabled"];
|
||||
this.WebSearchEnabled = source["WebSearchEnabled"];
|
||||
this.ThinkEnabled = source["ThinkEnabled"];
|
||||
|
||||
@@ -27,6 +27,12 @@ declare module "@/gotypes" {
|
||||
Model.prototype.isCloud = function (): boolean {
|
||||
return this.model.endsWith("cloud");
|
||||
};
|
||||
|
||||
export type CloudStatusSource = "env" | "config" | "both" | "none";
|
||||
export interface CloudStatusResponse {
|
||||
disabled: boolean;
|
||||
source: CloudStatusSource;
|
||||
}
|
||||
// Helper function to convert Uint8Array to base64
|
||||
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
||||
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
|
||||
@@ -285,6 +291,28 @@ export async function updateSettings(settings: Settings): Promise<{
|
||||
};
|
||||
}
|
||||
|
||||
export async function updateCloudSetting(
|
||||
enabled: boolean,
|
||||
): Promise<CloudStatusResponse> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/cloud`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ enabled }),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(error || "Failed to update cloud setting");
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return {
|
||||
disabled: Boolean(data.disabled),
|
||||
source: (data.source as CloudStatusSource) || "none",
|
||||
};
|
||||
}
|
||||
|
||||
export async function renameChat(chatId: string, title: string): Promise<void> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}/rename`, {
|
||||
method: "PUT",
|
||||
@@ -414,3 +442,16 @@ export async function fetchHealth(): Promise<boolean> {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function getCloudStatus(): Promise<CloudStatusResponse | null> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/cloud`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch cloud status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return {
|
||||
disabled: Boolean(data.disabled),
|
||||
source: (data.source as CloudStatusSource) || "none",
|
||||
};
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import { useUser } from "@/hooks/useUser";
|
||||
import { DisplayLogin } from "@/components/DisplayLogin";
|
||||
import { ErrorEvent, Message } from "@/gotypes";
|
||||
import { useSettings } from "@/hooks/useSettings";
|
||||
import { useCloudStatus } from "@/hooks/useCloudStatus";
|
||||
import { ThinkButton } from "./ThinkButton";
|
||||
import { ErrorMessage } from "./ErrorMessage";
|
||||
import { processFiles } from "@/utils/fileValidation";
|
||||
@@ -141,12 +142,12 @@ function ChatForm({
|
||||
const {
|
||||
settings: {
|
||||
webSearchEnabled,
|
||||
airplaneMode,
|
||||
thinkEnabled,
|
||||
thinkLevel: settingsThinkLevel,
|
||||
},
|
||||
setSettings,
|
||||
} = useSettings();
|
||||
const { cloudDisabled } = useCloudStatus();
|
||||
|
||||
// current supported models for web search
|
||||
const modelLower = selectedModel?.model.toLowerCase() || "";
|
||||
@@ -180,6 +181,12 @@ function ChatForm({
|
||||
setSettings,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (cloudDisabled && webSearchEnabled) {
|
||||
setSettings({ WebSearchEnabled: false });
|
||||
}
|
||||
}, [cloudDisabled, webSearchEnabled, setSettings]);
|
||||
|
||||
const removeFile = (index: number) => {
|
||||
setMessage((prev) => ({
|
||||
...prev,
|
||||
@@ -234,19 +241,19 @@ function ChatForm({
|
||||
|
||||
// Determine if login banner should be shown
|
||||
const shouldShowLoginBanner =
|
||||
!cloudDisabled &&
|
||||
!isLoadingUser &&
|
||||
!isAuthenticated &&
|
||||
((webSearchEnabled && supportsWebSearch) ||
|
||||
(selectedModel?.isCloud() && !airplaneMode));
|
||||
((webSearchEnabled && supportsWebSearch) || selectedModel?.isCloud());
|
||||
|
||||
// Determine which feature to highlight in the banner
|
||||
const getActiveFeatureForBanner = () => {
|
||||
if (cloudDisabled) return null;
|
||||
if (!isAuthenticated) {
|
||||
if (loginPromptFeature) return loginPromptFeature;
|
||||
if (webSearchEnabled && selectedModel?.isCloud() && !airplaneMode)
|
||||
return "webSearch";
|
||||
if (webSearchEnabled && selectedModel?.isCloud()) return "webSearch";
|
||||
if (webSearchEnabled) return "webSearch";
|
||||
if (selectedModel?.isCloud() && !airplaneMode) return "turbo";
|
||||
if (selectedModel?.isCloud()) return "turbo";
|
||||
}
|
||||
return null;
|
||||
};
|
||||
@@ -269,11 +276,12 @@ function ChatForm({
|
||||
useEffect(() => {
|
||||
if (
|
||||
isAuthenticated ||
|
||||
(!webSearchEnabled && !!selectedModel?.isCloud() && !airplaneMode)
|
||||
cloudDisabled ||
|
||||
(!webSearchEnabled && !!selectedModel?.isCloud())
|
||||
) {
|
||||
setLoginPromptFeature(null);
|
||||
}
|
||||
}, [isAuthenticated, webSearchEnabled, selectedModel, airplaneMode]);
|
||||
}, [isAuthenticated, webSearchEnabled, selectedModel, cloudDisabled]);
|
||||
|
||||
// When entering edit mode, populate the composition with existing data
|
||||
useEffect(() => {
|
||||
@@ -465,6 +473,10 @@ function ChatForm({
|
||||
const handleSubmit = async () => {
|
||||
if (!message.content.trim() || isStreaming || isDownloading) return;
|
||||
|
||||
if (cloudDisabled && selectedModel?.isCloud()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if cloud mode is enabled but user is not authenticated
|
||||
if (shouldShowLoginBanner) {
|
||||
return;
|
||||
@@ -478,7 +490,8 @@ function ChatForm({
|
||||
}),
|
||||
);
|
||||
|
||||
const useWebSearch = supportsWebSearch && webSearchEnabled && !airplaneMode;
|
||||
const useWebSearch =
|
||||
supportsWebSearch && webSearchEnabled && !cloudDisabled;
|
||||
const useThink = modelSupportsThinkingLevels
|
||||
? thinkLevel
|
||||
: supportsThinkToggling
|
||||
@@ -899,7 +912,7 @@ function ChatForm({
|
||||
)}
|
||||
<WebSearchButton
|
||||
ref={webSearchButtonRef}
|
||||
isVisible={supportsWebSearch && airplaneMode === false}
|
||||
isVisible={supportsWebSearch && cloudDisabled === false}
|
||||
isActive={webSearchEnabled}
|
||||
onToggle={() => {
|
||||
if (!webSearchEnabled && !isAuthenticated) {
|
||||
@@ -940,6 +953,7 @@ function ChatForm({
|
||||
!isDownloading &&
|
||||
(!message.content.trim() ||
|
||||
shouldShowLoginBanner ||
|
||||
(cloudDisabled && selectedModel?.isCloud()) ||
|
||||
message.fileErrors.length > 0)
|
||||
}
|
||||
className={`flex items-center justify-center h-9 w-9 rounded-full disabled:cursor-default cursor-pointer bg-black text-white dark:bg-white dark:text-black disabled:opacity-10 focus:outline-none focus:ring-2 focus:ring-blue-500`}
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
} from "react";
|
||||
import { Model } from "@/gotypes";
|
||||
import { useSelectedModel } from "@/hooks/useSelectedModel";
|
||||
import { useSettings } from "@/hooks/useSettings";
|
||||
import { useCloudStatus } from "@/hooks/useCloudStatus";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { getModelUpstreamInfo } from "@/api";
|
||||
import { ArrowDownTrayIcon } from "@heroicons/react/24/outline";
|
||||
@@ -34,7 +34,7 @@ export const ModelPicker = forwardRef<
|
||||
chatId,
|
||||
searchQuery,
|
||||
);
|
||||
const { settings } = useSettings();
|
||||
const { cloudDisabled } = useCloudStatus();
|
||||
const dropdownRef = useRef<HTMLDivElement>(null);
|
||||
const searchInputRef = useRef<HTMLInputElement>(null);
|
||||
const queryClient = useQueryClient();
|
||||
@@ -219,7 +219,7 @@ export const ModelPicker = forwardRef<
|
||||
models={models}
|
||||
selectedModel={selectedModel}
|
||||
onModelSelect={handleModelSelect}
|
||||
airplaneMode={settings.airplaneMode}
|
||||
cloudDisabled={cloudDisabled}
|
||||
isOpen={isOpen}
|
||||
/>
|
||||
</div>
|
||||
@@ -233,13 +233,13 @@ export const ModelList = forwardRef(function ModelList(
|
||||
models,
|
||||
selectedModel,
|
||||
onModelSelect,
|
||||
airplaneMode,
|
||||
cloudDisabled,
|
||||
isOpen,
|
||||
}: {
|
||||
models: Model[];
|
||||
selectedModel: Model | null;
|
||||
onModelSelect: (model: Model) => void;
|
||||
airplaneMode: boolean;
|
||||
cloudDisabled: boolean;
|
||||
isOpen: boolean;
|
||||
},
|
||||
ref,
|
||||
@@ -348,7 +348,7 @@ export const ModelList = forwardRef(function ModelList(
|
||||
</svg>
|
||||
)}
|
||||
{model.digest === undefined &&
|
||||
(airplaneMode || !model.isCloud()) && (
|
||||
(cloudDisabled || !model.isCloud()) && (
|
||||
<ArrowDownTrayIcon
|
||||
className="h-4 w-4 text-neutral-500 dark:text-neutral-400"
|
||||
strokeWidth={1.75}
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
FolderIcon,
|
||||
BoltIcon,
|
||||
WrenchIcon,
|
||||
CloudIcon,
|
||||
XMarkIcon,
|
||||
CogIcon,
|
||||
ArrowLeftIcon,
|
||||
@@ -18,8 +19,14 @@ import {
|
||||
import { Settings as SettingsType } from "@/gotypes";
|
||||
import { useNavigate } from "@tanstack/react-router";
|
||||
import { useUser } from "@/hooks/useUser";
|
||||
import { useCloudStatus } from "@/hooks/useCloudStatus";
|
||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { getSettings, updateSettings } from "@/api";
|
||||
import {
|
||||
getSettings,
|
||||
type CloudStatusResponse,
|
||||
updateCloudSetting,
|
||||
updateSettings,
|
||||
} from "@/api";
|
||||
|
||||
function AnimatedDots() {
|
||||
return (
|
||||
@@ -53,6 +60,11 @@ export default function Settings() {
|
||||
const [connectionError, setConnectionError] = useState<string | null>(null);
|
||||
const [pollingInterval, setPollingInterval] = useState<number | null>(null);
|
||||
const navigate = useNavigate();
|
||||
const {
|
||||
cloudDisabled,
|
||||
cloudStatus,
|
||||
isLoading: cloudStatusLoading,
|
||||
} = useCloudStatus();
|
||||
|
||||
const {
|
||||
data: settingsData,
|
||||
@@ -74,6 +86,50 @@ export default function Settings() {
|
||||
},
|
||||
});
|
||||
|
||||
const updateCloudMutation = useMutation({
|
||||
mutationFn: (enabled: boolean) => updateCloudSetting(enabled),
|
||||
onMutate: async (enabled: boolean) => {
|
||||
await queryClient.cancelQueries({ queryKey: ["cloudStatus"] });
|
||||
|
||||
const previous = queryClient.getQueryData<CloudStatusResponse | null>([
|
||||
"cloudStatus",
|
||||
]);
|
||||
const envForcesDisabled =
|
||||
previous?.source === "env" || previous?.source === "both";
|
||||
|
||||
queryClient.setQueryData<CloudStatusResponse | null>(
|
||||
["cloudStatus"],
|
||||
previous
|
||||
? {
|
||||
...previous,
|
||||
disabled: !enabled || envForcesDisabled,
|
||||
}
|
||||
: {
|
||||
disabled: !enabled,
|
||||
source: "config",
|
||||
},
|
||||
);
|
||||
|
||||
return { previous };
|
||||
},
|
||||
onError: (_error, _enabled, context) => {
|
||||
if (context?.previous !== undefined) {
|
||||
queryClient.setQueryData(["cloudStatus"], context.previous);
|
||||
}
|
||||
},
|
||||
onSuccess: (status) => {
|
||||
queryClient.setQueryData<CloudStatusResponse | null>(
|
||||
["cloudStatus"],
|
||||
status,
|
||||
);
|
||||
queryClient.invalidateQueries({ queryKey: ["models"] });
|
||||
queryClient.invalidateQueries({ queryKey: ["cloudStatus"] });
|
||||
|
||||
setShowSaved(true);
|
||||
setTimeout(() => setShowSaved(false), 1500);
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
refetchUser();
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
@@ -149,12 +205,16 @@ export default function Settings() {
|
||||
Agent: false,
|
||||
Tools: false,
|
||||
ContextLength: 4096,
|
||||
AirplaneMode: false,
|
||||
});
|
||||
updateSettingsMutation.mutate(defaultSettings);
|
||||
}
|
||||
};
|
||||
|
||||
const cloudOverriddenByEnv =
|
||||
cloudStatus?.source === "env" || cloudStatus?.source === "both";
|
||||
const cloudToggleDisabled =
|
||||
cloudStatusLoading || updateCloudMutation.isPending || cloudOverriddenByEnv;
|
||||
|
||||
const handleConnectOllamaAccount = async () => {
|
||||
setConnectionError(null);
|
||||
|
||||
@@ -237,7 +297,7 @@ export default function Settings() {
|
||||
<div className="space-y-4 max-w-2xl mx-auto">
|
||||
{/* Connect Ollama Account */}
|
||||
<div className="overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
||||
<div className="p-4 border-b border-neutral-200 dark:border-neutral-800">
|
||||
<div className="p-4">
|
||||
<Field>
|
||||
{isLoading ? (
|
||||
// Loading skeleton, this will only happen if the app started recently
|
||||
@@ -344,6 +404,34 @@ export default function Settings() {
|
||||
{/* Local Configuration */}
|
||||
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
||||
<div className="space-y-4 p-4">
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
<div className="flex items-start space-x-3 flex-1">
|
||||
<CloudIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
|
||||
<div>
|
||||
<Label>Cloud</Label>
|
||||
<Description>
|
||||
{cloudOverriddenByEnv
|
||||
? "The OLLAMA_NO_CLOUD environment variable is currently forcing cloud off."
|
||||
: "Enable cloud models and web search."}
|
||||
</Description>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex-shrink-0">
|
||||
<Switch
|
||||
checked={!cloudDisabled}
|
||||
disabled={cloudToggleDisabled}
|
||||
onChange={(checked) => {
|
||||
if (cloudOverriddenByEnv) {
|
||||
return;
|
||||
}
|
||||
updateCloudMutation.mutate(checked);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Field>
|
||||
|
||||
{/* Expose Ollama */}
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
@@ -440,35 +528,6 @@ export default function Settings() {
|
||||
</div>
|
||||
</div>
|
||||
</Field>
|
||||
{/* Airplane Mode */}
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
<div className="flex items-start space-x-3 flex-1">
|
||||
<svg
|
||||
className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100"
|
||||
viewBox="0 0 21.5508 17.9033"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path d="M21.5508 8.94727C21.542 7.91895 20.1445 7.17188 18.4658 7.17188L14.9238 7.17188C14.4316 7.17188 14.2471 7.09277 13.957 6.75879L8.05078 0.316406C7.86621 0.105469 7.6377 0 7.37402 0L6.35449 0C6.12598 0 5.99414 0.202148 6.1084 0.448242L9.14941 7.17188L4.68457 7.68164L3.09375 4.76367C2.97949 4.54395 2.78613 4.44727 2.49609 4.44727L2.11816 4.44727C1.88965 4.44727 1.74023 4.59668 1.74023 4.8252L1.74023 13.0693C1.74023 13.2979 1.88965 13.4385 2.11816 13.4385L2.49609 13.4385C2.78613 13.4385 2.97949 13.3418 3.09375 13.1309L4.68457 10.2129L9.14941 10.7227L6.1084 17.4463C5.99414 17.6836 6.12598 17.8945 6.35449 17.8945L7.37402 17.8945C7.6377 17.8945 7.86621 17.7803 8.05078 17.5781L13.957 11.127C14.2471 10.8018 14.4316 10.7227 14.9238 10.7227L18.4658 10.7227C20.1445 10.7227 21.542 9.9668 21.5508 8.94727Z" />
|
||||
</svg>
|
||||
<div>
|
||||
<Label>Airplane mode</Label>
|
||||
<Description>
|
||||
Airplane mode keeps data local, disabling cloud models
|
||||
and web search.
|
||||
</Description>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex-shrink-0">
|
||||
<Switch
|
||||
checked={settings.AirplaneMode}
|
||||
onChange={(checked) =>
|
||||
handleChange("AirplaneMode", checked)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Field>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ import { useSelectedModel } from "./useSelectedModel";
|
||||
import { createQueryBatcher } from "./useQueryBatcher";
|
||||
import { useRefetchModels } from "./useModels";
|
||||
import { useStreamingContext } from "@/contexts/StreamingContext";
|
||||
import { useSettings } from "./useSettings";
|
||||
import { getModelCapabilities } from "@/api";
|
||||
import { useCloudStatus } from "./useCloudStatus";
|
||||
|
||||
export const useChats = () => {
|
||||
return useQuery({
|
||||
@@ -116,11 +116,9 @@ export const useIsModelStale = (modelName: string) => {
|
||||
export const useShouldShowStaleDisplay = (model: Model | null) => {
|
||||
const isStale = useIsModelStale(model?.model || "");
|
||||
const { data: dismissedModels } = useDismissedStaleModels();
|
||||
const {
|
||||
settings: { airplaneMode },
|
||||
} = useSettings();
|
||||
const { cloudDisabled } = useCloudStatus();
|
||||
|
||||
if (model?.isCloud() && !airplaneMode) {
|
||||
if (model?.isCloud() && !cloudDisabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
20
app/ui/app/src/hooks/useCloudStatus.ts
Normal file
20
app/ui/app/src/hooks/useCloudStatus.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { getCloudStatus, type CloudStatusResponse } from "@/api";
|
||||
|
||||
export function useCloudStatus() {
|
||||
const cloudQuery = useQuery<CloudStatusResponse | null>({
|
||||
queryKey: ["cloudStatus"],
|
||||
queryFn: getCloudStatus,
|
||||
retry: false,
|
||||
staleTime: 60 * 1000,
|
||||
});
|
||||
|
||||
return {
|
||||
cloudStatus: cloudQuery.data,
|
||||
cloudDisabled: cloudQuery.data?.disabled ?? false,
|
||||
isKnown: cloudQuery.data !== null && cloudQuery.data !== undefined,
|
||||
isLoading: cloudQuery.isLoading,
|
||||
isError: cloudQuery.isError,
|
||||
error: cloudQuery.error,
|
||||
};
|
||||
}
|
||||
@@ -2,11 +2,11 @@ import { useQuery } from "@tanstack/react-query";
|
||||
import { Model } from "@/gotypes";
|
||||
import { getModels } from "@/api";
|
||||
import { mergeModels } from "@/utils/mergeModels";
|
||||
import { useSettings } from "./useSettings";
|
||||
import { useMemo } from "react";
|
||||
import { useCloudStatus } from "./useCloudStatus";
|
||||
|
||||
export function useModels(searchQuery = "") {
|
||||
const { settings } = useSettings();
|
||||
const { cloudDisabled } = useCloudStatus();
|
||||
const localQuery = useQuery<Model[], Error>({
|
||||
queryKey: ["models", searchQuery],
|
||||
queryFn: () => getModels(searchQuery),
|
||||
@@ -20,7 +20,7 @@ export function useModels(searchQuery = "") {
|
||||
});
|
||||
|
||||
const allModels = useMemo(() => {
|
||||
const models = mergeModels(localQuery.data || [], settings.airplaneMode);
|
||||
const models = mergeModels(localQuery.data || [], cloudDisabled);
|
||||
|
||||
if (searchQuery && searchQuery.trim()) {
|
||||
const query = searchQuery.toLowerCase().trim();
|
||||
@@ -40,7 +40,7 @@ export function useModels(searchQuery = "") {
|
||||
}
|
||||
|
||||
return models;
|
||||
}, [localQuery.data, searchQuery, settings.airplaneMode]);
|
||||
}, [localQuery.data, searchQuery, cloudDisabled]);
|
||||
|
||||
return {
|
||||
...localQuery,
|
||||
|
||||
@@ -7,6 +7,7 @@ import { Model } from "@/gotypes";
|
||||
import { FEATURED_MODELS } from "@/utils/mergeModels";
|
||||
import { getTotalVRAM } from "@/utils/vram.ts";
|
||||
import { getInferenceCompute } from "@/api";
|
||||
import { useCloudStatus } from "./useCloudStatus";
|
||||
|
||||
export function recommendDefaultModel(totalVRAM: number): string {
|
||||
const vram = Math.max(0, Number(totalVRAM) || 0);
|
||||
@@ -22,6 +23,7 @@ export function recommendDefaultModel(totalVRAM: number): string {
|
||||
export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
const { settings, setSettings } = useSettings();
|
||||
const { data: models = [], isLoading } = useModels(searchQuery || "");
|
||||
const { cloudDisabled } = useCloudStatus();
|
||||
const { data: chatData, isLoading: isChatLoading } = useChat(
|
||||
currentChatId && currentChatId !== "new" ? currentChatId : "",
|
||||
);
|
||||
@@ -46,12 +48,11 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
const restoredChatRef = useRef<string | null>(null);
|
||||
|
||||
const selectedModel: Model | null = useMemo(() => {
|
||||
// if airplane mode is on and selected model ends with cloud,
|
||||
// switch to recommended default model
|
||||
if (settings.airplaneMode && settings.selectedModel?.endsWith("cloud")) {
|
||||
// If cloud is disabled and selected model ends with cloud, switch to a local default.
|
||||
if (cloudDisabled && settings.selectedModel?.endsWith("cloud")) {
|
||||
return (
|
||||
models.find((m) => m.model === recommendedModel) ||
|
||||
models.find((m) => m.isCloud) ||
|
||||
models.find((m) => !m.isCloud()) ||
|
||||
models.find((m) => m.digest === undefined || m.digest === "") ||
|
||||
models[0] ||
|
||||
null
|
||||
@@ -68,7 +69,7 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
"qwen3-coder:480b",
|
||||
];
|
||||
const shouldMigrate =
|
||||
!settings.airplaneMode &&
|
||||
!cloudDisabled &&
|
||||
settings.turboEnabled &&
|
||||
baseModelsToMigrate.includes(settings.selectedModel);
|
||||
|
||||
@@ -96,13 +97,18 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
})) ||
|
||||
null
|
||||
);
|
||||
}, [models, settings.selectedModel, settings.airplaneMode, recommendedModel]);
|
||||
}, [
|
||||
models,
|
||||
settings.selectedModel,
|
||||
cloudDisabled,
|
||||
recommendedModel,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedModel) return;
|
||||
|
||||
if (
|
||||
settings.airplaneMode &&
|
||||
cloudDisabled &&
|
||||
settings.selectedModel?.endsWith("cloud") &&
|
||||
selectedModel.model !== settings.selectedModel
|
||||
) {
|
||||
@@ -110,13 +116,17 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
}
|
||||
|
||||
if (
|
||||
!settings.airplaneMode &&
|
||||
!cloudDisabled &&
|
||||
settings.turboEnabled &&
|
||||
selectedModel.model !== settings.selectedModel
|
||||
) {
|
||||
setSettings({ SelectedModel: selectedModel.model, TurboEnabled: false });
|
||||
}
|
||||
}, [selectedModel, settings.airplaneMode, settings.selectedModel]);
|
||||
}, [
|
||||
selectedModel,
|
||||
cloudDisabled,
|
||||
settings.selectedModel,
|
||||
]);
|
||||
|
||||
// Set model from chat history when chat data loads
|
||||
useEffect(() => {
|
||||
@@ -169,7 +179,9 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
|
||||
const defaultModel =
|
||||
models.find((m) => m.model === recommendedModel) ||
|
||||
models.find((m) => m.isCloud()) ||
|
||||
(cloudDisabled
|
||||
? models.find((m) => !m.isCloud())
|
||||
: models.find((m) => m.isCloud())) ||
|
||||
models.find((m) => m.digest === undefined || m.digest === "") ||
|
||||
models[0];
|
||||
|
||||
@@ -181,6 +193,7 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
inferenceComputes.length,
|
||||
models.length,
|
||||
settings.selectedModel,
|
||||
cloudDisabled,
|
||||
]);
|
||||
|
||||
// Add the selected model to the models list if it's not already there
|
||||
|
||||
@@ -9,7 +9,6 @@ interface SettingsState {
|
||||
webSearchEnabled: boolean;
|
||||
selectedModel: string;
|
||||
sidebarOpen: boolean;
|
||||
airplaneMode: boolean;
|
||||
thinkEnabled: boolean;
|
||||
thinkLevel: string;
|
||||
}
|
||||
@@ -51,7 +50,6 @@ export function useSettings() {
|
||||
thinkLevel: settingsData?.settings?.ThinkLevel ?? "none",
|
||||
selectedModel: settingsData?.settings?.SelectedModel ?? "",
|
||||
sidebarOpen: settingsData?.settings?.SidebarOpen ?? false,
|
||||
airplaneMode: settingsData?.settings?.AirplaneMode ?? false,
|
||||
}),
|
||||
[settingsData?.settings],
|
||||
);
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { QueryClient } from "@tanstack/react-query";
|
||||
import { createRootRouteWithContext, Outlet } from "@tanstack/react-router";
|
||||
import { getSettings } from "@/api";
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { useCloudStatus } from "@/hooks/useCloudStatus";
|
||||
|
||||
function RootComponent() {
|
||||
// This hook ensures settings are fetched on app startup
|
||||
@@ -9,6 +10,8 @@ function RootComponent() {
|
||||
queryKey: ["settings"],
|
||||
queryFn: getSettings,
|
||||
});
|
||||
// Fetch cloud status on startup (best-effort)
|
||||
useCloudStatus();
|
||||
|
||||
return (
|
||||
<div>
|
||||
|
||||
@@ -41,14 +41,14 @@ describe("Model merging logic", () => {
|
||||
expect(merged.length).toBe(FEATURED_MODELS.length + 2);
|
||||
});
|
||||
|
||||
it("should hide cloud models in airplane mode", () => {
|
||||
it("should hide cloud models when cloud is disabled", () => {
|
||||
const localModels: Model[] = [
|
||||
new Model({ model: "gpt-oss:120b-cloud" }),
|
||||
new Model({ model: "llama3:latest" }),
|
||||
new Model({ model: "mistral:latest" }),
|
||||
];
|
||||
|
||||
const merged = mergeModels(localModels, true); // airplane mode = true
|
||||
const merged = mergeModels(localModels, true); // cloud disabled = true
|
||||
|
||||
// No cloud models should be present
|
||||
const cloudModels = merged.filter((m) => m.isCloud());
|
||||
|
||||
@@ -32,7 +32,7 @@ function alphabeticalSort(a: Model, b: Model): number {
|
||||
//Merges models, sorting cloud models first, then other models
|
||||
export function mergeModels(
|
||||
localModels: Model[],
|
||||
airplaneMode: boolean = false,
|
||||
hideCloudModels: boolean = false,
|
||||
): Model[] {
|
||||
const allModels = (localModels || []).map((model) => model);
|
||||
|
||||
@@ -95,7 +95,7 @@ export function mergeModels(
|
||||
|
||||
remainingModels.sort(alphabeticalSort);
|
||||
|
||||
return airplaneMode
|
||||
return hideCloudModels
|
||||
? [...featuredModels, ...remainingModels]
|
||||
: [...cloudModels, ...featuredModels, ...remainingModels];
|
||||
}
|
||||
|
||||
37
app/ui/ui.go
37
app/ui/ui.go
@@ -284,12 +284,15 @@ func (s *Server) Handler() http.Handler {
|
||||
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
|
||||
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
|
||||
mux.Handle("POST /api/v1/settings", handle(s.settings))
|
||||
mux.Handle("GET /api/v1/cloud", handle(s.getCloudSetting))
|
||||
mux.Handle("POST /api/v1/cloud", handle(s.cloudSetting))
|
||||
|
||||
// Ollama proxy endpoints
|
||||
ollamaProxy := s.ollamaProxy()
|
||||
mux.Handle("GET /api/tags", ollamaProxy)
|
||||
mux.Handle("POST /api/show", ollamaProxy)
|
||||
mux.Handle("GET /api/version", ollamaProxy)
|
||||
mux.Handle("GET /api/status", ollamaProxy)
|
||||
mux.Handle("HEAD /api/version", ollamaProxy)
|
||||
mux.Handle("POST /api/me", ollamaProxy)
|
||||
mux.Handle("POST /api/signout", ollamaProxy)
|
||||
@@ -1460,6 +1463,40 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) cloudSetting(w http.ResponseWriter, r *http.Request) error {
|
||||
var req struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return fmt.Errorf("invalid request body: %w", err)
|
||||
}
|
||||
|
||||
if err := s.Store.SetCloudEnabled(req.Enabled); err != nil {
|
||||
return fmt.Errorf("failed to persist cloud setting: %w", err)
|
||||
}
|
||||
|
||||
s.Restart()
|
||||
|
||||
return s.writeCloudStatus(w)
|
||||
}
|
||||
|
||||
func (s *Server) getCloudSetting(w http.ResponseWriter, r *http.Request) error {
|
||||
return s.writeCloudStatus(w)
|
||||
}
|
||||
|
||||
func (s *Server) writeCloudStatus(w http.ResponseWriter) error {
|
||||
disabled, source, err := s.Store.CloudStatus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load cloud status: %w", err)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
return json.NewEncoder(w).Encode(map[string]any{
|
||||
"disabled": disabled,
|
||||
"source": source,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
@@ -115,6 +115,107 @@ func TestHandlePostApiSettings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePostApiCloudSetting(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
t.Setenv("HOME", tmpHome)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "")
|
||||
|
||||
testStore := &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
restartCount := 0
|
||||
server := &Server{
|
||||
Store: testStore,
|
||||
Restart: func() {
|
||||
restartCount++
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
body string
|
||||
wantEnabled bool
|
||||
}{
|
||||
{name: "disable cloud", body: `{"enabled": false}`, wantEnabled: false},
|
||||
{name: "enable cloud", body: `{"enabled": true}`, wantEnabled: true},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/v1/cloud", bytes.NewBufferString(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
if err := server.cloudSetting(rr, req); err != nil {
|
||||
t.Fatalf("cloudSetting() error = %v", err)
|
||||
}
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("cloudSetting() status = %d, want %d", rr.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
|
||||
t.Fatalf("cloudSetting() invalid response JSON: %v", err)
|
||||
}
|
||||
if got["disabled"] != !tc.wantEnabled {
|
||||
t.Fatalf("response disabled = %v, want %v", got["disabled"], !tc.wantEnabled)
|
||||
}
|
||||
|
||||
disabled, err := testStore.CloudDisabled()
|
||||
if err != nil {
|
||||
t.Fatalf("CloudDisabled() error = %v", err)
|
||||
}
|
||||
if gotEnabled := !disabled; gotEnabled != tc.wantEnabled {
|
||||
t.Fatalf("cloud enabled = %v, want %v", gotEnabled, tc.wantEnabled)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if restartCount != 2 {
|
||||
t.Fatalf("Restart called %d times, want 2", restartCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetApiCloudSetting(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
t.Setenv("HOME", tmpHome)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "")
|
||||
|
||||
testStore := &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
if err := testStore.SetCloudEnabled(false); err != nil {
|
||||
t.Fatalf("SetCloudEnabled(false) error = %v", err)
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
Store: testStore,
|
||||
Restart: func() {},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/cloud", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
if err := server.getCloudSetting(rr, req); err != nil {
|
||||
t.Fatalf("getCloudSetting() error = %v", err)
|
||||
}
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("getCloudSetting() status = %d, want %d", rr.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
|
||||
t.Fatalf("getCloudSetting() invalid response JSON: %v", err)
|
||||
}
|
||||
if got["disabled"] != true {
|
||||
t.Fatalf("response disabled = %v, want true", got["disabled"])
|
||||
}
|
||||
if got["source"] != "config" {
|
||||
t.Fatalf("response source = %v, want config", got["source"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticationMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
22
auth/auth.go
22
auth/auth.go
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -83,3 +84,24 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||
// signature is <pubkey>:<signature>
|
||||
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
|
||||
}
|
||||
|
||||
// SignRequest adds a nonce query parameter and an Authorization header with
|
||||
// an Ed25519 signature to req.
|
||||
func SignRequest(ctx context.Context, req *http.Request) error {
|
||||
nonce, err := NewNonce(rand.Reader, 16)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
q.Set("nonce", nonce)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
data := []byte(fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI()))
|
||||
signature, err := Sign(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", signature)
|
||||
return nil
|
||||
}
|
||||
|
||||
88
cmd/cmd.go
88
cmd/cmd.go
@@ -57,9 +57,9 @@ import (
|
||||
|
||||
func init() {
|
||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem) (string, error) {
|
||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
@@ -182,6 +182,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
case "parser":
|
||||
mfConfig.Parser = cmd.Args
|
||||
case "renderer":
|
||||
mfConfig.Renderer = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
@@ -581,6 +585,17 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
opts.WordWrap = !nowrap
|
||||
|
||||
useImagegen := false
|
||||
if cmd.Flags().Lookup("imagegen") != nil {
|
||||
useImagegen, err = cmd.Flags().GetBool("imagegen")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if useImagegen {
|
||||
opts.Options["use_imagegen_runner"] = true
|
||||
}
|
||||
|
||||
// Fill out the rest of the options based on information about the
|
||||
// model.
|
||||
client, err := api.ClientFromEnvironment()
|
||||
@@ -1885,13 +1900,25 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
return
|
||||
}
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem) (string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
if version.Version != "0.0.0" && version.IsOfficialInstall() && version.IsLocalHost(envconfig.Host()) {
|
||||
if version.HasCachedUpdate() {
|
||||
fmt.Print("A new version of Ollama is available. Run \"ollama update\" to install.\n\n")
|
||||
_ = version.ClearCachedUpdate()
|
||||
}
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if available, err := version.CheckForUpdate(ctx); err == nil && available {
|
||||
_ = version.CacheAvailableUpdate()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
@@ -1899,10 +1926,7 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
}
|
||||
|
||||
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
}
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, config.ErrCancelled
|
||||
@@ -1949,7 +1973,7 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
launchIntegration := func(name string) bool {
|
||||
// If not configured or model no longer exists, prompt for model selection
|
||||
configuredModel := config.IntegrationModel(name)
|
||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) {
|
||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
return false // Return to main menu
|
||||
@@ -1971,7 +1995,7 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
return
|
||||
case tui.SelectionRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
if modelName := config.LastModel(); modelName != "" {
|
||||
if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
||||
runModel(modelName)
|
||||
} else {
|
||||
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
@@ -1999,6 +2023,9 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
runModel(modelName)
|
||||
case tui.SelectionIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
@@ -2008,6 +2035,17 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
case tui.SelectionChangeIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
if len(result.Models) > 0 {
|
||||
// Filter out cloud-disabled models
|
||||
var filtered []string
|
||||
for _, m := range result.Models {
|
||||
if !config.IsCloudModelDisabled(cmd.Context(), m) {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
continue
|
||||
}
|
||||
result.Models = filtered
|
||||
// Multi-select from modal (Editor integrations)
|
||||
if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
||||
@@ -2017,8 +2055,11 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
} else if result.Model != "" {
|
||||
if config.IsCloudModelDisabled(cmd.Context(), result.Model) {
|
||||
continue
|
||||
}
|
||||
// Single-select from modal - save and launch
|
||||
if err := config.SaveIntegrationModel(result.Integration, result.Model); err != nil {
|
||||
if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
@@ -2130,6 +2171,9 @@ func NewCLI() *cobra.Command {
|
||||
// Image generation flags (width, height, steps, seed, etc.)
|
||||
imagegen.RegisterFlags(runCmd)
|
||||
|
||||
runCmd.Flags().Bool("imagegen", false, "Use the imagegen runner for LLM inference")
|
||||
runCmd.Flags().MarkHidden("imagegen")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
Short: "Stop a running model",
|
||||
@@ -2273,6 +2317,7 @@ func NewCLI() *cobra.Command {
|
||||
envVars["OLLAMA_MAX_QUEUE"],
|
||||
envVars["OLLAMA_MODELS"],
|
||||
envVars["OLLAMA_NUM_PARALLEL"],
|
||||
envVars["OLLAMA_NO_CLOUD"],
|
||||
envVars["OLLAMA_NOPRUNE"],
|
||||
envVars["OLLAMA_ORIGINS"],
|
||||
envVars["OLLAMA_SCHED_SPREAD"],
|
||||
@@ -2287,6 +2332,18 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
}
|
||||
|
||||
updateCmd := &cobra.Command{
|
||||
Use: "update",
|
||||
Short: "Update Ollama to the latest version",
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
force, _ := cmd.Flags().GetBool("force")
|
||||
_ = version.ClearCachedUpdate()
|
||||
return version.DoUpdate(force)
|
||||
},
|
||||
}
|
||||
updateCmd.Flags().BoolP("force", "f", false, "Force update even if installed via a package manager")
|
||||
|
||||
rootCmd.AddCommand(
|
||||
serveCmd,
|
||||
createCmd,
|
||||
@@ -2304,6 +2361,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
updateCmd,
|
||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
)
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := DefaultSingleSelector("Select model:", items)
|
||||
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
@@ -140,7 +140,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"qwen3:8b"})
|
||||
SaveIntegration("claude", []string{"qwen3:8b"})
|
||||
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
||||
|
||||
got := envMap(c.modelEnvVars("qwen3:8b"))
|
||||
@@ -162,7 +162,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"llama3.2:70b"})
|
||||
SaveIntegration("claude", []string{"llama3.2:70b"})
|
||||
saveAliases("claude", map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
@@ -187,7 +187,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"saved-model"})
|
||||
SaveIntegration("claude", []string{"saved-model"})
|
||||
saveAliases("claude", map[string]string{"primary": "saved-model"})
|
||||
|
||||
got := envMap(c.modelEnvVars("different-model"))
|
||||
|
||||
123
cmd/config/cline.go
Normal file
123
cmd/config/cline.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Cline implements Runner and Editor for the Cline CLI integration
|
||||
type Cline struct{}
|
||||
|
||||
func (c *Cline) String() string { return "Cline" }
|
||||
|
||||
func (c *Cline) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("cline"); err != nil {
|
||||
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "cline", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("cline", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (c *Cline) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
p := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cline) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
|
||||
}
|
||||
}
|
||||
|
||||
// Set Ollama as the provider for both act and plan modes
|
||||
baseURL := envconfig.Host().String()
|
||||
config["ollamaBaseUrl"] = baseURL
|
||||
config["actModeApiProvider"] = "ollama"
|
||||
config["actModeOllamaModelId"] = models[0]
|
||||
config["actModeOllamaBaseUrl"] = baseURL
|
||||
config["planModeApiProvider"] = "ollama"
|
||||
config["planModeOllamaModelId"] = models[0]
|
||||
config["planModeOllamaBaseUrl"] = baseURL
|
||||
|
||||
config["welcomeViewCompleted"] = true
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Cline) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if config["actModeApiProvider"] != "ollama" {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelID, _ := config["actModeOllamaModelId"].(string)
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{modelID}
|
||||
}
|
||||
204
cmd/config/cline_test.go
Normal file
204
cmd/config/cline_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClineIntegration(t *testing.T) {
|
||||
c := &Cline{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Cline" {
|
||||
t.Errorf("String() = %q, want %q", got, "Cline")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClineEdit(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
|
||||
readConfig := func() map[string]any {
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var config map[string]any
|
||||
json.Unmarshal(data, &config)
|
||||
return config
|
||||
}
|
||||
|
||||
t.Run("creates config from scratch", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeApiProvider"] != "ollama" {
|
||||
t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
|
||||
}
|
||||
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
if config["planModeApiProvider"] != "ollama" {
|
||||
t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
|
||||
}
|
||||
if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
|
||||
}
|
||||
if config["welcomeViewCompleted"] != true {
|
||||
t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves existing fields", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
existing := map[string]any{
|
||||
"remoteRulesToggles": map[string]any{},
|
||||
"remoteWorkflowToggles": map[string]any{},
|
||||
"customSetting": "keep-me",
|
||||
}
|
||||
data, _ := json.Marshal(existing)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["customSetting"] != "keep-me" {
|
||||
t.Errorf("customSetting was not preserved")
|
||||
}
|
||||
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates model on re-edit", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
if config["planModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit(nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||
t.Error("expected no config file to be created for empty models")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses first model as primary", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud", "glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClineModels(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
|
||||
t.Run("returns nil when no config", func(t *testing.T) {
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when provider is not ollama", func(t *testing.T) {
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
config := map[string]any{
|
||||
"actModeApiProvider": "anthropic",
|
||||
"actModeOllamaModelId": "some-model",
|
||||
}
|
||||
data, _ := json.Marshal(config)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns model when ollama is configured", func(t *testing.T) {
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
config := map[string]any{
|
||||
"actModeApiProvider": "ollama",
|
||||
"actModeOllamaModelId": "kimi-k2.5:cloud",
|
||||
}
|
||||
data, _ := json.Marshal(config)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClinePaths(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns nil when no config exists", func(t *testing.T) {
|
||||
if paths := c.Paths(); paths != nil {
|
||||
t.Errorf("Paths() = %v, want nil", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 || paths[0] != configPath {
|
||||
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -56,8 +56,8 @@ func migrateConfig() (bool, error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var js json.RawMessage
|
||||
if err := json.Unmarshal(oldData, &js); err != nil {
|
||||
// Ignore legacy files with invalid JSON and continue startup.
|
||||
if !json.Valid(oldData) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ func save(cfg *config) error {
|
||||
return writeWithBackup(path, data)
|
||||
}
|
||||
|
||||
func saveIntegration(appName string, models []string) error {
|
||||
func SaveIntegration(appName string, models []string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save integration with models
|
||||
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{"model1", "model2"}); err != nil {
|
||||
t.Fatalf("failed to save integration: %v", err)
|
||||
}
|
||||
|
||||
@@ -604,7 +604,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
}
|
||||
|
||||
// Save integration with same model (this is the pattern we use)
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -619,7 +619,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate out-of-sync state (like manual edit or bug)
|
||||
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
@@ -634,7 +634,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
}
|
||||
|
||||
// The fix: when updating aliases, also update models
|
||||
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -650,7 +650,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial state
|
||||
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
@@ -662,7 +662,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
if err := saveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
|
||||
t.Run("save and load round-trip", func(t *testing.T) {
|
||||
models := []string{"llama3.2", "mistral", "qwen2.5"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
if err := SaveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
|
||||
t.Run("save and load aliases", func(t *testing.T) {
|
||||
models := []string{"llama3.2"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
if err := SaveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aliases := map[string]string{
|
||||
@@ -74,14 +74,14 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
@@ -94,7 +94,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
SaveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
config, _ := loadIntegration("codex")
|
||||
defaultModel := ""
|
||||
@@ -118,7 +118,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||
saveIntegration("Claude", []string{"model-x"})
|
||||
SaveIntegration("Claude", []string{"model-x"})
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
@@ -134,8 +134,8 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("multiple integrations in single file", func(t *testing.T) {
|
||||
saveIntegration("app1", []string{"model-1"})
|
||||
saveIntegration("app2", []string{"model-2"})
|
||||
SaveIntegration("app1", []string{"model-1"})
|
||||
SaveIntegration("app2", []string{"model-2"})
|
||||
|
||||
config1, _ := loadIntegration("app1")
|
||||
config2, _ := loadIntegration("app2")
|
||||
@@ -172,8 +172,8 @@ func TestListIntegrations(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("returns all saved integrations", func(t *testing.T) {
|
||||
saveIntegration("claude", []string{"model-1"})
|
||||
saveIntegration("droid", []string{"model-2"})
|
||||
SaveIntegration("claude", []string{"model-1"})
|
||||
SaveIntegration("droid", []string{"model-2"})
|
||||
|
||||
configs, err := listIntegrations()
|
||||
if err != nil {
|
||||
@@ -261,7 +261,7 @@ func TestSaveIntegration_NilModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveIntegration("test", nil); err != nil {
|
||||
if err := SaveIntegration("test", nil); err != nil {
|
||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -281,7 +281,7 @@ func TestSaveIntegration_EmptyAppName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
err := saveIntegration("", []string{"model"})
|
||||
err := SaveIntegration("", []string{"model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name, got nil")
|
||||
}
|
||||
@@ -511,7 +511,7 @@ func TestMigrateConfig(t *testing.T) {
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
// load triggers migration, then save should write to new path
|
||||
if err := saveIntegration("codex", []string{"qwen2.5"}); err != nil {
|
||||
if err := SaveIntegration("codex", []string{"qwen2.5"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -51,6 +52,16 @@ func (d *Droid) Run(model string, args []string) error {
|
||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("droid", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "droid", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -52,6 +53,7 @@ type AliasConfigurer interface {
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
"clawdbot": &Openclaw{},
|
||||
"cline": &Cline{},
|
||||
"codex": &Codex{},
|
||||
"moltbot": &Openclaw{},
|
||||
"droid": &Droid{},
|
||||
@@ -100,16 +102,17 @@ var recommendedVRAM = map[string]string{
|
||||
var integrationAliases = map[string]bool{
|
||||
"clawdbot": true,
|
||||
"moltbot": true,
|
||||
"pi": true,
|
||||
}
|
||||
|
||||
// integrationInstallHints maps integration names to install URLs.
|
||||
var integrationInstallHints = map[string]string{
|
||||
"claude": "https://code.claude.com/docs/en/quickstart",
|
||||
"cline": "https://cline.bot/cli",
|
||||
"openclaw": "https://docs.openclaw.ai",
|
||||
"codex": "https://developers.openai.com/codex/cli/",
|
||||
"droid": "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||
"opencode": "https://opencode.ai",
|
||||
"pi": "https://github.com/badlogic/pi-mono",
|
||||
}
|
||||
|
||||
// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable.
|
||||
@@ -127,13 +130,21 @@ type IntegrationInfo struct {
|
||||
// integrationDescriptions maps integration names to short descriptions.
|
||||
var integrationDescriptions = map[string]string{
|
||||
"claude": "Anthropic's coding tool with subagents",
|
||||
"cline": "Autonomous coding agent with parallel execution",
|
||||
"codex": "OpenAI's open-source coding agent",
|
||||
"openclaw": "Personal AI with 100+ skills",
|
||||
"droid": "Factory's coding agent across terminal and IDEs",
|
||||
"opencode": "Anomaly's open-source coding agent",
|
||||
"pi": "Minimal AI agent toolkit with plugin support",
|
||||
}
|
||||
|
||||
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name.
|
||||
// integrationOrder defines a custom display order for integrations.
|
||||
// Integrations listed here are placed at the end in the given order;
|
||||
// all others appear first, sorted alphabetically.
|
||||
var integrationOrder = []string{"opencode", "droid", "pi", "cline"}
|
||||
|
||||
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name
|
||||
// with integrationOrder entries placed at the end.
|
||||
func ListIntegrationInfos() []IntegrationInfo {
|
||||
var result []IntegrationInfo
|
||||
for name, r := range integrations {
|
||||
@@ -146,7 +157,26 @@ func ListIntegrationInfos() []IntegrationInfo {
|
||||
Description: integrationDescriptions[name],
|
||||
})
|
||||
}
|
||||
|
||||
orderRank := make(map[string]int, len(integrationOrder))
|
||||
for i, name := range integrationOrder {
|
||||
orderRank[name] = i + 1 // 1-indexed so 0 means "not in the list"
|
||||
}
|
||||
|
||||
slices.SortFunc(result, func(a, b IntegrationInfo) int {
|
||||
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||
// Both have custom order: sort by their rank
|
||||
if aRank > 0 && bRank > 0 {
|
||||
return aRank - bRank
|
||||
}
|
||||
// Only one has custom order: it goes last
|
||||
if aRank > 0 {
|
||||
return 1
|
||||
}
|
||||
if bRank > 0 {
|
||||
return -1
|
||||
}
|
||||
// Neither has custom order: alphabetical
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
return result
|
||||
@@ -184,9 +214,15 @@ func IsIntegrationInstalled(name string) bool {
|
||||
case "droid":
|
||||
_, err := exec.LookPath("droid")
|
||||
return err == nil
|
||||
case "cline":
|
||||
_, err := exec.LookPath("cline")
|
||||
return err == nil
|
||||
case "opencode":
|
||||
_, err := exec.LookPath("opencode")
|
||||
return err == nil
|
||||
case "pi":
|
||||
_, err := exec.LookPath("pi")
|
||||
return err == nil
|
||||
default:
|
||||
return true // Assume installed for unknown integrations
|
||||
}
|
||||
@@ -212,7 +248,8 @@ type ModelItem struct {
|
||||
}
|
||||
|
||||
// SingleSelector is a function type for single item selection.
|
||||
type SingleSelector func(title string, items []ModelItem) (string, error)
|
||||
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
|
||||
|
||||
// MultiSelector is a function type for multi item selection.
|
||||
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||
@@ -234,6 +271,11 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, client)
|
||||
if cloudDisabled {
|
||||
existing = filterCloudModels(existing)
|
||||
}
|
||||
|
||||
lastModel := LastModel()
|
||||
var preChecked []string
|
||||
if lastModel != "" {
|
||||
@@ -242,11 +284,15 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, preChecked, lastModel)
|
||||
|
||||
if cloudDisabled {
|
||||
items = filterCloudItems(items)
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
selected, err := selector("Select model to run:", items)
|
||||
selected, err := selector("Select model to run:", items, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -356,13 +402,11 @@ func selectIntegration() (string, error) {
|
||||
return "", fmt.Errorf("no integrations available")
|
||||
}
|
||||
|
||||
names := slices.Sorted(maps.Keys(integrations))
|
||||
var items []ModelItem
|
||||
for _, name := range names {
|
||||
for name, r := range integrations {
|
||||
if integrationAliases[name] {
|
||||
continue
|
||||
}
|
||||
r := integrations[name]
|
||||
description := r.String()
|
||||
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
||||
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
|
||||
@@ -370,7 +414,25 @@ func selectIntegration() (string, error) {
|
||||
items = append(items, ModelItem{Name: name, Description: description})
|
||||
}
|
||||
|
||||
return DefaultSingleSelector("Select integration:", items)
|
||||
orderRank := make(map[string]int, len(integrationOrder))
|
||||
for i, name := range integrationOrder {
|
||||
orderRank[name] = i + 1
|
||||
}
|
||||
slices.SortFunc(items, func(a, b ModelItem) int {
|
||||
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||
if aRank > 0 && bRank > 0 {
|
||||
return aRank - bRank
|
||||
}
|
||||
if aRank > 0 {
|
||||
return 1
|
||||
}
|
||||
if bRank > 0 {
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
return DefaultSingleSelector("Select integration:", items, "")
|
||||
}
|
||||
|
||||
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
|
||||
@@ -395,6 +457,11 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, client)
|
||||
if cloudDisabled {
|
||||
existing = filterCloudModels(existing)
|
||||
}
|
||||
|
||||
var preChecked []string
|
||||
if saved, err := loadIntegration(name); err == nil {
|
||||
preChecked = saved.Models
|
||||
@@ -404,6 +471,10 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
|
||||
items, preChecked, existingModels, cloudModels := buildModelList(existing, preChecked, current)
|
||||
|
||||
if cloudDisabled {
|
||||
items = filterCloudItems(items)
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no models available")
|
||||
}
|
||||
@@ -419,7 +490,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
if _, ok := r.(AliasConfigurer); ok {
|
||||
prompt = fmt.Sprintf("Select Primary model for %s:", r)
|
||||
}
|
||||
model, err := single(prompt, items)
|
||||
model, err := single(prompt, items, current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -510,8 +581,17 @@ func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]b
|
||||
})
|
||||
}
|
||||
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, client)
|
||||
if cloudDisabled {
|
||||
existing = filterCloudModels(existing)
|
||||
}
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
|
||||
if cloudDisabled {
|
||||
items = filterCloudItems(items)
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
@@ -540,6 +620,9 @@ func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]
|
||||
if len(selectedCloudModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||
}
|
||||
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
@@ -672,25 +755,6 @@ func LaunchIntegrationWithModel(name, modelName string) error {
|
||||
return runIntegration(name, modelName, nil)
|
||||
}
|
||||
|
||||
// SaveIntegrationModel saves the model for an integration.
|
||||
func SaveIntegrationModel(name, modelName string) error {
|
||||
// Load existing models and prepend the new one
|
||||
var models []string
|
||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||
models = existing.Models
|
||||
// Remove the model if it already exists
|
||||
for i, m := range models {
|
||||
if m == modelName {
|
||||
models = append(models[:i], models[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Prepend the new model
|
||||
models = append([]string{modelName}, models...)
|
||||
return saveIntegration(name, models)
|
||||
}
|
||||
|
||||
// SaveAndEditIntegration saves the models for an Editor integration and runs its Edit method
|
||||
// to write the integration's config files.
|
||||
func SaveAndEditIntegration(name string, models []string) error {
|
||||
@@ -698,7 +762,7 @@ func SaveAndEditIntegration(name string, models []string) error {
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
if err := saveIntegration(name, models); err != nil {
|
||||
if err := SaveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
if editor, isEditor := r.(Editor); isEditor {
|
||||
@@ -709,6 +773,29 @@ func SaveAndEditIntegration(name string, models []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveEditorModels filters out cloud-disabled models before editor launch.
|
||||
// If no models remain, it invokes picker to collect a valid replacement list.
|
||||
func resolveEditorModels(name string, models []string, picker func() ([]string, error)) ([]string, error) {
|
||||
filtered := filterDisabledCloudModels(models)
|
||||
if len(filtered) != len(models) {
|
||||
if err := SaveIntegration(name, filtered); err != nil {
|
||||
return nil, fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
}
|
||||
if len(filtered) > 0 {
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
selected, err := picker()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := SaveIntegration(name, selected); err != nil {
|
||||
return nil, fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors.
|
||||
func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error {
|
||||
r, ok := integrations[name]
|
||||
@@ -743,7 +830,7 @@ func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single
|
||||
}
|
||||
}
|
||||
|
||||
if err := saveIntegration(name, models); err != nil {
|
||||
if err := SaveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
@@ -776,10 +863,12 @@ Without arguments, this is equivalent to running 'ollama' directly.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
cline Cline
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||
pi Pi
|
||||
|
||||
Examples:
|
||||
ollama launch
|
||||
@@ -837,6 +926,10 @@ Examples:
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
if modelFlag != "" && IsCloudModelDisabled(cmd.Context(), modelFlag) {
|
||||
modelFlag = ""
|
||||
}
|
||||
|
||||
// Handle AliasConfigurer integrations (claude, codex)
|
||||
if ac, ok := r.(AliasConfigurer); ok {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
@@ -864,7 +957,7 @@ Examples:
|
||||
model = cfg.Models[0]
|
||||
// AliasConfigurer integrations use single model; sanitize if multiple
|
||||
if len(cfg.Models) > 1 {
|
||||
_ = saveIntegration(name, []string{model})
|
||||
_ = SaveIntegration(name, []string{model})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -876,7 +969,9 @@ Examples:
|
||||
|
||||
// Validate saved model still exists
|
||||
if model != "" && modelFlag == "" {
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
|
||||
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled && isCloudModelName(model) {
|
||||
model = ""
|
||||
} else if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
|
||||
if err := ShowOrPull(cmd.Context(), client, model); err != nil {
|
||||
model = ""
|
||||
@@ -884,18 +979,16 @@ Examples:
|
||||
}
|
||||
}
|
||||
|
||||
// If no valid model or --config flag, show picker
|
||||
if model == "" || configFlag {
|
||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model = aliases["primary"]
|
||||
existingAliases = aliases
|
||||
// Show picker so user can change model (skip when --model flag provided)
|
||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, modelFlag == "")
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model = aliases["primary"]
|
||||
existingAliases = aliases
|
||||
|
||||
// Ensure cloud models are authenticated
|
||||
if isCloudModel(cmd.Context(), client, model) {
|
||||
@@ -908,7 +1001,7 @@ Examples:
|
||||
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
|
||||
}
|
||||
if err := saveIntegration(name, []string{model}); err != nil {
|
||||
if err := SaveIntegration(name, []string{model}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
@@ -946,11 +1039,24 @@ Examples:
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
|
||||
return runIntegration(name, saved.Models[0], passArgs)
|
||||
models = filterDisabledCloudModels(models)
|
||||
if len(models) == 0 {
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
current := ""
|
||||
if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 {
|
||||
current = saved.Models[0]
|
||||
}
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
models, err = selectModels(cmd.Context(), name, current)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
@@ -974,7 +1080,7 @@ Examples:
|
||||
}
|
||||
}
|
||||
|
||||
if err := saveIntegration(name, models); err != nil {
|
||||
if err := SaveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
@@ -1048,7 +1154,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if strings.HasSuffix(rec.Name, ":cloud") {
|
||||
if isCloudModelName(rec.Name) {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
@@ -1153,7 +1259,55 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
// isCloudModel checks if a model is a cloud model using the Show API.
|
||||
// IsCloudModelDisabled reports whether the given model name looks like a cloud
|
||||
// model and cloud features are currently disabled on the server.
|
||||
func IsCloudModelDisabled(ctx context.Context, name string) bool {
|
||||
if !isCloudModelName(name) {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
disabled, _ := cloudStatusDisabled(ctx, client)
|
||||
return disabled
|
||||
}
|
||||
|
||||
func isCloudModelName(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
}
|
||||
|
||||
func filterCloudModels(existing []modelInfo) []modelInfo {
|
||||
filtered := existing[:0]
|
||||
for _, m := range existing {
|
||||
if !m.Remote {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// filterDisabledCloudModels removes cloud models from a list when cloud is disabled.
|
||||
func filterDisabledCloudModels(models []string) []string {
|
||||
var filtered []string
|
||||
for _, m := range models {
|
||||
if !IsCloudModelDisabled(context.Background(), m) {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func filterCloudItems(items []ModelItem) []ModelItem {
|
||||
filtered := items[:0]
|
||||
for _, item := range items {
|
||||
if !isCloudModelName(item.Name) {
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||
if client == nil {
|
||||
return false
|
||||
@@ -1183,6 +1337,11 @@ func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, client)
|
||||
if cloudDisabled {
|
||||
existing = filterCloudModels(existing)
|
||||
}
|
||||
|
||||
lastModel := LastModel()
|
||||
var preChecked []string
|
||||
if lastModel != "" {
|
||||
@@ -1191,9 +1350,25 @@ func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
|
||||
|
||||
items, _, existingModels, _ := buildModelList(existing, preChecked, lastModel)
|
||||
|
||||
if cloudDisabled {
|
||||
items = filterCloudItems(items)
|
||||
}
|
||||
|
||||
return items, existingModels
|
||||
}
|
||||
|
||||
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
|
||||
status, err := client.CloudStatusExperimental(ctx)
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||
return false, false
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
return status.Cloud.Disabled, true
|
||||
}
|
||||
|
||||
func pullModel(ctx context.Context, client *api.Client, model string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
@@ -16,6 +16,28 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type stubEditorRunner struct {
|
||||
edited [][]string
|
||||
ranModel string
|
||||
}
|
||||
|
||||
func (s *stubEditorRunner) Run(model string, args []string) error {
|
||||
s.ranModel = model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubEditorRunner) String() string { return "StubEditor" }
|
||||
|
||||
func (s *stubEditorRunner) Paths() []string { return nil }
|
||||
|
||||
func (s *stubEditorRunner) Edit(models []string) error {
|
||||
cloned := append([]string(nil), models...)
|
||||
s.edited = append(s.edited, cloned)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubEditorRunner) Models() []string { return nil }
|
||||
|
||||
func TestIntegrationLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -149,6 +171,10 @@ func TestLaunchCmd_TUICallback(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.NotFoundHandler())
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
@@ -680,7 +706,7 @@ func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save a config for opencode so it looks like a previous launch
|
||||
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
|
||||
if err := SaveIntegration("opencode", []string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -697,6 +723,137 @@ func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEditorLaunchModels_PicksWhenAllFiltered(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
pickerCalled := false
|
||||
models, err := resolveEditorModels("opencode", []string{"glm-5:cloud"}, func() ([]string, error) {
|
||||
pickerCalled = true
|
||||
return []string{"llama3.2"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveEditorLaunchModels returned error: %v", err)
|
||||
}
|
||||
if !pickerCalled {
|
||||
t.Fatal("expected model picker to be called when all models are filtered")
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, models); diff != "" {
|
||||
t.Fatalf("resolved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
saved, err := loadIntegration("opencode")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEditorLaunchModels_FiltersAndSkipsPickerWhenLocalRemains(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
pickerCalled := false
|
||||
models, err := resolveEditorModels("droid", []string{"llama3.2", "glm-5:cloud"}, func() ([]string, error) {
|
||||
pickerCalled = true
|
||||
return []string{"qwen3:8b"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveEditorLaunchModels returned error: %v", err)
|
||||
}
|
||||
if pickerCalled {
|
||||
t.Fatal("picker should not be called when a local model remains")
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, models); diff != "" {
|
||||
t.Fatalf("resolved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
saved, err := loadIntegration("droid")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmd_ModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/show":
|
||||
fmt.Fprintf(w, `{"model":"llama3.2"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &stubEditorRunner{}
|
||||
old, existed := integrations["stubeditor"]
|
||||
integrations["stubeditor"] = stub
|
||||
defer func() {
|
||||
if existed {
|
||||
integrations["stubeditor"] = old
|
||||
} else {
|
||||
delete(integrations, "stubeditor")
|
||||
}
|
||||
}()
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
|
||||
saved, err := loadIntegration("stubeditor")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasConfigurerInterface(t *testing.T) {
|
||||
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
|
||||
claude := &Claude{}
|
||||
@@ -1091,10 +1248,26 @@ func TestListIntegrationInfos(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sorted by name", func(t *testing.T) {
|
||||
t.Run("sorted with custom order at end", func(t *testing.T) {
|
||||
// integrationOrder entries (cline, opencode) should appear last, in that order.
|
||||
// All other entries should be sorted alphabetically before them.
|
||||
orderRank := make(map[string]int)
|
||||
for i, name := range integrationOrder {
|
||||
orderRank[name] = i + 1
|
||||
}
|
||||
for i := 1; i < len(infos); i++ {
|
||||
if infos[i-1].Name >= infos[i].Name {
|
||||
t.Errorf("not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
||||
aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
|
||||
switch {
|
||||
case aRank == 0 && bRank == 0:
|
||||
if infos[i-1].Name >= infos[i].Name {
|
||||
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
||||
}
|
||||
case aRank > 0 && bRank == 0:
|
||||
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
|
||||
case aRank > 0 && bRank > 0:
|
||||
if aRank >= bRank {
|
||||
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1234,7 +1407,7 @@ func TestIntegrationModels(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("returns all saved models", func(t *testing.T) {
|
||||
if err := saveIntegration("droid", []string{"llama3.2", "qwen3:8b"}); err != nil {
|
||||
if err := SaveIntegration("droid", []string{"llama3.2", "qwen3:8b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := IntegrationModels("droid")
|
||||
|
||||
@@ -2,7 +2,9 @@ package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -32,6 +34,16 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("openclaw", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "openclaw", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
@@ -58,7 +70,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
|
||||
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
|
||||
|
||||
err := cmd.Run()
|
||||
err = cmd.Run()
|
||||
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
|
||||
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
|
||||
return nil
|
||||
|
||||
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
@@ -51,6 +52,16 @@ func (o *OpenCode) Run(model string, args []string) error {
|
||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "opencode", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := o.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -365,14 +365,27 @@ func (m selectorModel) View() string {
|
||||
return s
|
||||
}
|
||||
|
||||
func SelectSingle(title string, items []SelectItem) (string, error) {
|
||||
// cursorForCurrent returns the item index matching current, or 0 if not found.
|
||||
func cursorForCurrent(items []SelectItem, current string) int {
|
||||
if current != "" {
|
||||
for i, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func SelectSingle(title string, items []SelectItem, current string) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
title: title,
|
||||
items: items,
|
||||
cursor: cursorForCurrent(items, current),
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
@@ -402,6 +415,12 @@ type multiSelectorModel struct {
|
||||
cancelled bool
|
||||
confirmed bool
|
||||
width int
|
||||
|
||||
// multi enables full multi-select editing mode. The zero value (false)
|
||||
// shows a single-select picker where Enter adds the chosen model to
|
||||
// the existing list. Tab toggles between modes.
|
||||
multi bool
|
||||
singleAdd string // model picked in single mode
|
||||
}
|
||||
|
||||
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
|
||||
@@ -416,13 +435,23 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
|
||||
m.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := m.itemIndex[name]; ok {
|
||||
// Reverse order so preChecked[0] (the current default) ends up last
|
||||
// in checkOrder, matching the "last checked = default" convention.
|
||||
for i := len(preChecked) - 1; i >= 0; i-- {
|
||||
if idx, ok := m.itemIndex[preChecked[i]]; ok {
|
||||
m.checked[idx] = true
|
||||
m.checkOrder = append(m.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
// Position cursor on the current default model
|
||||
if len(preChecked) > 0 {
|
||||
if idx, ok := m.itemIndex[preChecked[0]]; ok {
|
||||
m.cursor = idx
|
||||
m.updateScroll(m.otherStart())
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -533,14 +562,25 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyTab:
|
||||
m.multi = !m.multi
|
||||
|
||||
case tea.KeyEnter:
|
||||
if len(m.checkOrder) > 0 {
|
||||
if !m.multi {
|
||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||
m.singleAdd = filtered[m.cursor].Name
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
} else if len(m.checkOrder) > 0 {
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case tea.KeySpace:
|
||||
m.toggleItem()
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
@@ -576,15 +616,36 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
// On some terminals (e.g. Windows PowerShell), space arrives as
|
||||
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
|
||||
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
} else {
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
s.WriteString("\n")
|
||||
if item.Description != "" {
|
||||
s.WriteString(selectorDescLineStyle.Render(item.Description))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
@@ -596,7 +657,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
|
||||
}
|
||||
|
||||
suffix := ""
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
|
||||
suffix = " " + selectorDefaultTagStyle.Render("(default)")
|
||||
}
|
||||
|
||||
@@ -618,6 +679,11 @@ func (m multiSelectorModel) View() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
renderItem := m.renderSingleItem
|
||||
if m.multi {
|
||||
renderItem = m.renderMultiItem
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
@@ -642,7 +708,7 @@ func (m multiSelectorModel) View() string {
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
m.renderMultiItem(&s, filtered[idx], idx)
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
@@ -665,7 +731,7 @@ func (m multiSelectorModel) View() string {
|
||||
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
||||
s.WriteString("\n")
|
||||
for _, idx := range recItems {
|
||||
m.renderMultiItem(&s, filtered[idx], idx)
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -685,7 +751,7 @@ func (m multiSelectorModel) View() string {
|
||||
if idx >= len(otherItems) {
|
||||
break
|
||||
}
|
||||
m.renderMultiItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
}
|
||||
|
||||
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
@@ -697,15 +763,18 @@ func (m multiSelectorModel) View() string {
|
||||
|
||||
s.WriteString("\n")
|
||||
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
if !m.multi {
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
|
||||
|
||||
result := s.String()
|
||||
if m.width > 0 {
|
||||
@@ -728,18 +797,28 @@ func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]st
|
||||
}
|
||||
|
||||
fm := finalModel.(multiSelectorModel)
|
||||
if fm.cancelled {
|
||||
if fm.cancelled || !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
if !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
// Single-add mode: prepend the picked model, keep existing models deduped
|
||||
if fm.singleAdd != "" {
|
||||
result := []string{fm.singleAdd}
|
||||
for _, name := range preChecked {
|
||||
if name != fm.singleAdd {
|
||||
result = append(result, name)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var result []string
|
||||
// Multi-edit mode: last checked is default (first in result)
|
||||
last := fm.checkOrder[len(fm.checkOrder)-1]
|
||||
result := []string{fm.items[last].Name}
|
||||
for _, idx := range fm.checkOrder {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
if idx != last {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -382,6 +382,42 @@ func TestUpdateNavigation_Backspace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- cursorForCurrent ---
|
||||
|
||||
func TestCursorForCurrent(t *testing.T) {
|
||||
testItems := []SelectItem{
|
||||
{Name: "llama3.2", Recommended: true},
|
||||
{Name: "qwen3:8b", Recommended: true},
|
||||
{Name: "gemma3:latest"},
|
||||
{Name: "deepseek-r1"},
|
||||
{Name: "glm-5:cloud"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
current string
|
||||
want int
|
||||
}{
|
||||
{"empty current", "", 0},
|
||||
{"exact match", "qwen3:8b", 1},
|
||||
{"no match returns 0", "nonexistent", 0},
|
||||
{"bare name matches with :latest suffix", "gemma3", 2},
|
||||
{"full tag matches bare item", "llama3.2:latest", 0},
|
||||
{"cloud model exact match", "glm-5:cloud", 4},
|
||||
{"cloud model bare name", "glm-5", 4},
|
||||
{"recommended item exact match", "llama3.2", 0},
|
||||
{"recommended item with tag", "qwen3", 1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := cursorForCurrent(testItems, tt.current); got != tt.want {
|
||||
t.Errorf("cursorForCurrent(%q) = %d, want %d", tt.current, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReorderItems ---
|
||||
|
||||
func TestReorderItems(t *testing.T) {
|
||||
@@ -503,6 +539,7 @@ func TestMultiView_CursorIndicator(t *testing.T) {
|
||||
|
||||
func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
|
||||
if !strings.Contains(content, "[x]") {
|
||||
@@ -514,11 +551,18 @@ func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMultiView_DefaultTag(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b"})
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Error("first checked item should have (default) tag")
|
||||
t.Error("should have (default) tag")
|
||||
}
|
||||
// preChecked[0] ("a") should be the default (last in checkOrder)
|
||||
aIdx := strings.Index(content, "a")
|
||||
defaultIdx := strings.Index(content, "(default)")
|
||||
if defaultIdx < aIdx {
|
||||
t.Error("(default) tag should appear after 'a' (the current default)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -545,6 +589,200 @@ func TestMultiView_OverflowIndicator(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Multi-select space toggle (including KeyRunes fallback for Windows PowerShell) ---
|
||||
|
||||
func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.multi = true
|
||||
m.cursor = 1
|
||||
|
||||
// Simulate space delivered as tea.KeySpace
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if !m.checked[1] {
|
||||
t.Error("space (KeySpace) should toggle the item at cursor")
|
||||
}
|
||||
if m.filter != "" {
|
||||
t.Error("space should not modify filter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.multi = true
|
||||
m.cursor = 1
|
||||
|
||||
// Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if !m.checked[1] {
|
||||
t.Error("space (KeyRunes) should toggle the item at cursor")
|
||||
}
|
||||
if m.filter != "" {
|
||||
t.Error("space rune should not be added to filter")
|
||||
}
|
||||
if m.cursor != 1 {
|
||||
t.Errorf("cursor should stay at 1, got %d", m.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Single-add mode ---
|
||||
|
||||
func TestMulti_StartsInSingleMode(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
if m.multi {
|
||||
t.Error("should start in single mode (multi=false)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeNoCheckboxes(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
content := m.View()
|
||||
if strings.Contains(content, "[x]") || strings.Contains(content, "[ ]") {
|
||||
t.Error("single mode should not show checkboxes")
|
||||
}
|
||||
if !strings.Contains(content, "▸") {
|
||||
t.Error("single mode should show cursor indicator")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeEnterPicksItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.cursor = 1
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if m.singleAdd != "b" {
|
||||
t.Errorf("enter in single mode should pick cursor item, got %q", m.singleAdd)
|
||||
}
|
||||
if !m.confirmed {
|
||||
t.Error("should set confirmed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeSpaceIsNoop(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
m.cursor = 0
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if len(m.checked) != 0 {
|
||||
t.Error("space in single mode should not toggle items")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeSpaceRuneIsNoop(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
m.cursor = 0
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if len(m.checked) != 0 {
|
||||
t.Error("space rune in single mode should not toggle items")
|
||||
}
|
||||
if m.filter != "" {
|
||||
t.Error("space rune in single mode should not add to filter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_TabTogglesMode(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
|
||||
if m.multi {
|
||||
t.Fatal("should start in single mode")
|
||||
}
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
m = updated.(multiSelectorModel)
|
||||
if !m.multi {
|
||||
t.Error("tab should switch to multi mode")
|
||||
}
|
||||
|
||||
updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
m = updated.(multiSelectorModel)
|
||||
if m.multi {
|
||||
t.Error("tab should switch back to single mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeHelpText(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "tab add multiple") {
|
||||
t.Error("single mode should show 'tab add multiple' in help")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_MultiModeHelpText(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "tab select single") {
|
||||
t.Error("multi mode should show 'tab select single' in help")
|
||||
}
|
||||
}
|
||||
|
||||
// --- preChecked initialization order ---
|
||||
|
||||
func TestMulti_PreCheckedDefaultIsLast(t *testing.T) {
|
||||
// preChecked[0] ("a") is the current default and should end up
|
||||
// last in checkOrder so it gets the (default) tag.
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b", "c"})
|
||||
|
||||
if len(m.checkOrder) != 3 {
|
||||
t.Fatalf("expected 3 in checkOrder, got %d", len(m.checkOrder))
|
||||
}
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "a" {
|
||||
t.Errorf("preChecked[0] should be last in checkOrder, got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_CursorOnDefaultModel(t *testing.T) {
|
||||
// preChecked[0] ("b") is the default; cursor should start on it
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c"})
|
||||
|
||||
if m.cursor != 1 {
|
||||
t.Errorf("cursor should be on preChecked[0] ('b') at index 1, got %d", m.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Multi-mode last-checked is default ---
|
||||
|
||||
func TestMulti_LastCheckedIsDefault(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("alpha", "beta", "gamma"), nil)
|
||||
m.multi = true
|
||||
|
||||
// Check "alpha" then "gamma"
|
||||
m.cursor = 0
|
||||
m.toggleItem()
|
||||
m.cursor = 2
|
||||
m.toggleItem()
|
||||
|
||||
// Last checked ("gamma") should be at the end of checkOrder
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "gamma" {
|
||||
t.Errorf("last checked should be 'gamma', got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
|
||||
// The (default) tag renders based on checkOrder[len-1]
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Fatal("should show (default) tag")
|
||||
}
|
||||
// "alpha" line should NOT have the default tag
|
||||
for _, line := range strings.Split(content, "\n") {
|
||||
if strings.Contains(line, "alpha") && strings.Contains(line, "(default)") {
|
||||
t.Error("'alpha' (first checked) should not have (default) tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Key message helpers for testing
|
||||
|
||||
type keyType = int
|
||||
|
||||
@@ -131,7 +131,7 @@ type model struct {
|
||||
signInURL string
|
||||
signInModel string
|
||||
signInSpinner int
|
||||
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
|
||||
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
|
||||
|
||||
width int // terminal width from WindowSizeMsg
|
||||
statusMsg string // temporary status message shown near help text
|
||||
@@ -209,7 +209,26 @@ func (m *model) openMultiModelModal(integration string) {
|
||||
}
|
||||
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud")
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
}
|
||||
|
||||
func cloudStatusDisabled(client *api.Client) bool {
|
||||
status, err := client.CloudStatusExperimental(context.Background())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return status.Cloud.Disabled
|
||||
}
|
||||
|
||||
func cloudModelDisabled(name string) bool {
|
||||
if !isCloudModel(name) {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cloudStatusDisabled(client)
|
||||
}
|
||||
|
||||
// checkCloudSignIn checks if a cloud model needs sign-in.
|
||||
@@ -222,6 +241,9 @@ func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if cloudStatusDisabled(client) {
|
||||
return nil
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
@@ -272,7 +294,11 @@ func (m *model) loadAvailableModels() {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cloudDisabled := cloudStatusDisabled(client)
|
||||
for _, mdl := range models.Models {
|
||||
if cloudDisabled && mdl.RemoteModel != "" {
|
||||
continue
|
||||
}
|
||||
m.availableModels[mdl.Name] = true
|
||||
}
|
||||
}
|
||||
@@ -403,8 +429,24 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
if m.multiModalSelector.confirmed {
|
||||
var selected []string
|
||||
for _, idx := range m.multiModalSelector.checkOrder {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
if m.multiModalSelector.singleAdd != "" {
|
||||
// Single-add mode: prepend picked model, keep existing deduped
|
||||
selected = []string{m.multiModalSelector.singleAdd}
|
||||
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
|
||||
if name != m.multiModalSelector.singleAdd {
|
||||
selected = append(selected, name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Last checked is default (first in result)
|
||||
co := m.multiModalSelector.checkOrder
|
||||
last := co[len(co)-1]
|
||||
selected = []string{m.multiModalSelector.items[last].Name}
|
||||
for _, idx := range co {
|
||||
if idx != last {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(selected) > 0 {
|
||||
m.changeModels = selected
|
||||
@@ -496,6 +538,15 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
if configuredModel != "" && isCloudModel(configuredModel) && cloudModelDisabled(configuredModel) {
|
||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
||||
m.openMultiModelModal(item.integration)
|
||||
} else {
|
||||
m.openModelModal(configuredModel)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
|
||||
@@ -226,3 +226,7 @@ curl https://ollama.com/api/chat \
|
||||
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
## Local only
|
||||
|
||||
Ollama can run in local-only mode by [disabling Ollama's cloud](./faq#how-do-i-disable-ollama-cloud) features.
|
||||
@@ -106,20 +106,23 @@
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/index",
|
||||
{
|
||||
"group": "Assistants",
|
||||
"expanded": true,
|
||||
"pages": [
|
||||
"/integrations/openclaw"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Coding",
|
||||
"expanded": true,
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/codex",
|
||||
"/integrations/opencode",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Assistants",
|
||||
"pages": [
|
||||
"/integrations/openclaw"
|
||||
"/integrations/goose",
|
||||
"/integrations/pi"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
20
docs/faq.mdx
20
docs/faq.mdx
@@ -160,6 +160,26 @@ docker run -d -e HTTPS_PROXY=https://my.proxy.example.com -p 11434:11434 ollama-
|
||||
|
||||
Ollama runs locally. We don't see your prompts or data when you run locally. When using cloud-hosted models, we process your prompts and responses to provide the service but do not store or log that content and never train on it. We collect basic account info and limited usage metadata to provide the service that does not include prompt or response content. We don't sell your data. You can delete your account anytime.
|
||||
|
||||
## How do I disable Ollama's cloud features?
|
||||
|
||||
Ollama can run in local only mode by disabling Ollama's cloud features. By turning off Ollama's cloud features, you will lose the ability to use Ollama's cloud models and web search.
|
||||
|
||||
Set `disable_ollama_cloud` in `~/.ollama/server.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"disable_ollama_cloud": true
|
||||
}
|
||||
```
|
||||
|
||||
You can also set the environment variable:
|
||||
|
||||
```shell
|
||||
OLLAMA_NO_CLOUD=1
|
||||
```
|
||||
|
||||
Restart Ollama after changing configuration. Once disabled, Ollama's logs will show `Ollama cloud disabled: true`.
|
||||
|
||||
## How can I expose Ollama on my network?
|
||||
|
||||
Ollama binds 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable.
|
||||
|
||||
@@ -13,6 +13,7 @@ Coding assistants that can read, modify, and execute code in your projects.
|
||||
- [OpenCode](/integrations/opencode)
|
||||
- [Droid](/integrations/droid)
|
||||
- [Goose](/integrations/goose)
|
||||
- [Pi](/integrations/pi)
|
||||
|
||||
## Assistants
|
||||
|
||||
|
||||
57
docs/integrations/pi.mdx
Normal file
57
docs/integrations/pi.mdx
Normal file
@@ -0,0 +1,57 @@
|
||||
---
|
||||
title: Pi
|
||||
---
|
||||
|
||||
Pi is a minimal AI agent toolkit with plugin support.
|
||||
|
||||
## Install
|
||||
|
||||
Install [Pi](https://github.com/badlogic/pi-mono):
|
||||
|
||||
```bash
|
||||
npm install -g @mariozechner/pi-coding-agent
|
||||
```
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch pi
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch pi --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a configuration block to `~/.pi/agent/models.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{
|
||||
"id": "qwen3-coder"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Update `~/.pi/agent/settings.json` to set the default provider:
|
||||
|
||||
```json
|
||||
{
|
||||
"defaultProvider": "ollama",
|
||||
"defaultModel": "qwen3-coder"
|
||||
}
|
||||
```
|
||||
@@ -27,9 +27,17 @@ The menu provides quick access to:
|
||||
- **Launch tools** - Claude Code, Codex, OpenClaw, and more
|
||||
- **Additional integrations** - Available under "More..."
|
||||
|
||||
## Assistants
|
||||
|
||||
Launch [OpenClaw](/integrations/openclaw), a personal AI with 100+ skills:
|
||||
|
||||
```sh
|
||||
ollama launch openclaw
|
||||
```
|
||||
|
||||
## Coding
|
||||
|
||||
Launch coding tools with Ollama models:
|
||||
Launch [Claude Code](/integrations/claude-code) and other coding tools with Ollama models:
|
||||
|
||||
```sh
|
||||
ollama launch claude
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package envconfig
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
@@ -11,6 +13,7 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -206,6 +209,8 @@ var (
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable Vulkan backend
|
||||
EnableVulkan = Bool("OLLAMA_VULKAN")
|
||||
// NoCloudEnv checks the OLLAMA_NO_CLOUD environment variable.
|
||||
NoCloudEnv = Bool("OLLAMA_NO_CLOUD")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
@@ -285,6 +290,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
@@ -334,3 +340,91 @@ func Values() map[string]string {
|
||||
func Var(key string) string {
|
||||
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
|
||||
}
|
||||
|
||||
// serverConfigData holds the parsed fields from ~/.ollama/server.json.
|
||||
type serverConfigData struct {
|
||||
DisableOllamaCloud bool `json:"disable_ollama_cloud,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
serverCfgMu sync.RWMutex
|
||||
serverCfgLoaded bool
|
||||
serverCfg serverConfigData
|
||||
)
|
||||
|
||||
func loadServerConfig() {
|
||||
serverCfgMu.RLock()
|
||||
if serverCfgLoaded {
|
||||
serverCfgMu.RUnlock()
|
||||
return
|
||||
}
|
||||
serverCfgMu.RUnlock()
|
||||
|
||||
cfg := serverConfigData{}
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
path := filepath.Join(home, ".ollama", "server.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("envconfig: could not read server config", "error", err)
|
||||
}
|
||||
} else if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
slog.Debug("envconfig: could not parse server config", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
serverCfgMu.Lock()
|
||||
defer serverCfgMu.Unlock()
|
||||
if serverCfgLoaded {
|
||||
return
|
||||
}
|
||||
serverCfg = cfg
|
||||
serverCfgLoaded = true
|
||||
}
|
||||
|
||||
func cachedServerConfig() serverConfigData {
|
||||
serverCfgMu.RLock()
|
||||
defer serverCfgMu.RUnlock()
|
||||
return serverCfg
|
||||
}
|
||||
|
||||
// ReloadServerConfig refreshes the cached ~/.ollama/server.json settings.
|
||||
func ReloadServerConfig() {
|
||||
serverCfgMu.Lock()
|
||||
serverCfgLoaded = false
|
||||
serverCfg = serverConfigData{}
|
||||
serverCfgMu.Unlock()
|
||||
|
||||
loadServerConfig()
|
||||
}
|
||||
|
||||
// NoCloud returns true if Ollama cloud features are disabled,
|
||||
// checking both the OLLAMA_NO_CLOUD environment variable and
|
||||
// the disable_ollama_cloud field in ~/.ollama/server.json.
|
||||
func NoCloud() bool {
|
||||
if NoCloudEnv() {
|
||||
return true
|
||||
}
|
||||
loadServerConfig()
|
||||
return cachedServerConfig().DisableOllamaCloud
|
||||
}
|
||||
|
||||
// NoCloudSource returns the source of the cloud-disabled decision.
|
||||
// Returns "none", "env", "config", or "both".
|
||||
func NoCloudSource() string {
|
||||
envDisabled := NoCloudEnv()
|
||||
loadServerConfig()
|
||||
configDisabled := cachedServerConfig().DisableOllamaCloud
|
||||
|
||||
switch {
|
||||
case envDisabled && configDisabled:
|
||||
return "both"
|
||||
case envDisabled:
|
||||
return "env"
|
||||
case configDisabled:
|
||||
return "config"
|
||||
default:
|
||||
return "none"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package envconfig
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -326,3 +328,81 @@ func TestLogLevel(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoCloud(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
configContent string
|
||||
wantDisabled bool
|
||||
wantSource string
|
||||
}{
|
||||
{
|
||||
name: "neither env nor config",
|
||||
wantDisabled: false,
|
||||
wantSource: "none",
|
||||
},
|
||||
{
|
||||
name: "env only",
|
||||
envValue: "1",
|
||||
wantDisabled: true,
|
||||
wantSource: "env",
|
||||
},
|
||||
{
|
||||
name: "config only",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
wantDisabled: true,
|
||||
wantSource: "config",
|
||||
},
|
||||
{
|
||||
name: "both env and config",
|
||||
envValue: "1",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
wantDisabled: true,
|
||||
wantSource: "both",
|
||||
},
|
||||
{
|
||||
name: "config false",
|
||||
configContent: `{"disable_ollama_cloud": false}`,
|
||||
wantDisabled: false,
|
||||
wantSource: "none",
|
||||
},
|
||||
{
|
||||
name: "invalid config ignored",
|
||||
configContent: `{invalid json`,
|
||||
wantDisabled: false,
|
||||
wantSource: "none",
|
||||
},
|
||||
{
|
||||
name: "no config file",
|
||||
wantDisabled: false,
|
||||
wantSource: "none",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
if tt.configContent != "" {
|
||||
configDir := filepath.Join(home, ".ollama")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(configDir, "server.json"), []byte(tt.configContent), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
setTestHome(t, home)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
|
||||
|
||||
if got := NoCloud(); got != tt.wantDisabled {
|
||||
t.Errorf("NoCloud() = %v, want %v", got, tt.wantDisabled)
|
||||
}
|
||||
|
||||
if got := NoCloudSource(); got != tt.wantSource {
|
||||
t.Errorf("NoCloudSource() = %q, want %q", got, tt.wantSource)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
10
envconfig/test_home_test.go
Normal file
10
envconfig/test_home_test.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package envconfig
|
||||
|
||||
import "testing"
|
||||
|
||||
func setTestHome(t *testing.T, home string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
ReloadServerConfig()
|
||||
}
|
||||
25
internal/cloud/policy.go
Normal file
25
internal/cloud/policy.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cloud
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
const DisabledMessagePrefix = "ollama cloud is disabled"
|
||||
|
||||
// Status returns whether cloud is disabled and the source of the decision.
|
||||
// Source is one of: "none", "env", "config", "both".
|
||||
func Status() (disabled bool, source string) {
|
||||
return envconfig.NoCloud(), envconfig.NoCloudSource()
|
||||
}
|
||||
|
||||
func Disabled() bool {
|
||||
return envconfig.NoCloud()
|
||||
}
|
||||
|
||||
func DisabledError(operation string) string {
|
||||
if operation == "" {
|
||||
return DisabledMessagePrefix
|
||||
}
|
||||
|
||||
return DisabledMessagePrefix + ": " + operation
|
||||
}
|
||||
85
internal/cloud/policy_test.go
Normal file
85
internal/cloud/policy_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package cloud
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
configContent string
|
||||
disabled bool
|
||||
source string
|
||||
}{
|
||||
{
|
||||
name: "none",
|
||||
disabled: false,
|
||||
source: "none",
|
||||
},
|
||||
{
|
||||
name: "env only",
|
||||
envValue: "1",
|
||||
disabled: true,
|
||||
source: "env",
|
||||
},
|
||||
{
|
||||
name: "config only",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
disabled: true,
|
||||
source: "config",
|
||||
},
|
||||
{
|
||||
name: "both",
|
||||
envValue: "1",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
disabled: true,
|
||||
source: "both",
|
||||
},
|
||||
{
|
||||
name: "invalid config ignored",
|
||||
configContent: `{invalid json`,
|
||||
disabled: false,
|
||||
source: "none",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
if tt.configContent != "" {
|
||||
configPath := filepath.Join(home, ".ollama", "server.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(configPath, []byte(tt.configContent), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
setTestHome(t, home)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
|
||||
|
||||
disabled, source := Status()
|
||||
if disabled != tt.disabled {
|
||||
t.Fatalf("disabled: expected %v, got %v", tt.disabled, disabled)
|
||||
}
|
||||
if source != tt.source {
|
||||
t.Fatalf("source: expected %q, got %q", tt.source, source)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisabledError(t *testing.T) {
|
||||
if got := DisabledError(""); got != DisabledMessagePrefix {
|
||||
t.Fatalf("expected %q, got %q", DisabledMessagePrefix, got)
|
||||
}
|
||||
|
||||
want := DisabledMessagePrefix + ": remote inference is unavailable"
|
||||
if got := DisabledError("remote inference is unavailable"); got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
14
internal/cloud/test_home_test.go
Normal file
14
internal/cloud/test_home_test.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package cloud
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func setTestHome(t *testing.T, home string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
envconfig.ReloadServerConfig()
|
||||
}
|
||||
@@ -2,15 +2,22 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
|
||||
@@ -18,7 +25,6 @@ type AnthropicWriter struct {
|
||||
BaseWriter
|
||||
stream bool
|
||||
id string
|
||||
model string
|
||||
converter *anthropic.StreamConverter
|
||||
}
|
||||
|
||||
@@ -31,7 +37,7 @@ func (w *AnthropicWriter) writeError(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
|
||||
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.Status(), errData.Error))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -40,18 +46,7 @@ func (w *AnthropicWriter) writeError(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
return writeSSE(w.ResponseWriter, eventType, data)
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
@@ -65,6 +60,7 @@ func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
events := w.converter.Process(chatResponse)
|
||||
logutil.Trace("anthropic middleware: stream chunk", "resp", anthropic.TraceChatResponse(chatResponse), "events", len(events))
|
||||
for _, event := range events {
|
||||
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
||||
return 0, err
|
||||
@@ -75,6 +71,7 @@ func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := anthropic.ToMessagesResponse(w.id, chatResponse)
|
||||
logutil.Trace("anthropic middleware: converted response", "resp", anthropic.TraceMessagesResponse(response))
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
@@ -87,9 +84,743 @@ func (w *AnthropicWriter) Write(data []byte) (int, error) {
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
// WebSearchAnthropicWriter intercepts responses containing web_search tool calls,
|
||||
// executes the search, re-invokes the model with results, and assembles the
|
||||
// Anthropic-format response (server_tool_use + web_search_tool_result + text).
|
||||
type WebSearchAnthropicWriter struct {
|
||||
BaseWriter
|
||||
newLoopContext func() (context.Context, context.CancelFunc)
|
||||
inner *AnthropicWriter
|
||||
req anthropic.MessagesRequest // original Anthropic request
|
||||
chatReq *api.ChatRequest // converted Ollama request (for followup calls)
|
||||
stream bool
|
||||
|
||||
estimatedInputTokens int
|
||||
|
||||
terminalSent bool
|
||||
|
||||
observedPromptEvalCount int
|
||||
observedEvalCount int
|
||||
|
||||
loopInFlight bool
|
||||
loopBaseInputTok int
|
||||
loopBaseOutputTok int
|
||||
loopResultCh chan webSearchLoopResult
|
||||
|
||||
streamMessageStarted bool
|
||||
streamHasOpenBlock bool
|
||||
streamOpenBlockIndex int
|
||||
streamNextIndex int
|
||||
}
|
||||
|
||||
const maxWebSearchLoops = 3
|
||||
|
||||
type webSearchLoopResult struct {
|
||||
response anthropic.MessagesResponse
|
||||
loopErr *webSearchLoopError
|
||||
}
|
||||
|
||||
type webSearchLoopError struct {
|
||||
code string
|
||||
query string
|
||||
usage anthropic.Usage
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *webSearchLoopError) Error() string {
|
||||
if e.err == nil {
|
||||
return e.code
|
||||
}
|
||||
return fmt.Sprintf("%s: %v", e.code, e.err)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) Write(data []byte) (int, error) {
|
||||
if w.terminalSent {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
code := w.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.inner.writeError(data)
|
||||
}
|
||||
|
||||
var chatResponse api.ChatResponse
|
||||
if err := json.Unmarshal(data, &chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.recordObservedUsage(chatResponse.Metrics)
|
||||
|
||||
if w.stream && w.loopInFlight {
|
||||
if !chatResponse.Done {
|
||||
return len(data), nil
|
||||
}
|
||||
if err := w.writeLoopResult(); err != nil {
|
||||
return len(data), err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
webSearchCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(chatResponse.Message.ToolCalls)
|
||||
logutil.Trace("anthropic middleware: upstream chunk",
|
||||
"resp", anthropic.TraceChatResponse(chatResponse),
|
||||
"web_search", hasWebSearch,
|
||||
"other_tools", hasOtherTools,
|
||||
)
|
||||
if hasWebSearch && hasOtherTools {
|
||||
// Prefer web_search if both server and client tools are present in one chunk.
|
||||
slog.Debug("preferring web_search tool call over client tool calls in mixed tool response")
|
||||
}
|
||||
|
||||
if !hasWebSearch {
|
||||
if w.stream {
|
||||
if err := w.writePassthroughStreamChunk(chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
return w.inner.writeResponse(data)
|
||||
}
|
||||
|
||||
if w.stream {
|
||||
// Let the original generation continue to completion while web search runs in parallel.
|
||||
logutil.Trace("anthropic middleware: starting async web_search loop",
|
||||
"tool_call", anthropic.TraceToolCall(webSearchCall),
|
||||
"resp", anthropic.TraceChatResponse(chatResponse),
|
||||
)
|
||||
w.startLoopWorker(chatResponse, webSearchCall)
|
||||
if chatResponse.Done {
|
||||
if err := w.writeLoopResult(); err != nil {
|
||||
return len(data), err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
loopCtx, cancel := w.startLoopContext()
|
||||
defer cancel()
|
||||
|
||||
initialUsage := anthropic.Usage{
|
||||
InputTokens: max(w.observedPromptEvalCount, chatResponse.Metrics.PromptEvalCount),
|
||||
OutputTokens: max(w.observedEvalCount, chatResponse.Metrics.EvalCount),
|
||||
}
|
||||
logutil.Trace("anthropic middleware: starting sync web_search loop",
|
||||
"tool_call", anthropic.TraceToolCall(webSearchCall),
|
||||
"resp", anthropic.TraceChatResponse(chatResponse),
|
||||
"usage", initialUsage,
|
||||
)
|
||||
response, loopErr := w.runWebSearchLoop(loopCtx, chatResponse, webSearchCall, initialUsage)
|
||||
if loopErr != nil {
|
||||
return len(data), w.sendError(loopErr.code, loopErr.query, loopErr.usage)
|
||||
}
|
||||
|
||||
if err := w.writeTerminalResponse(response); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initialResponse api.ChatResponse, initialToolCall api.ToolCall, initialUsage anthropic.Usage) (anthropic.MessagesResponse, *webSearchLoopError) {
|
||||
followUpMessages := make([]api.Message, 0, len(w.chatReq.Messages)+maxWebSearchLoops*2)
|
||||
followUpMessages = append(followUpMessages, w.chatReq.Messages...)
|
||||
|
||||
followUpTools := append(api.Tools(nil), w.chatReq.Tools...)
|
||||
usage := initialUsage
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop init",
|
||||
"model", w.req.Model,
|
||||
"tool_call", anthropic.TraceToolCall(initialToolCall),
|
||||
"messages", len(followUpMessages),
|
||||
"tools", len(followUpTools),
|
||||
"max_loops", maxWebSearchLoops,
|
||||
)
|
||||
|
||||
currentResponse := initialResponse
|
||||
currentToolCall := initialToolCall
|
||||
|
||||
var serverContent []anthropic.ContentBlock
|
||||
|
||||
if !isCloudModelName(w.req.Model) {
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search execution blocked", "reason", "non_cloud_model")
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "web_search_not_supported_for_local_models",
|
||||
query: extractQueryFromToolCall(&initialToolCall),
|
||||
usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
for loop := 1; loop <= maxWebSearchLoops; loop++ {
|
||||
query := extractQueryFromToolCall(¤tToolCall)
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop iteration",
|
||||
"loop", loop,
|
||||
"query", anthropic.TraceTruncateString(query),
|
||||
"messages", len(followUpMessages),
|
||||
)
|
||||
if query == "" {
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "invalid_request",
|
||||
query: "",
|
||||
usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
const defaultMaxResults = 5
|
||||
searchResp, err := anthropic.WebSearch(ctx, query, defaultMaxResults)
|
||||
if err != nil {
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search request failed",
|
||||
"loop", loop,
|
||||
"query", query,
|
||||
"error", err,
|
||||
)
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "unavailable",
|
||||
query: query,
|
||||
usage: usage,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search results",
|
||||
"loop", loop,
|
||||
"results", len(searchResp.Results),
|
||||
)
|
||||
|
||||
toolUseID := loopServerToolUseID(w.inner.id, loop)
|
||||
searchResults := anthropic.ConvertOllamaToAnthropicResults(searchResp)
|
||||
serverContent = append(serverContent,
|
||||
anthropic.ContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: map[string]any{"query": query},
|
||||
},
|
||||
anthropic.ContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: searchResults,
|
||||
},
|
||||
)
|
||||
|
||||
assistantMsg := buildWebSearchAssistantMessage(currentResponse, currentToolCall)
|
||||
toolResultMsg := api.Message{
|
||||
Role: "tool",
|
||||
Content: formatWebSearchResultsForToolMessage(searchResp.Results),
|
||||
ToolCallID: currentToolCall.ID,
|
||||
}
|
||||
followUpMessages = append(followUpMessages, assistantMsg, toolResultMsg)
|
||||
|
||||
followUpResponse, err := w.callFollowUpChat(ctx, followUpMessages, followUpTools)
|
||||
if err != nil {
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup /api/chat failed",
|
||||
"loop", loop,
|
||||
"query", query,
|
||||
"error", err,
|
||||
)
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "api_error",
|
||||
query: query,
|
||||
usage: usage,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup response",
|
||||
"loop", loop,
|
||||
"resp", anthropic.TraceChatResponse(followUpResponse),
|
||||
)
|
||||
|
||||
usage.InputTokens += followUpResponse.Metrics.PromptEvalCount
|
||||
usage.OutputTokens += followUpResponse.Metrics.EvalCount
|
||||
|
||||
nextToolCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(followUpResponse.Message.ToolCalls)
|
||||
if hasWebSearch && hasOtherTools {
|
||||
// Prefer web_search if both server and client tools are present in one chunk.
|
||||
slog.Debug("preferring web_search tool call over client tool calls in mixed followup response")
|
||||
}
|
||||
|
||||
if !hasWebSearch {
|
||||
finalResponse := w.combineServerAndFinalContent(serverContent, followUpResponse, usage)
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop complete",
|
||||
"loop", loop,
|
||||
"resp", anthropic.TraceMessagesResponse(finalResponse),
|
||||
)
|
||||
return finalResponse, nil
|
||||
}
|
||||
|
||||
currentResponse = followUpResponse
|
||||
currentToolCall = nextToolCall
|
||||
}
|
||||
|
||||
maxLoopQuery := extractQueryFromToolCall(¤tToolCall)
|
||||
maxLoopToolUseID := loopServerToolUseID(w.inner.id, maxWebSearchLoops+1)
|
||||
serverContent = append(serverContent,
|
||||
anthropic.ContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: maxLoopToolUseID,
|
||||
Name: "web_search",
|
||||
Input: map[string]any{"query": maxLoopQuery},
|
||||
},
|
||||
anthropic.ContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: maxLoopToolUseID,
|
||||
Content: anthropic.WebSearchToolResultError{
|
||||
Type: "web_search_tool_result_error",
|
||||
ErrorCode: "max_uses_exceeded",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
maxResponse := anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: serverContent,
|
||||
StopReason: "end_turn",
|
||||
Usage: usage,
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop max reached",
|
||||
"resp", anthropic.TraceMessagesResponse(maxResponse),
|
||||
)
|
||||
return maxResponse, nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) startLoopWorker(initialResponse api.ChatResponse, initialToolCall api.ToolCall) {
|
||||
if w.loopInFlight {
|
||||
return
|
||||
}
|
||||
|
||||
initialUsage := anthropic.Usage{
|
||||
InputTokens: max(w.observedPromptEvalCount, initialResponse.Metrics.PromptEvalCount),
|
||||
OutputTokens: max(w.observedEvalCount, initialResponse.Metrics.EvalCount),
|
||||
}
|
||||
w.loopBaseInputTok = initialUsage.InputTokens
|
||||
w.loopBaseOutputTok = initialUsage.OutputTokens
|
||||
w.loopResultCh = make(chan webSearchLoopResult, 1)
|
||||
w.loopInFlight = true
|
||||
logutil.Trace("anthropic middleware: loop worker started",
|
||||
"usage", initialUsage,
|
||||
"tool_call", anthropic.TraceToolCall(initialToolCall),
|
||||
)
|
||||
|
||||
go func() {
|
||||
ctx, cancel := w.startLoopContext()
|
||||
defer cancel()
|
||||
|
||||
response, loopErr := w.runWebSearchLoop(ctx, initialResponse, initialToolCall, initialUsage)
|
||||
w.loopResultCh <- webSearchLoopResult{
|
||||
response: response,
|
||||
loopErr: loopErr,
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writeLoopResult() error {
|
||||
if w.loopResultCh == nil {
|
||||
return w.sendError("api_error", "", w.currentObservedUsage())
|
||||
}
|
||||
|
||||
result := <-w.loopResultCh
|
||||
w.loopResultCh = nil
|
||||
w.loopInFlight = false
|
||||
if result.loopErr != nil {
|
||||
logutil.Trace("anthropic middleware: loop worker returned error",
|
||||
"code", result.loopErr.code,
|
||||
"query", result.loopErr.query,
|
||||
"usage", result.loopErr.usage,
|
||||
"error", result.loopErr.err,
|
||||
)
|
||||
usage := result.loopErr.usage
|
||||
w.applyObservedUsageDeltaToUsage(&usage)
|
||||
return w.sendError(result.loopErr.code, result.loopErr.query, usage)
|
||||
}
|
||||
logutil.Trace("anthropic middleware: loop worker done", "resp", anthropic.TraceMessagesResponse(result.response))
|
||||
|
||||
w.applyObservedUsageDelta(&result.response)
|
||||
return w.writeTerminalResponse(result.response)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) applyObservedUsageDelta(response *anthropic.MessagesResponse) {
|
||||
w.applyObservedUsageDeltaToUsage(&response.Usage)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) recordObservedUsage(metrics api.Metrics) {
|
||||
if metrics.PromptEvalCount > w.observedPromptEvalCount {
|
||||
w.observedPromptEvalCount = metrics.PromptEvalCount
|
||||
}
|
||||
if metrics.EvalCount > w.observedEvalCount {
|
||||
w.observedEvalCount = metrics.EvalCount
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) applyObservedUsageDeltaToUsage(usage *anthropic.Usage) {
|
||||
if deltaIn := w.observedPromptEvalCount - w.loopBaseInputTok; deltaIn > 0 {
|
||||
usage.InputTokens += deltaIn
|
||||
}
|
||||
if deltaOut := w.observedEvalCount - w.loopBaseOutputTok; deltaOut > 0 {
|
||||
usage.OutputTokens += deltaOut
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) currentObservedUsage() anthropic.Usage {
|
||||
return anthropic.Usage{
|
||||
InputTokens: w.observedPromptEvalCount,
|
||||
OutputTokens: w.observedEvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) startLoopContext() (context.Context, context.CancelFunc) {
|
||||
if w.newLoopContext != nil {
|
||||
return w.newLoopContext()
|
||||
}
|
||||
return context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) combineServerAndFinalContent(serverContent []anthropic.ContentBlock, finalResponse api.ChatResponse, usage anthropic.Usage) anthropic.MessagesResponse {
|
||||
converted := anthropic.ToMessagesResponse(w.inner.id, finalResponse)
|
||||
|
||||
content := make([]anthropic.ContentBlock, 0, len(serverContent)+len(converted.Content))
|
||||
content = append(content, serverContent...)
|
||||
content = append(content, converted.Content...)
|
||||
|
||||
return anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: content,
|
||||
StopReason: converted.StopReason,
|
||||
StopSequence: converted.StopSequence,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
func buildWebSearchAssistantMessage(response api.ChatResponse, webSearchCall api.ToolCall) api.Message {
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{webSearchCall},
|
||||
}
|
||||
if response.Message.Content != "" {
|
||||
assistantMsg.Content = response.Message.Content
|
||||
}
|
||||
if response.Message.Thinking != "" {
|
||||
assistantMsg.Thinking = response.Message.Thinking
|
||||
}
|
||||
return assistantMsg
|
||||
}
|
||||
|
||||
func formatWebSearchResultsForToolMessage(results []anthropic.OllamaWebSearchResult) string {
|
||||
var resultText strings.Builder
|
||||
for _, r := range results {
|
||||
fmt.Fprintf(&resultText, "Title: %s\nURL: %s\n", r.Title, r.URL)
|
||||
if r.Content != "" {
|
||||
fmt.Fprintf(&resultText, "Content: %s\n", r.Content)
|
||||
}
|
||||
resultText.WriteString("\n")
|
||||
}
|
||||
return resultText.String()
|
||||
}
|
||||
|
||||
func findWebSearchToolCall(toolCalls []api.ToolCall) (api.ToolCall, bool, bool) {
|
||||
var webSearchCall api.ToolCall
|
||||
hasWebSearch := false
|
||||
hasOtherTools := false
|
||||
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Function.Name == "web_search" {
|
||||
if !hasWebSearch {
|
||||
webSearchCall = toolCall
|
||||
hasWebSearch = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
hasOtherTools = true
|
||||
}
|
||||
|
||||
return webSearchCall, hasWebSearch, hasOtherTools
|
||||
}
|
||||
|
||||
func loopServerToolUseID(messageID string, loop int) string {
|
||||
base := serverToolUseID(messageID)
|
||||
if loop <= 1 {
|
||||
return base
|
||||
}
|
||||
return fmt.Sprintf("%s_%d", base, loop)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) callFollowUpChat(ctx context.Context, messages []api.Message, tools api.Tools) (api.ChatResponse, error) {
|
||||
streaming := false
|
||||
followUp := api.ChatRequest{
|
||||
Model: w.chatReq.Model,
|
||||
Messages: messages,
|
||||
Stream: &streaming,
|
||||
Tools: tools,
|
||||
Options: w.chatReq.Options,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(followUp)
|
||||
if err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
|
||||
chatURL := envconfig.Host().String() + "/api/chat"
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup request",
|
||||
"url", chatURL,
|
||||
"req", anthropic.TraceChatRequest(&followUp),
|
||||
)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", chatURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup non-200 response",
|
||||
"status", resp.StatusCode,
|
||||
"response", strings.TrimSpace(string(respBody)),
|
||||
)
|
||||
return api.ChatResponse{}, fmt.Errorf("followup /api/chat returned status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
||||
}
|
||||
|
||||
var chatResp api.ChatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup decoded", "resp", anthropic.TraceChatResponse(chatResp))
|
||||
|
||||
return chatResp, nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writePassthroughStreamChunk(chatResponse api.ChatResponse) error {
|
||||
events := w.inner.converter.Process(chatResponse)
|
||||
for _, event := range events {
|
||||
switch e := event.Data.(type) {
|
||||
case anthropic.MessageStartEvent:
|
||||
w.streamMessageStarted = true
|
||||
case anthropic.ContentBlockStartEvent:
|
||||
w.streamHasOpenBlock = true
|
||||
w.streamOpenBlockIndex = e.Index
|
||||
if e.Index+1 > w.streamNextIndex {
|
||||
w.streamNextIndex = e.Index + 1
|
||||
}
|
||||
case anthropic.ContentBlockStopEvent:
|
||||
if w.streamHasOpenBlock && w.streamOpenBlockIndex == e.Index {
|
||||
w.streamHasOpenBlock = false
|
||||
}
|
||||
if e.Index+1 > w.streamNextIndex {
|
||||
w.streamNextIndex = e.Index + 1
|
||||
}
|
||||
case anthropic.MessageStopEvent:
|
||||
w.terminalSent = true
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, event.Event, event.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) ensureStreamMessageStart(usage anthropic.Usage) error {
|
||||
if w.streamMessageStarted {
|
||||
return nil
|
||||
}
|
||||
|
||||
inputTokens := usage.InputTokens
|
||||
if inputTokens == 0 {
|
||||
inputTokens = w.estimatedInputTokens
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "message_start", anthropic.MessageStartEvent{
|
||||
Type: "message_start",
|
||||
Message: anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: []anthropic.ContentBlock{},
|
||||
Usage: anthropic.Usage{
|
||||
InputTokens: inputTokens,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.streamMessageStarted = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) closeOpenStreamBlock() error {
|
||||
if !w.streamHasOpenBlock {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: w.streamOpenBlockIndex,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if w.streamOpenBlockIndex+1 > w.streamNextIndex {
|
||||
w.streamNextIndex = w.streamOpenBlockIndex + 1
|
||||
}
|
||||
w.streamHasOpenBlock = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writeStreamContentBlocks(content []anthropic.ContentBlock) error {
|
||||
for _, block := range content {
|
||||
index := w.streamNextIndex
|
||||
if block.Type == "text" {
|
||||
emptyText := ""
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: index,
|
||||
ContentBlock: anthropic.ContentBlock{
|
||||
Type: "text",
|
||||
Text: &emptyText,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
text := ""
|
||||
if block.Text != nil {
|
||||
text = *block.Text
|
||||
}
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_delta", anthropic.ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: index,
|
||||
Delta: anthropic.Delta{
|
||||
Type: "text_delta",
|
||||
Text: text,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: index,
|
||||
ContentBlock: block,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: index,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.streamNextIndex++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writeTerminalResponse(response anthropic.MessagesResponse) error {
|
||||
if w.terminalSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !w.stream {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w.ResponseWriter).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
w.terminalSent = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := w.ensureStreamMessageStart(response.Usage); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.closeOpenStreamBlock(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.writeStreamContentBlocks(response.Content); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "message_delta", anthropic.MessageDeltaEvent{
|
||||
Type: "message_delta",
|
||||
Delta: anthropic.MessageDelta{
|
||||
StopReason: response.StopReason,
|
||||
},
|
||||
Usage: anthropic.DeltaUsage{
|
||||
InputTokens: response.Usage.InputTokens,
|
||||
OutputTokens: response.Usage.OutputTokens,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "message_stop", anthropic.MessageStopEvent{
|
||||
Type: "message_stop",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.terminalSent = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamResponse emits a complete MessagesResponse as SSE events.
|
||||
func (w *WebSearchAnthropicWriter) streamResponse(response anthropic.MessagesResponse) error {
|
||||
return w.writeTerminalResponse(response)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) webSearchErrorResponse(errorCode, query string, usage anthropic.Usage) anthropic.MessagesResponse {
|
||||
toolUseID := serverToolUseID(w.inner.id)
|
||||
|
||||
return anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: []anthropic.ContentBlock{
|
||||
{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: map[string]any{"query": query},
|
||||
},
|
||||
{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: anthropic.WebSearchToolResultError{
|
||||
Type: "web_search_tool_result_error",
|
||||
ErrorCode: errorCode,
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: "end_turn",
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
// sendError sends a web search error response.
|
||||
func (w *WebSearchAnthropicWriter) sendError(errorCode, query string, usage anthropic.Usage) error {
|
||||
response := w.webSearchErrorResponse(errorCode, query, usage)
|
||||
logutil.Trace("anthropic middleware: web_search error", "code", errorCode, "query", query, "usage", usage)
|
||||
return w.writeTerminalResponse(response)
|
||||
}
|
||||
|
||||
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
|
||||
func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
requestCtx := c.Request.Context()
|
||||
|
||||
var req anthropic.MessagesRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
@@ -134,11 +865,10 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
// Estimate input tokens for streaming (actual count not available until generation completes)
|
||||
estimatedTokens := anthropic.EstimateInputTokens(req)
|
||||
|
||||
w := &AnthropicWriter{
|
||||
innerWriter := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
|
||||
}
|
||||
|
||||
@@ -148,8 +878,78 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
if hasWebSearchTool(req.Tools) {
|
||||
// Guard against runtime cloud-disable policy (OLLAMA_NO_CLOUD/server.json)
|
||||
// for cloud models. Local models may still receive web_search tool definitions;
|
||||
// execution is validated when the model actually emits a web_search tool call.
|
||||
if isCloudModelName(req.Model) {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, anthropic.NewError(http.StatusForbidden, internalcloud.DisabledError("web search is unavailable")))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Writer = &WebSearchAnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
newLoopContext: func() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(requestCtx, 5*time.Minute)
|
||||
},
|
||||
inner: innerWriter,
|
||||
req: req,
|
||||
chatReq: chatReq,
|
||||
stream: req.Stream,
|
||||
estimatedInputTokens: estimatedTokens,
|
||||
}
|
||||
} else {
|
||||
c.Writer = innerWriter
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// hasWebSearchTool checks if the request tools include a web_search tool
|
||||
func hasWebSearchTool(tools []anthropic.Tool) bool {
|
||||
for _, tool := range tools {
|
||||
if strings.HasPrefix(tool.Type, "web_search") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isCloudModelName(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
}
|
||||
|
||||
// extractQueryFromToolCall extracts the search query from a web_search tool call
|
||||
func extractQueryFromToolCall(tc *api.ToolCall) string {
|
||||
q, ok := tc.Function.Arguments.Get("query")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
if s, ok := q.(string); ok {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// writeSSE writes a Server-Sent Event
|
||||
func writeSSE(w http.ResponseWriter, eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, d); err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// serverToolUseID derives a server tool use ID from a message ID
|
||||
func serverToolUseID(messageID string) string {
|
||||
return "srvtoolu_" + strings.TrimPrefix(messageID, "msg_")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
22
middleware/test_home_test.go
Normal file
22
middleware/test_home_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func setTestHome(t *testing.T, home string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
envconfig.ReloadServerConfig()
|
||||
}
|
||||
|
||||
// enableCloudForTest sets HOME to a clean temp dir and clears OLLAMA_NO_CLOUD
|
||||
// so that cloud features are enabled for the duration of the test.
|
||||
func enableCloudForTest(t *testing.T) {
|
||||
t.Helper()
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "")
|
||||
setTestHome(t, t.TempDir())
|
||||
}
|
||||
@@ -45,6 +45,10 @@ func ParserForName(name string) Parser {
|
||||
var p Parser
|
||||
|
||||
switch name {
|
||||
case "qwen3":
|
||||
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
case "qwen3-thinking":
|
||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
case "qwen3-coder":
|
||||
p = &Qwen3CoderParser{}
|
||||
case "qwen3-vl-instruct":
|
||||
|
||||
@@ -54,6 +54,8 @@ func TestBuiltInParsersStillWork(t *testing.T) {
|
||||
name string
|
||||
}{
|
||||
{"passthrough"},
|
||||
{"qwen3"},
|
||||
{"qwen3-thinking"},
|
||||
{"qwen3-coder"},
|
||||
{"harmony"},
|
||||
}
|
||||
|
||||
335
model/parsers/qwen3.go
Normal file
335
model/parsers/qwen3.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type qwen3ParserState int
|
||||
|
||||
const (
|
||||
qwen3ParserStateLookingForThinkingOpen qwen3ParserState = iota
|
||||
qwen3ParserStateThinkingStartedEatingWhitespace
|
||||
qwen3ParserStateCollectingThinking
|
||||
qwen3ParserStateThinkingDoneEatingWhitespace
|
||||
qwen3ParserStateCollectingContent
|
||||
qwen3ParserStateToolStartedEatingWhitespace
|
||||
qwen3ParserStateCollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
qwen3ThinkingOpenTag = "<think>"
|
||||
qwen3ThinkingCloseTag = "</think>"
|
||||
qwen3ToolOpenTag = "<tool_call>"
|
||||
qwen3ToolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
// Qwen3Parser parses Qwen3 output to extract thinking and tool calls.
|
||||
// Qwen3 prompts end with <think> when thinking is enabled, so output begins
|
||||
// with thinking content directly (without an opening tag).
|
||||
type Qwen3Parser struct {
|
||||
state qwen3ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
hasThinkingSupport bool
|
||||
defaultThinking bool
|
||||
maybeThinkingOpenAtBOL bool
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.buffer.Reset()
|
||||
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
if thinkValue == nil {
|
||||
thinkingEnabled = p.defaultThinking
|
||||
}
|
||||
|
||||
if p.hasThinkingSupport && thinkingEnabled {
|
||||
p.state = qwen3ParserStateCollectingThinking
|
||||
p.maybeThinkingOpenAtBOL = true
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
type qwen3Event interface {
|
||||
isQwen3Event()
|
||||
}
|
||||
|
||||
type qwen3EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwen3EventContent) isQwen3Event() {}
|
||||
|
||||
type qwen3EventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
func (qwen3EventRawToolCall) isQwen3Event() {}
|
||||
|
||||
type qwen3EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwen3EventThinkingContent) isQwen3Event() {}
|
||||
|
||||
func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case qwen3EventRawToolCall:
|
||||
toolCall, err := parseQwen3ToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
calls = append(calls, toolCall)
|
||||
case qwen3EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case qwen3EventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), calls, nil
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) parseEvents() []qwen3Event {
|
||||
var all []qwen3Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []qwen3Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen3ParserState) ([]qwen3Event, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||
return splitAtTag(&p.buffer, tag, trimAfter)
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||
var events []qwen3Event
|
||||
|
||||
switch p.state {
|
||||
case qwen3ParserStateLookingForThinkingOpen:
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
|
||||
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
p.state = qwen3ParserStateThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingThinking
|
||||
}
|
||||
return events, true
|
||||
} else if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
|
||||
return events, false
|
||||
} else if trimmed == "" {
|
||||
return events, false
|
||||
}
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
return events, true
|
||||
|
||||
case qwen3ParserStateThinkingStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingThinking)
|
||||
|
||||
case qwen3ParserStateCollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
|
||||
// Some qwen3 checkpoints emit an explicit opening <think> tag even
|
||||
// though the prompt already ended with <think>. Strip exactly one
|
||||
// leading opening tag if present.
|
||||
if p.maybeThinkingOpenAtBOL {
|
||||
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
|
||||
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
return events, false
|
||||
}
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
return events, true
|
||||
}
|
||||
if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
|
||||
return events, false
|
||||
}
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
}
|
||||
|
||||
if strings.Contains(acc, qwen3ThinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = qwen3ParserStateThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case qwen3ParserStateThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingContent)
|
||||
|
||||
case qwen3ParserStateCollectingContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, qwen3ToolOpenTag) {
|
||||
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwen3EventContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = qwen3ParserStateToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, qwen3ToolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case qwen3ParserStateToolStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingToolContent)
|
||||
|
||||
case qwen3ParserStateCollectingToolContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, qwen3ToolCloseTag) {
|
||||
toolContent, _ := p.splitAtTag(qwen3ToolCloseTag, true)
|
||||
if len(toolContent) == 0 {
|
||||
slog.Warn("qwen3 tool call closing tag found but no content before it")
|
||||
}
|
||||
events = append(events, qwen3EventRawToolCall{raw: toolContent})
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
return events, true
|
||||
}
|
||||
return events, false
|
||||
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
var parsed struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse JSON: %w", err)
|
||||
}
|
||||
|
||||
if parsed.Name == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||
}
|
||||
|
||||
_ = tools // qwen3 uses direct JSON args and does not require schema coercion here.
|
||||
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: parsed.Name,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
|
||||
for key, value := range parsed.Arguments {
|
||||
toolCall.Function.Arguments.Set(key, value)
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
147
model/parsers/qwen3_test.go
Normal file
147
model/parsers/qwen3_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQwen3ParserThinkingEnabled(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Let me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingEnabledWithExplicitOpeningTag(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingEnabledWithSplitOpeningTag(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("<thi", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on first chunk: %v", err)
|
||||
}
|
||||
if content != "" || thinking != "" || len(calls) != 0 {
|
||||
t.Fatalf("expected no output for first chunk, got content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("nk>Let me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on second chunk: %v", err)
|
||||
}
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingDisabled(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Direct answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Direct answer" {
|
||||
t.Fatalf("expected content %q, got %q", "Direct answer", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserNilThinkDefaultsToContentForInstructParser(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add("Direct answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Direct answer" {
|
||||
t.Fatalf("expected content %q, got %q", "Direct answer", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCall(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
input := "<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
|
||||
location, ok := calls[0].Function.Arguments.Get("location")
|
||||
if !ok || location != "San Francisco" {
|
||||
t.Fatalf("expected location %q, got %v", "San Francisco", location)
|
||||
}
|
||||
unit, ok := calls[0].Function.Arguments.Get("unit")
|
||||
if !ok || unit != "celsius" {
|
||||
t.Fatalf("expected unit %q, got %v", "celsius", unit)
|
||||
}
|
||||
}
|
||||
@@ -115,6 +115,15 @@ func (s *store) saveLocked() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read existing file into a generic map to preserve unknown fields
|
||||
// (e.g. disable_ollama_cloud) that aliasStore doesn't own.
|
||||
existing := make(map[string]json.RawMessage)
|
||||
if data, err := os.ReadFile(s.path); err == nil {
|
||||
if err := json.Unmarshal(data, &existing); err != nil {
|
||||
slog.Debug("failed to parse existing server config; preserving unknown fields skipped", "path", s.path, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Combine exact and prefix entries
|
||||
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
|
||||
for _, entry := range s.entries {
|
||||
@@ -126,10 +135,17 @@ func (s *store) saveLocked() error {
|
||||
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
|
||||
})
|
||||
|
||||
cfg := serverConfig{
|
||||
Version: serverConfigVersion,
|
||||
Aliases: entries,
|
||||
// Overwrite only the keys we own
|
||||
versionJSON, err := json.Marshal(serverConfigVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
aliasesJSON, err := json.Marshal(entries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
existing["version"] = versionJSON
|
||||
existing["aliases"] = aliasesJSON
|
||||
|
||||
f, err := os.CreateTemp(dir, "router-*.json")
|
||||
if err != nil {
|
||||
@@ -138,7 +154,7 @@ func (s *store) saveLocked() error {
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(cfg); err != nil {
|
||||
if err := enc.Encode(existing); err != nil {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
|
||||
@@ -38,6 +38,7 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
@@ -58,6 +59,11 @@ import (
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
|
||||
const (
|
||||
cloudErrRemoteInferenceUnavailable = "remote model is unavailable"
|
||||
cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
|
||||
)
|
||||
|
||||
func shouldUseHarmony(model *Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
@@ -144,12 +150,15 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
||||
}
|
||||
|
||||
useImagegen, _ := requestOpts["use_imagegen_runner"].(bool)
|
||||
delete(requestOpts, "use_imagegen_runner")
|
||||
|
||||
opts, err := s.modelOptions(model, requestOpts)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive, useImagegen)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
@@ -229,6 +238,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
|
||||
return
|
||||
}
|
||||
|
||||
origModel := req.Model
|
||||
|
||||
remoteURL, err := url.Parse(m.Config.RemoteHost)
|
||||
@@ -1066,9 +1080,12 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
|
||||
resp, err := GetModelInfo(req)
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
case errors.As(err, &statusErr):
|
||||
c.JSON(statusErr.StatusCode, gin.H{"error": statusErr.ErrorMessage})
|
||||
case err.Error() == errtypes.InvalidModelNameErrMsg:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
@@ -1095,6 +1112,15 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
return nil, api.StatusError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
ErrorMessage: internalcloud.DisabledError(cloudErrRemoteModelDetailsUnavailable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
modelDetails := api.ModelDetails{
|
||||
ParentModel: m.ParentModel,
|
||||
Format: m.Config.ModelFormat,
|
||||
@@ -1571,6 +1597,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
|
||||
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
r.GET("/api/status", s.StatusHandler)
|
||||
|
||||
// Local model cache management (new implementation is at end of function)
|
||||
r.POST("/api/pull", s.PullHandler)
|
||||
@@ -1634,6 +1661,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
func Serve(ln net.Listener) error {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
slog.Info("server config", "env", envconfig.Values())
|
||||
cloudDisabled, _ := internalcloud.Status()
|
||||
slog.Info(fmt.Sprintf("Ollama cloud disabled: %t", cloudDisabled))
|
||||
|
||||
blobsDir, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
@@ -1824,6 +1853,16 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) StatusHandler(c *gin.Context) {
|
||||
disabled, source := internalcloud.Status()
|
||||
c.JSON(http.StatusOK, api.StatusResponse{
|
||||
Cloud: api.CloudStatus{
|
||||
Disabled: disabled,
|
||||
Source: source,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) WhoamiHandler(c *gin.Context) {
|
||||
// todo allow other hosts
|
||||
u, err := url.Parse("https://ollama.com")
|
||||
@@ -2010,6 +2049,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
|
||||
return
|
||||
}
|
||||
|
||||
origModel := req.Model
|
||||
|
||||
remoteURL, err := url.Parse(m.Config.RemoteHost)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
@@ -16,7 +17,7 @@ import (
|
||||
|
||||
func TestAliasShadowingRejected(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
@@ -40,7 +41,7 @@ func TestAliasShadowingRejected(t *testing.T) {
|
||||
|
||||
func TestAliasResolvesForChatRemote(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
var remoteModel string
|
||||
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -256,7 +257,7 @@ func TestPrefixAliasChain(t *testing.T) {
|
||||
|
||||
func TestPrefixAliasCRUD(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
|
||||
@@ -364,7 +365,7 @@ func TestPrefixAliasCaseInsensitive(t *testing.T) {
|
||||
|
||||
func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
|
||||
@@ -424,3 +425,51 @@ func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasSavePreservesCloudDisable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configPath := filepath.Join(tmpDir, ".ollama", "server.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
initial := map[string]any{
|
||||
"version": serverConfigVersion,
|
||||
"disable_ollama_cloud": true,
|
||||
"aliases": []aliasEntry{},
|
||||
}
|
||||
data, err := json.Marshal(initial)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(configPath, data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := Server{}
|
||||
w := createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
updated, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var updatedCfg map[string]json.RawMessage
|
||||
if err := json.Unmarshal(updated, &updatedCfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw, ok := updatedCfg["disable_ollama_cloud"]
|
||||
if !ok {
|
||||
t.Fatal("expected disable_ollama_cloud key to be preserved")
|
||||
}
|
||||
if string(raw) != "true" {
|
||||
t.Fatalf("expected disable_ollama_cloud to remain true, got %s", string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
94
server/routes_cloud_test.go
Normal file
94
server/routes_cloud_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
)
|
||||
|
||||
func TestStatusHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "1")
|
||||
|
||||
s := Server{}
|
||||
w := createRequest(t, s.StatusHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.StatusResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !resp.Cloud.Disabled {
|
||||
t.Fatalf("expected cloud.disabled true, got false")
|
||||
}
|
||||
if resp.Cloud.Source != "env" {
|
||||
t.Fatalf("expected cloud.source env, got %q", resp.Cloud.Source)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudDisabledBlocksRemoteOperations(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "1")
|
||||
|
||||
s := Server{}
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-cloud",
|
||||
RemoteHost: "example.com",
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("chat remote blocked", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-cloud",
|
||||
Messages: []api.Message{{Role: "user", Content: "hi"}},
|
||||
})
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected status 403, got %d", w.Code)
|
||||
}
|
||||
if got := w.Body.String(); got != `{"error":"`+internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)+`"}` {
|
||||
t.Fatalf("unexpected response: %s", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("generate remote blocked", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-cloud",
|
||||
Prompt: "hi",
|
||||
})
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected status 403, got %d", w.Code)
|
||||
}
|
||||
if got := w.Body.String(); got != `{"error":"`+internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)+`"}` {
|
||||
t.Fatalf("unexpected response: %s", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("show remote blocked", func(t *testing.T) {
|
||||
w := createRequest(t, s.ShowHandler, api.ShowRequest{
|
||||
Model: "test-cloud",
|
||||
})
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected status 403, got %d", w.Code)
|
||||
}
|
||||
if got := w.Body.String(); got != `{"error":"`+internalcloud.DisabledError(cloudErrRemoteModelDetailsUnavailable)+`"}` {
|
||||
t.Fatalf("unexpected response: %s", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -2371,29 +2371,6 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: map[string]*runnerRef{
|
||||
"": {
|
||||
llama: &mock,
|
||||
Options: &opts,
|
||||
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create model manifest with image capability
|
||||
n := model.ParseName("test-image")
|
||||
cfg := model.ConfigV2{Capabilities: []string{"image"}}
|
||||
@@ -2409,6 +2386,35 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loadedModel, err := GetModel("test-image")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: map[string]*runnerRef{
|
||||
schedulerModelKey(loadedModel): {
|
||||
llama: &mock,
|
||||
Options: &opts,
|
||||
model: loadedModel,
|
||||
isImagegen: true,
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-image",
|
||||
|
||||
107
server/sched.go
107
server/sched.go
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -32,6 +33,7 @@ type LlmRequest struct {
|
||||
successCh chan *runnerRef
|
||||
errCh chan error
|
||||
schedAttempts uint
|
||||
useImagegen bool
|
||||
}
|
||||
|
||||
type Scheduler struct {
|
||||
@@ -81,8 +83,30 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
||||
return sched
|
||||
}
|
||||
|
||||
// schedulerModelKey returns the scheduler map key for a model.
|
||||
// GGUF-backed models use ModelPath; safetensors/image models without a
|
||||
// ModelPath use manifest digest so distinct models don't collide.
|
||||
func schedulerModelKey(m *Model) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
if m.ModelPath != "" {
|
||||
return m.ModelPath
|
||||
}
|
||||
if m.Digest != "" {
|
||||
return "digest:" + m.Digest
|
||||
}
|
||||
if m.Name != "" {
|
||||
return "name:" + m.Name
|
||||
}
|
||||
if m.ShortName != "" {
|
||||
return "short:" + m.ShortName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
|
||||
if opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
@@ -99,10 +123,12 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
|
||||
sessionDuration: sessionDuration,
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
useImagegen: useImagegen,
|
||||
}
|
||||
|
||||
key := schedulerModelKey(req.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[req.model.ModelPath]
|
||||
runner := s.loaded[key]
|
||||
s.loadedMu.Unlock()
|
||||
if runner != nil && !runner.needsReload(c, req) {
|
||||
req.useLoadedRunner(runner, s.finishedReqCh)
|
||||
@@ -148,8 +174,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
|
||||
for {
|
||||
var runnerToExpire *runnerRef
|
||||
pendingKey := schedulerModelKey(pending.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[pending.model.ModelPath]
|
||||
runner := s.loaded[pendingKey]
|
||||
loadedCount := len(s.loaded)
|
||||
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
@@ -163,7 +190,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
runnerToExpire = runner
|
||||
} else {
|
||||
// Runner is usable, return it
|
||||
logutil.Trace("using existing loaded runner", "model", pending.model.ModelPath)
|
||||
logutil.Trace("using existing loaded runner", "model", pendingKey)
|
||||
pending.useLoadedRunner(runner, s.finishedReqCh)
|
||||
break
|
||||
}
|
||||
@@ -289,11 +316,12 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
slog.Debug("shutting down scheduler completed loop")
|
||||
return
|
||||
case finished := <-s.finishedReqCh:
|
||||
finishedKey := schedulerModelKey(finished.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[finished.model.ModelPath]
|
||||
runner := s.loaded[finishedKey]
|
||||
s.loadedMu.Unlock()
|
||||
if runner == nil {
|
||||
slog.Error("finished request signal received after model unloaded", "modelPath", finished.model.ModelPath)
|
||||
slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
|
||||
continue
|
||||
}
|
||||
runner.refMu.Lock()
|
||||
@@ -344,7 +372,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
|
||||
s.loadedMu.Lock()
|
||||
slog.Debug("got lock to unload expired event", "runner", runner)
|
||||
runnerToUnload := s.loaded[runner.modelPath]
|
||||
runnerToUnload := s.loaded[runner.modelKey]
|
||||
if runnerToUnload == nil {
|
||||
// If runnerToUnload is nil, we already processed an event and
|
||||
// unloaded it. This double unload can happen if the initial
|
||||
@@ -373,7 +401,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
}
|
||||
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
|
||||
runner.unload()
|
||||
delete(s.loaded, runner.modelPath)
|
||||
delete(s.loaded, runner.modelKey)
|
||||
s.loadedMu.Unlock()
|
||||
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
|
||||
<-finished
|
||||
@@ -511,6 +539,7 @@ iGPUScan:
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
modelKey: schedulerModelKey(req.model),
|
||||
llama: llama,
|
||||
Options: &req.opts,
|
||||
sessionDuration: sessionDuration,
|
||||
@@ -525,7 +554,7 @@ iGPUScan:
|
||||
runner.refMu.Lock() // hold lock until running or aborted
|
||||
|
||||
s.loadedMu.Lock()
|
||||
if oldRunner, ok := s.loaded[req.model.ModelPath]; ok {
|
||||
if oldRunner, ok := s.loaded[runner.modelKey]; ok {
|
||||
// Shouldn't happen, but safeguard against leaking a runner
|
||||
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
|
||||
oldRunner.refMu.Lock()
|
||||
@@ -533,7 +562,7 @@ iGPUScan:
|
||||
oldRunner.refMu.Unlock()
|
||||
}
|
||||
s.activeLoading = nil
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loaded[runner.modelKey] = runner
|
||||
slog.Info("loaded runners", "count", len(s.loaded))
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
@@ -566,17 +595,20 @@ iGPUScan:
|
||||
// loadMLX loads an experimental safetensors model using the unified MLX runner.
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
// Determine mode based on capabilities
|
||||
var mode imagegen.ModelMode
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
mode = imagegen.ModeImageGen
|
||||
} else {
|
||||
mode = imagegen.ModeLLM
|
||||
}
|
||||
|
||||
// Use model name for MLX (it resolves manifests by name, not file path)
|
||||
modelName := req.model.ShortName
|
||||
server, err := imagegen.NewServer(modelName, mode)
|
||||
var server llm.LlamaServer
|
||||
var err error
|
||||
|
||||
isImagegen := false
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
server, err = imagegen.NewServer(modelName, imagegen.ModeImageGen)
|
||||
isImagegen = true
|
||||
} else if req.useImagegen {
|
||||
server, err = imagegen.NewServer(modelName, imagegen.ModeLLM)
|
||||
isImagegen = true
|
||||
} else {
|
||||
server, err = mlxrunner.NewClient(modelName)
|
||||
}
|
||||
if err != nil {
|
||||
req.errCh <- err
|
||||
return true
|
||||
@@ -590,16 +622,18 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
modelKey: schedulerModelKey(req.model),
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
isImagegen: isImagegen,
|
||||
sessionDuration: sessionDuration,
|
||||
totalSize: server.TotalSize(),
|
||||
vramSize: server.VRAMSize(),
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loaded[runner.modelKey] = runner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Set up expiration timer
|
||||
@@ -667,6 +701,7 @@ type runnerRef struct {
|
||||
loading bool // True only during initial load, then false forever
|
||||
gpus []ml.DeviceID // Recorded at time of provisioning
|
||||
discreteGPUs bool // True if all devices are discrete GPUs - used to skip VRAM recovery check for iGPUs
|
||||
isImagegen bool // True if loaded via imagegen runner (vs mlxrunner)
|
||||
vramSize uint64
|
||||
totalSize uint64
|
||||
|
||||
@@ -676,6 +711,7 @@ type runnerRef struct {
|
||||
|
||||
model *Model
|
||||
modelPath string
|
||||
modelKey string
|
||||
numParallel int
|
||||
*api.Options
|
||||
}
|
||||
@@ -695,10 +731,16 @@ func (runner *runnerRef) unload() {
|
||||
}
|
||||
|
||||
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
|
||||
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
|
||||
slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
|
||||
// Check if runner type (imagegen vs mlxrunner) matches what's requested
|
||||
wantImagegen := req.useImagegen || slices.Contains(req.model.Config.Capabilities, "image")
|
||||
if runner.isImagegen != wantImagegen {
|
||||
return true
|
||||
}
|
||||
|
||||
timeout := 10 * time.Second
|
||||
if runner.loading {
|
||||
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
|
||||
@@ -800,6 +842,10 @@ func (runner *runnerRef) LogValue() slog.Value {
|
||||
if runner == nil {
|
||||
return slog.StringValue("nil")
|
||||
}
|
||||
modelID := runner.modelPath
|
||||
if modelID == "" {
|
||||
modelID = runner.modelKey
|
||||
}
|
||||
attrs := []slog.Attr{}
|
||||
if runner.model != nil {
|
||||
attrs = append(attrs, slog.String("name", runner.model.Name))
|
||||
@@ -814,7 +860,7 @@ func (runner *runnerRef) LogValue() slog.Value {
|
||||
slog.String("vram", format.HumanBytes2(runner.vramSize)),
|
||||
slog.Int("parallel", runner.numParallel),
|
||||
slog.Int("pid", runner.pid),
|
||||
slog.String("model", runner.modelPath),
|
||||
slog.String("model", modelID),
|
||||
)
|
||||
if runner.Options != nil {
|
||||
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
|
||||
@@ -859,8 +905,16 @@ func (a ByDurationAndName) Less(i, j int) bool {
|
||||
if d1 != d2 {
|
||||
return d1 < d2
|
||||
}
|
||||
// Secondary sort by model path lex order
|
||||
return a[i].modelPath < a[j].modelPath
|
||||
// Secondary sort by model key/path lex order
|
||||
n1 := a[i].modelPath
|
||||
if n1 == "" {
|
||||
n1 = a[i].modelKey
|
||||
}
|
||||
n2 := a[j].modelPath
|
||||
if n2 == "" {
|
||||
n2 = a[j].modelKey
|
||||
}
|
||||
return n1 < n2
|
||||
}
|
||||
|
||||
// TODO - future consideration to pick runners based on size
|
||||
@@ -920,8 +974,9 @@ func (s *Scheduler) unloadAllRunners() {
|
||||
}
|
||||
|
||||
func (s *Scheduler) expireRunner(model *Model) {
|
||||
modelKey := schedulerModelKey(model)
|
||||
s.loadedMu.Lock()
|
||||
runner, ok := s.loaded[model.ModelPath]
|
||||
runner, ok := s.loaded[modelKey]
|
||||
s.loadedMu.Unlock()
|
||||
if ok {
|
||||
runner.refMu.Lock()
|
||||
|
||||
@@ -408,10 +408,10 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
s.newServerFn = a.newServer
|
||||
slog.Info("a")
|
||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
|
||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration, false)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
slog.Info("b")
|
||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
|
||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration, false)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
require.Empty(t, successCh1b)
|
||||
require.Len(t, errCh1b, 1)
|
||||
@@ -435,7 +435,7 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
|
||||
c.req.model.ModelPath = "bad path"
|
||||
slog.Info("c")
|
||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
|
||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration, false)
|
||||
// Starts in pending channel, then should be quickly processed to return an error
|
||||
time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
||||
require.Empty(t, successCh1c)
|
||||
@@ -448,6 +448,71 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
b.ctxDone()
|
||||
}
|
||||
|
||||
func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
opts := api.DefaultOptions()
|
||||
opts.NumCtx = 4
|
||||
|
||||
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
|
||||
loadedRunner := &runnerRef{
|
||||
model: loadedModel,
|
||||
modelKey: schedulerModelKey(loadedModel),
|
||||
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
|
||||
Options: &opts,
|
||||
numParallel: 1,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[loadedRunner.modelKey] = loadedRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
|
||||
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
|
||||
|
||||
require.Empty(t, successCh)
|
||||
require.Empty(t, errCh)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
}
|
||||
|
||||
func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
opts := api.DefaultOptions()
|
||||
opts.NumCtx = 4
|
||||
|
||||
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
|
||||
loadedRunner := &runnerRef{
|
||||
model: loadedModel,
|
||||
modelKey: schedulerModelKey(loadedModel),
|
||||
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
|
||||
Options: &opts,
|
||||
numParallel: 1,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[loadedRunner.modelKey] = loadedRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
reqCtx, cancelReq := context.WithCancel(ctx)
|
||||
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
|
||||
cancelReq()
|
||||
|
||||
select {
|
||||
case runner := <-successCh:
|
||||
require.Equal(t, loadedRunner, runner)
|
||||
default:
|
||||
t.Fatal("expected existing runner to be reused")
|
||||
}
|
||||
|
||||
require.Empty(t, errCh)
|
||||
require.Empty(t, s.pendingReqCh)
|
||||
}
|
||||
|
||||
func TestSchedExpireRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
|
||||
defer done()
|
||||
@@ -509,7 +574,7 @@ func TestSchedPrematureExpired(t *testing.T) {
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
s.newServerFn = scenario1a.newServer
|
||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration, false)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
s.Run(ctx)
|
||||
select {
|
||||
|
||||
14
server/test_home_test.go
Normal file
14
server/test_home_test.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func setTestHome(t *testing.T, home string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
envconfig.ReloadServerConfig()
|
||||
}
|
||||
190
version/update.go
Normal file
190
version/update.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
)
|
||||
|
||||
var updateCheckURLBase = "https://ollama.com"
|
||||
|
||||
// CheckForUpdate calls the ollama.com update API and reports whether a
|
||||
// newer version is available.
|
||||
func CheckForUpdate(ctx context.Context) (bool, error) {
|
||||
requestURL, err := url.Parse(updateCheckURLBase + "/api/update")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parse update URL: %w", err)
|
||||
}
|
||||
|
||||
query := requestURL.Query()
|
||||
query.Add("os", runtime.GOOS)
|
||||
query.Add("arch", runtime.GOARCH)
|
||||
query.Add("version", Version)
|
||||
requestURL.RawQuery = query.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
_ = auth.SignRequest(ctx, req)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("update check request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
|
||||
func cacheFilePath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "update"), nil
|
||||
}
|
||||
|
||||
// CacheAvailableUpdate creates the update marker file.
|
||||
func CacheAvailableUpdate() error {
|
||||
path, err := cacheFilePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
|
||||
// HasCachedUpdate reports whether a non-stale update marker exists.
|
||||
func HasCachedUpdate() bool {
|
||||
path, err := cacheFilePath()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Since(fi.ModTime()) <= 24*time.Hour
|
||||
}
|
||||
|
||||
// ClearCachedUpdate removes the update marker file.
|
||||
func ClearCachedUpdate() error {
|
||||
path, err := cacheFilePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.Remove(path)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func IsOfficialInstall() bool {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
exe, err = filepath.EvalSymlinks(exe)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
if localAppData == "" {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(strings.ToLower(exe), strings.ToLower(filepath.Join(localAppData, "Programs", "Ollama")+string(filepath.Separator)))
|
||||
case "darwin":
|
||||
return strings.HasPrefix(exe, "/Applications/Ollama.app/")
|
||||
default:
|
||||
dir := filepath.Dir(exe)
|
||||
return dir == "/usr/local/bin" || dir == "/usr/bin" || dir == "/bin"
|
||||
}
|
||||
}
|
||||
|
||||
// DoUpdate downloads and runs the platform-appropriate install script.
|
||||
func DoUpdate(force bool) error {
|
||||
if !force && !IsOfficialInstall() {
|
||||
return fmt.Errorf("ollama appears to be installed through a package manager. Please update it using your package manager")
|
||||
}
|
||||
|
||||
var scriptURL, tmpPattern, shell string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
scriptURL = "https://ollama.com/install.ps1"
|
||||
tmpPattern = "ollama-install-*.ps1"
|
||||
shell = "powershell"
|
||||
default:
|
||||
scriptURL = "https://ollama.com/install.sh"
|
||||
tmpPattern = "ollama-install-*.sh"
|
||||
shell = "sh"
|
||||
}
|
||||
|
||||
resp, err := http.Get(scriptURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download install script: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download install script: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", tmpPattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
if _, err := io.Copy(tmpFile, resp.Body); err != nil {
|
||||
tmpFile.Close()
|
||||
return fmt.Errorf("write install script: %w", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
cmd := exec.Command(shell, tmpFile.Name())
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// IsLocalHost reports whether the configured Ollama host points to the
|
||||
// local machine.
|
||||
func IsLocalHost(host *url.URL) bool {
|
||||
hostname := host.Hostname()
|
||||
switch hostname {
|
||||
case "", "127.0.0.1", "localhost", "::1", "0.0.0.0":
|
||||
return true
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
return ip.IsLoopback()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
146
version/update_test.go
Normal file
146
version/update_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func setHome(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
} else {
|
||||
t.Setenv("HOME", dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckForUpdate(t *testing.T) {
|
||||
t.Run("update available", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("os") == "" || r.URL.Query().Get("arch") == "" || r.URL.Query().Get("version") == "" {
|
||||
t.Error("missing expected query parameters")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
old := updateCheckURLBase
|
||||
updateCheckURLBase = ts.URL
|
||||
defer func() { updateCheckURLBase = old }()
|
||||
|
||||
available, err := CheckForUpdate(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !available {
|
||||
t.Fatal("expected update to be available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("up to date", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
old := updateCheckURLBase
|
||||
updateCheckURLBase = ts.URL
|
||||
defer func() { updateCheckURLBase = old }()
|
||||
|
||||
available, err := CheckForUpdate(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if available {
|
||||
t.Fatal("expected no update available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("network error", func(t *testing.T) {
|
||||
old := updateCheckURLBase
|
||||
updateCheckURLBase = "http://localhost:1"
|
||||
defer func() { updateCheckURLBase = old }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := CheckForUpdate(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unreachable server")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCacheRoundTrip(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
setHome(t, tmp)
|
||||
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
|
||||
|
||||
if err := CacheAvailableUpdate(); err != nil {
|
||||
t.Fatalf("cache write: %v", err)
|
||||
}
|
||||
|
||||
if !HasCachedUpdate() {
|
||||
t.Fatal("expected cached update to be present")
|
||||
}
|
||||
|
||||
if err := ClearCachedUpdate(); err != nil {
|
||||
t.Fatalf("cache clear: %v", err)
|
||||
}
|
||||
|
||||
if HasCachedUpdate() {
|
||||
t.Fatal("expected no cached update after clear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasCachedUpdateStale(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
setHome(t, tmp)
|
||||
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
|
||||
|
||||
if err := CacheAvailableUpdate(); err != nil {
|
||||
t.Fatalf("cache write: %v", err)
|
||||
}
|
||||
|
||||
// Backdate the file to make it stale
|
||||
path := filepath.Join(tmp, ".ollama", "update")
|
||||
staleTime := time.Now().Add(-25 * time.Hour)
|
||||
os.Chtimes(path, staleTime, staleTime)
|
||||
|
||||
if HasCachedUpdate() {
|
||||
t.Fatal("expected no cached update for stale file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
host string
|
||||
local bool
|
||||
}{
|
||||
{"http://127.0.0.1:11434", true},
|
||||
{"http://localhost:11434", true},
|
||||
{"http://[::1]:11434", true},
|
||||
{"http://0.0.0.0:11434", true},
|
||||
{"http://remote.example.com:11434", false},
|
||||
{"http://192.168.1.100:11434", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.host, func(t *testing.T) {
|
||||
u, err := url.Parse(tt.host)
|
||||
if err != nil {
|
||||
t.Fatalf("parse URL: %v", err)
|
||||
}
|
||||
if got := IsLocalHost(u); got != tt.local {
|
||||
t.Errorf("IsLocalHost(%s) = %v, want %v", tt.host, got, tt.local)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
27
x/cmd/run.go
27
x/cmd/run.go
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -62,6 +64,18 @@ func isLocalServer() bool {
|
||||
return hostname == "localhost" || hostname == "127.0.0.1" || strings.Contains(parsed.Host, ":11434")
|
||||
}
|
||||
|
||||
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
|
||||
status, err := client.CloudStatusExperimental(ctx)
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||
return false, false
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
return status.Cloud.Disabled, true
|
||||
}
|
||||
|
||||
// truncateToolOutput truncates tool output to prevent context overflow.
|
||||
// Uses a smaller limit (4k tokens) for local models, larger (10k) for cloud/remote.
|
||||
func truncateToolOutput(output, modelName string) string {
|
||||
@@ -86,6 +100,10 @@ func waitForOllamaSignin(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||
return errors.New(internalcloud.DisabledError("cloud account endpoints are unavailable"))
|
||||
}
|
||||
|
||||
// Get signin URL from initial Whoami call
|
||||
_, err = client.Whoami(ctx)
|
||||
if err != nil {
|
||||
@@ -664,6 +682,15 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
supportsTools = false
|
||||
}
|
||||
|
||||
if enableWebsearch {
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if disabled, known := cloudStatusDisabled(cmd.Context(), client); known && disabled {
|
||||
fmt.Fprintf(os.Stderr, "%s\n", internalcloud.DisabledError("web search is unavailable"))
|
||||
enableWebsearch = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create tool registry only if model supports tools
|
||||
var toolRegistry *tools.Registry
|
||||
if supportsTools {
|
||||
|
||||
@@ -30,6 +30,8 @@ type ModelfileConfig struct {
|
||||
Template string
|
||||
System string
|
||||
License string
|
||||
Parser string
|
||||
Renderer string
|
||||
}
|
||||
|
||||
// CreateOptions holds all options for model creation.
|
||||
@@ -37,7 +39,7 @@ type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile
|
||||
}
|
||||
|
||||
// CreateModel imports a model from a local directory.
|
||||
@@ -267,8 +269,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: caps,
|
||||
Requires: MinOllamaVersion,
|
||||
Parser: parserName,
|
||||
Renderer: rendererName,
|
||||
Parser: resolveParserName(opts.Modelfile, parserName),
|
||||
Renderer: resolveRendererName(opts.Modelfile, rendererName),
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
@@ -305,6 +307,22 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
|
||||
}
|
||||
}
|
||||
|
||||
func resolveParserName(mf *ModelfileConfig, inferred string) string {
|
||||
if mf != nil && mf.Parser != "" {
|
||||
return mf.Parser
|
||||
}
|
||||
|
||||
return inferred
|
||||
}
|
||||
|
||||
func resolveRendererName(mf *ModelfileConfig, inferred string) string {
|
||||
if mf != nil && mf.Renderer != "" {
|
||||
return mf.Renderer
|
||||
}
|
||||
|
||||
return inferred
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
var layers []manifest.Layer
|
||||
@@ -410,7 +428,7 @@ func getParserName(modelDir string) string {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
return "qwen3"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -424,7 +442,7 @@ func getParserName(modelDir string) string {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
return "qwen3"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ func TestModelfileConfig(t *testing.T) {
|
||||
Template: "{{ .Prompt }}",
|
||||
System: "You are a helpful assistant.",
|
||||
License: "MIT",
|
||||
Parser: "qwen3",
|
||||
Renderer: "qwen3",
|
||||
}
|
||||
|
||||
if config.Template != "{{ .Prompt }}" {
|
||||
@@ -21,6 +23,12 @@ func TestModelfileConfig(t *testing.T) {
|
||||
if config.License != "MIT" {
|
||||
t.Errorf("License = %q, want %q", config.License, "MIT")
|
||||
}
|
||||
if config.Parser != "qwen3" {
|
||||
t.Errorf("Parser = %q, want %q", config.Parser, "qwen3")
|
||||
}
|
||||
if config.Renderer != "qwen3" {
|
||||
t.Errorf("Renderer = %q, want %q", config.Renderer, "qwen3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_Empty(t *testing.T) {
|
||||
@@ -35,6 +43,12 @@ func TestModelfileConfig_Empty(t *testing.T) {
|
||||
if config.License != "" {
|
||||
t.Errorf("License should be empty, got %q", config.License)
|
||||
}
|
||||
if config.Parser != "" {
|
||||
t.Errorf("Parser should be empty, got %q", config.Parser)
|
||||
}
|
||||
if config.Renderer != "" {
|
||||
t.Errorf("Renderer should be empty, got %q", config.Renderer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
@@ -53,6 +67,12 @@ func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
if config.License != "" {
|
||||
t.Error("License should be empty")
|
||||
}
|
||||
if config.Parser != "" {
|
||||
t.Error("Parser should be empty")
|
||||
}
|
||||
if config.Renderer != "" {
|
||||
t.Error("Renderer should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinOllamaVersion(t *testing.T) {
|
||||
@@ -98,6 +118,8 @@ func TestCreateOptions(t *testing.T) {
|
||||
Template: "test",
|
||||
System: "system",
|
||||
License: "MIT",
|
||||
Parser: "qwen3-thinking",
|
||||
Renderer: "qwen3",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -116,6 +138,92 @@ func TestCreateOptions(t *testing.T) {
|
||||
if opts.Modelfile.Template != "test" {
|
||||
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
|
||||
}
|
||||
if opts.Modelfile.Parser != "qwen3-thinking" {
|
||||
t.Errorf("Modelfile.Parser = %q, want %q", opts.Modelfile.Parser, "qwen3-thinking")
|
||||
}
|
||||
if opts.Modelfile.Renderer != "qwen3" {
|
||||
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveParserName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mf *ModelfileConfig
|
||||
inferred string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil modelfile uses inferred",
|
||||
mf: nil,
|
||||
inferred: "qwen3",
|
||||
want: "qwen3",
|
||||
},
|
||||
{
|
||||
name: "empty parser uses inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Parser: "",
|
||||
},
|
||||
inferred: "qwen3",
|
||||
want: "qwen3",
|
||||
},
|
||||
{
|
||||
name: "explicit parser overrides inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Parser: "qwen3-thinking",
|
||||
},
|
||||
inferred: "qwen3",
|
||||
want: "qwen3-thinking",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveParserName(tt.mf, tt.inferred); got != tt.want {
|
||||
t.Fatalf("resolveParserName() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRendererName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mf *ModelfileConfig
|
||||
inferred string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil modelfile uses inferred",
|
||||
mf: nil,
|
||||
inferred: "qwen3-coder",
|
||||
want: "qwen3-coder",
|
||||
},
|
||||
{
|
||||
name: "empty renderer uses inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Renderer: "",
|
||||
},
|
||||
inferred: "qwen3-coder",
|
||||
want: "qwen3-coder",
|
||||
},
|
||||
{
|
||||
name: "explicit renderer overrides inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Renderer: "qwen3",
|
||||
},
|
||||
inferred: "qwen3-coder",
|
||||
want: "qwen3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveRendererName(tt.mf, tt.inferred); got != tt.want {
|
||||
t.Fatalf("resolveRendererName() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions_Defaults(t *testing.T) {
|
||||
|
||||
@@ -102,15 +102,20 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
|
||||
for _, entry := range entries {
|
||||
name := entry.name
|
||||
|
||||
// Try to get tensor by stripped name first, then with component prefix.
|
||||
// Blobs may store tensors with the full prefixed name (e.g., "text_encoder/model.layers.0.weight")
|
||||
// while the tensors map uses stripped names (e.g., "model.layers.0.weight").
|
||||
// Try to get tensor by stripped name first, then with component prefix,
|
||||
// then fall back to "data" for legacy blobs created by older versions
|
||||
// that stored all tensors with the generic key "data".
|
||||
lookupName := name
|
||||
arr := sf.Get(lookupName)
|
||||
if arr == nil && mw.component != "" {
|
||||
lookupName = mw.component + "/" + name
|
||||
arr = sf.Get(lookupName)
|
||||
}
|
||||
if arr == nil {
|
||||
// Legacy blob format: tensor stored as "data"
|
||||
lookupName = "data"
|
||||
arr = sf.Get(lookupName)
|
||||
}
|
||||
if arr != nil {
|
||||
// Single-tensor blob or tensor found by name
|
||||
if dtype != 0 && arr.Dtype() != dtype {
|
||||
|
||||
@@ -16,10 +16,10 @@ import (
|
||||
)
|
||||
|
||||
type Function struct {
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
NeedsARM64Guard bool
|
||||
}
|
||||
|
||||
@@ -29,6 +29,11 @@ func findHeaders(directory string) ([]string, error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Private headers contain C++ implementation helpers and are not part of
|
||||
// the C API surface; parsing them can produce invalid wrapper signatures.
|
||||
if d.IsDir() && d.Name() == "private" {
|
||||
return fs.SkipDir
|
||||
}
|
||||
if !d.IsDir() && strings.HasSuffix(path, ".h") {
|
||||
headers = append(headers, path)
|
||||
}
|
||||
@@ -194,10 +199,10 @@ func parseFunctions(content string) []Function {
|
||||
needsGuard := needsARM64Guard(funcName, returnType, params)
|
||||
|
||||
functions = append(functions, Function{
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
NeedsARM64Guard: needsGuard,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,8 @@ mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL;
|
||||
mlx_array (*mlx_array_new_double_ptr)(double val) = NULL;
|
||||
mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL;
|
||||
mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL;
|
||||
int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL;
|
||||
int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL;
|
||||
int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL;
|
||||
@@ -49,7 +51,7 @@ int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL;
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL;
|
||||
#endif
|
||||
@@ -67,7 +69,7 @@ const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL;
|
||||
const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL;
|
||||
const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL;
|
||||
const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL;
|
||||
const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
|
||||
const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL;
|
||||
#endif
|
||||
@@ -123,6 +125,7 @@ int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL;
|
||||
int (*mlx_disable_compile_ptr)(void) = NULL;
|
||||
int (*mlx_enable_compile_ptr)(void) = NULL;
|
||||
int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL;
|
||||
int (*mlx_cuda_is_available_ptr)(bool* res) = NULL;
|
||||
mlx_device (*mlx_device_new_ptr)(void) = NULL;
|
||||
mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL;
|
||||
int (*mlx_device_free_ptr)(mlx_device dev) = NULL;
|
||||
@@ -133,6 +136,16 @@ int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL;
|
||||
int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL;
|
||||
int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL;
|
||||
int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL;
|
||||
int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL;
|
||||
int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL;
|
||||
mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL;
|
||||
int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL;
|
||||
int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL;
|
||||
int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL;
|
||||
int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL;
|
||||
int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
||||
int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
||||
@@ -263,7 +276,6 @@ int (*mlx_reset_peak_memory_ptr)(void) = NULL;
|
||||
int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL;
|
||||
int (*mlx_metal_is_available_ptr)(bool* res) = NULL;
|
||||
int (*mlx_metal_start_capture_ptr)(const char* path) = NULL;
|
||||
int (*mlx_metal_stop_capture_ptr)(void) = NULL;
|
||||
@@ -658,6 +670,16 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed");
|
||||
if (mlx_array_new_data_managed_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload");
|
||||
if (mlx_array_new_data_managed_payload_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_set_ptr = dlsym(handle, "mlx_array_set");
|
||||
if (mlx_array_set_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n");
|
||||
@@ -1141,6 +1163,11 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available");
|
||||
if (mlx_cuda_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_new_ptr = dlsym(handle, "mlx_device_new");
|
||||
if (mlx_device_new_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n");
|
||||
@@ -1191,6 +1218,56 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available");
|
||||
if (mlx_device_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_count_ptr = dlsym(handle, "mlx_device_count");
|
||||
if (mlx_device_count_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new");
|
||||
if (mlx_device_info_new_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get");
|
||||
if (mlx_device_info_get_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free");
|
||||
if (mlx_device_info_free_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key");
|
||||
if (mlx_device_info_has_key_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string");
|
||||
if (mlx_device_info_is_string_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string");
|
||||
if (mlx_device_info_get_string_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size");
|
||||
if (mlx_device_info_get_size_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys");
|
||||
if (mlx_device_info_get_keys_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather");
|
||||
if (mlx_distributed_all_gather_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n");
|
||||
@@ -1841,11 +1918,6 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info");
|
||||
if (mlx_metal_device_info_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available");
|
||||
if (mlx_metal_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n");
|
||||
@@ -3528,6 +3600,14 @@ mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dt
|
||||
return mlx_array_new_data_ptr(data, shape, dim, dtype);
|
||||
}
|
||||
|
||||
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) {
|
||||
return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor);
|
||||
}
|
||||
|
||||
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) {
|
||||
return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor);
|
||||
}
|
||||
|
||||
int mlx_array_set(mlx_array* arr, const mlx_array src) {
|
||||
return mlx_array_set_ptr(arr, src);
|
||||
}
|
||||
@@ -3644,7 +3724,7 @@ int mlx_array_item_float64(double* res, const mlx_array arr) {
|
||||
return mlx_array_item_float64_ptr(res, arr);
|
||||
}
|
||||
|
||||
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) {
|
||||
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) {
|
||||
return mlx_array_item_complex64_ptr(res, arr);
|
||||
}
|
||||
|
||||
@@ -3704,7 +3784,7 @@ const double* mlx_array_data_float64(const mlx_array arr) {
|
||||
return mlx_array_data_float64_ptr(arr);
|
||||
}
|
||||
|
||||
const float _Complex* mlx_array_data_complex64(const mlx_array arr) {
|
||||
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) {
|
||||
return mlx_array_data_complex64_ptr(arr);
|
||||
}
|
||||
|
||||
@@ -3916,6 +3996,10 @@ int mlx_set_compile_mode(mlx_compile_mode mode) {
|
||||
return mlx_set_compile_mode_ptr(mode);
|
||||
}
|
||||
|
||||
int mlx_cuda_is_available(bool* res) {
|
||||
return mlx_cuda_is_available_ptr(res);
|
||||
}
|
||||
|
||||
mlx_device mlx_device_new(void) {
|
||||
return mlx_device_new_ptr();
|
||||
}
|
||||
@@ -3956,6 +4040,46 @@ int mlx_set_default_device(mlx_device dev) {
|
||||
return mlx_set_default_device_ptr(dev);
|
||||
}
|
||||
|
||||
int mlx_device_is_available(bool* avail, mlx_device dev) {
|
||||
return mlx_device_is_available_ptr(avail, dev);
|
||||
}
|
||||
|
||||
int mlx_device_count(int* count, mlx_device_type type) {
|
||||
return mlx_device_count_ptr(count, type);
|
||||
}
|
||||
|
||||
mlx_device_info mlx_device_info_new(void) {
|
||||
return mlx_device_info_new_ptr();
|
||||
}
|
||||
|
||||
int mlx_device_info_get(mlx_device_info* info, mlx_device dev) {
|
||||
return mlx_device_info_get_ptr(info, dev);
|
||||
}
|
||||
|
||||
int mlx_device_info_free(mlx_device_info info) {
|
||||
return mlx_device_info_free_ptr(info);
|
||||
}
|
||||
|
||||
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_has_key_ptr(exists, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_is_string_ptr(is_string, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_get_string_ptr(value, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_get_size_ptr(value, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) {
|
||||
return mlx_device_info_get_keys_ptr(keys, info);
|
||||
}
|
||||
|
||||
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) {
|
||||
return mlx_distributed_all_gather_ptr(res, x, group, S);
|
||||
}
|
||||
@@ -4476,10 +4600,6 @@ int mlx_set_wired_limit(size_t* res, size_t limit) {
|
||||
return mlx_set_wired_limit_ptr(res, limit);
|
||||
}
|
||||
|
||||
mlx_metal_device_info_t mlx_metal_device_info(void) {
|
||||
return mlx_metal_device_info_ptr();
|
||||
}
|
||||
|
||||
int mlx_metal_is_available(bool* res) {
|
||||
return mlx_metal_is_available_ptr(res);
|
||||
}
|
||||
|
||||
@@ -26,6 +26,8 @@
|
||||
#undef mlx_array_new_double
|
||||
#undef mlx_array_new_complex
|
||||
#undef mlx_array_new_data
|
||||
#undef mlx_array_new_data_managed
|
||||
#undef mlx_array_new_data_managed_payload
|
||||
#undef mlx_array_set
|
||||
#undef mlx_array_set_bool
|
||||
#undef mlx_array_set_int
|
||||
@@ -121,6 +123,7 @@
|
||||
#undef mlx_disable_compile
|
||||
#undef mlx_enable_compile
|
||||
#undef mlx_set_compile_mode
|
||||
#undef mlx_cuda_is_available
|
||||
#undef mlx_device_new
|
||||
#undef mlx_device_new_type
|
||||
#undef mlx_device_free
|
||||
@@ -131,6 +134,16 @@
|
||||
#undef mlx_device_get_type
|
||||
#undef mlx_get_default_device
|
||||
#undef mlx_set_default_device
|
||||
#undef mlx_device_is_available
|
||||
#undef mlx_device_count
|
||||
#undef mlx_device_info_new
|
||||
#undef mlx_device_info_get
|
||||
#undef mlx_device_info_free
|
||||
#undef mlx_device_info_has_key
|
||||
#undef mlx_device_info_is_string
|
||||
#undef mlx_device_info_get_string
|
||||
#undef mlx_device_info_get_size
|
||||
#undef mlx_device_info_get_keys
|
||||
#undef mlx_distributed_all_gather
|
||||
#undef mlx_distributed_all_max
|
||||
#undef mlx_distributed_all_min
|
||||
@@ -261,7 +274,6 @@
|
||||
#undef mlx_set_cache_limit
|
||||
#undef mlx_set_memory_limit
|
||||
#undef mlx_set_wired_limit
|
||||
#undef mlx_metal_device_info
|
||||
#undef mlx_metal_is_available
|
||||
#undef mlx_metal_start_capture
|
||||
#undef mlx_metal_stop_capture
|
||||
@@ -602,6 +614,8 @@ extern mlx_array (*mlx_array_new_float64_ptr)(double val);
|
||||
extern mlx_array (*mlx_array_new_double_ptr)(double val);
|
||||
extern mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val);
|
||||
extern mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype);
|
||||
extern mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
|
||||
extern mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
|
||||
extern int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src);
|
||||
extern int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val);
|
||||
extern int (*mlx_array_set_int_ptr)(mlx_array* arr, int val);
|
||||
@@ -631,7 +645,7 @@ extern int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr);
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
extern int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr);
|
||||
#endif
|
||||
@@ -649,7 +663,7 @@ extern const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr);
|
||||
extern const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr);
|
||||
extern const float* (*mlx_array_data_float32_ptr)(const mlx_array arr);
|
||||
extern const double* (*mlx_array_data_float64_ptr)(const mlx_array arr);
|
||||
extern const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
|
||||
extern const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
extern const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr);
|
||||
#endif
|
||||
@@ -705,6 +719,7 @@ extern int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id);
|
||||
extern int (*mlx_disable_compile_ptr)(void);
|
||||
extern int (*mlx_enable_compile_ptr)(void);
|
||||
extern int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode);
|
||||
extern int (*mlx_cuda_is_available_ptr)(bool* res);
|
||||
extern mlx_device (*mlx_device_new_ptr)(void);
|
||||
extern mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index);
|
||||
extern int (*mlx_device_free_ptr)(mlx_device dev);
|
||||
@@ -715,6 +730,16 @@ extern int (*mlx_device_get_index_ptr)(int* index, mlx_device dev);
|
||||
extern int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev);
|
||||
extern int (*mlx_get_default_device_ptr)(mlx_device* dev);
|
||||
extern int (*mlx_set_default_device_ptr)(mlx_device dev);
|
||||
extern int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev);
|
||||
extern int (*mlx_device_count_ptr)(int* count, mlx_device_type type);
|
||||
extern mlx_device_info (*mlx_device_info_new_ptr)(void);
|
||||
extern int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev);
|
||||
extern int (*mlx_device_info_free_ptr)(mlx_device_info info);
|
||||
extern int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info);
|
||||
extern int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
|
||||
extern int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
extern int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
@@ -845,7 +870,6 @@ extern int (*mlx_reset_peak_memory_ptr)(void);
|
||||
extern int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit);
|
||||
extern int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit);
|
||||
extern int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit);
|
||||
extern mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void);
|
||||
extern int (*mlx_metal_is_available_ptr)(bool* res);
|
||||
extern int (*mlx_metal_start_capture_ptr)(const char* path);
|
||||
extern int (*mlx_metal_stop_capture_ptr)(void);
|
||||
@@ -1202,6 +1226,10 @@ mlx_array mlx_array_new_complex(float real_val, float imag_val);
|
||||
|
||||
mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype);
|
||||
|
||||
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
|
||||
|
||||
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
|
||||
|
||||
int mlx_array_set(mlx_array* arr, const mlx_array src);
|
||||
|
||||
int mlx_array_set_bool(mlx_array* arr, bool val);
|
||||
@@ -1260,7 +1288,7 @@ int mlx_array_item_float32(float* res, const mlx_array arr);
|
||||
|
||||
int mlx_array_item_float64(double* res, const mlx_array arr);
|
||||
|
||||
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr);
|
||||
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
|
||||
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
|
||||
@@ -1292,7 +1320,7 @@ const float* mlx_array_data_float32(const mlx_array arr);
|
||||
|
||||
const double* mlx_array_data_float64(const mlx_array arr);
|
||||
|
||||
const float _Complex* mlx_array_data_complex64(const mlx_array arr);
|
||||
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
|
||||
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
const float16_t* mlx_array_data_float16(const mlx_array arr);
|
||||
@@ -1400,6 +1428,8 @@ int mlx_enable_compile(void);
|
||||
|
||||
int mlx_set_compile_mode(mlx_compile_mode mode);
|
||||
|
||||
int mlx_cuda_is_available(bool* res);
|
||||
|
||||
mlx_device mlx_device_new(void);
|
||||
|
||||
mlx_device mlx_device_new_type(mlx_device_type type, int index);
|
||||
@@ -1420,6 +1450,26 @@ int mlx_get_default_device(mlx_device* dev);
|
||||
|
||||
int mlx_set_default_device(mlx_device dev);
|
||||
|
||||
int mlx_device_is_available(bool* avail, mlx_device dev);
|
||||
|
||||
int mlx_device_count(int* count, mlx_device_type type);
|
||||
|
||||
mlx_device_info mlx_device_info_new(void);
|
||||
|
||||
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
|
||||
|
||||
int mlx_device_info_free(mlx_device_info info);
|
||||
|
||||
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
|
||||
|
||||
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
|
||||
|
||||
int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
@@ -1680,8 +1730,6 @@ int mlx_set_memory_limit(size_t* res, size_t limit);
|
||||
|
||||
int mlx_set_wired_limit(size_t* res, size_t limit);
|
||||
|
||||
mlx_metal_device_info_t mlx_metal_device_info(void);
|
||||
|
||||
int mlx_metal_is_available(bool* res);
|
||||
|
||||
int mlx_metal_start_capture(const char* path);
|
||||
|
||||
@@ -2,76 +2,298 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
||||
type Client struct {
|
||||
Port int
|
||||
*exec.Cmd
|
||||
port int
|
||||
modelName string
|
||||
vramSize uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string
|
||||
lastErrLock sync.Mutex
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func (c *Client) JoinPath(path string) string {
|
||||
return (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort("127.0.0.1", strconv.Itoa(c.Port)),
|
||||
}).JoinPath(path).String()
|
||||
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
|
||||
func NewClient(modelName string) (*Client, error) {
|
||||
if err := imagegen.CheckPlatformSupport(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find a free port
|
||||
port := 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
if l, err := net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
port = rand.Intn(65535-49152) + 49152
|
||||
}
|
||||
|
||||
// Get the current executable path
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// Spawn subprocess: ollama runner --mlx-engine --model <name> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
if runtime.GOOS == "linux" {
|
||||
libraryPaths := []string{ml.LibOllamaPath}
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
|
||||
libraryPaths = append(libraryPaths, mlxDirs...)
|
||||
}
|
||||
|
||||
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
found := false
|
||||
for i := range cmd.Env {
|
||||
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
|
||||
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
|
||||
}
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
|
||||
// Estimate VRAM based on tensor size from manifest
|
||||
var vramSize uint64
|
||||
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
|
||||
vramSize = uint64(modelManifest.TotalTensorSize())
|
||||
} else {
|
||||
vramSize = 8 * 1024 * 1024 * 1024
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
port: port,
|
||||
modelName: modelName,
|
||||
vramSize: vramSize,
|
||||
done: make(chan error, 1),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
cmd: cmd,
|
||||
}
|
||||
|
||||
// Forward subprocess stdout/stderr to server logs
|
||||
stdout, _ := cmd.StdoutPipe()
|
||||
stderr, _ := cmd.StderrPipe()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
for scanner.Scan() {
|
||||
slog.Info("mlx-runner", "msg", scanner.Text())
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
slog.Warn("mlx-runner", "msg", line)
|
||||
c.lastErrLock.Lock()
|
||||
c.lastErr = line
|
||||
c.lastErrLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
|
||||
}
|
||||
|
||||
// Reap subprocess when it exits
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
c.done <- err
|
||||
}()
|
||||
|
||||
// Wait for subprocess to be ready
|
||||
if err := c.waitUntilRunning(); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) CheckError(w *http.Response) error {
|
||||
if w.StatusCode >= 400 {
|
||||
return errors.New(w.Status)
|
||||
func (c *Client) getLastErr() string {
|
||||
c.lastErrLock.Lock()
|
||||
defer c.lastErrLock.Unlock()
|
||||
return c.lastErr
|
||||
}
|
||||
|
||||
func (c *Client) waitUntilRunning() error {
|
||||
ctx := context.Background()
|
||||
timeout := time.After(2 * time.Minute)
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-c.done:
|
||||
errMsg := c.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
|
||||
}
|
||||
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
|
||||
case <-timeout:
|
||||
errMsg := c.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
|
||||
}
|
||||
return errors.New("timeout waiting for mlx runner to start")
|
||||
case <-ticker.C:
|
||||
if err := c.Ping(ctx); err == nil {
|
||||
slog.Info("mlx runner is ready", "port", c.port)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
||||
type completionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Options *completionOpts `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type completionOpts struct {
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
MinP float32 `json:"min_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.cmd != nil && c.cmd.Process != nil {
|
||||
slog.Info("stopping mlx runner subprocess", "pid", c.cmd.Process.Pid)
|
||||
c.cmd.Process.Signal(os.Interrupt)
|
||||
|
||||
select {
|
||||
case <-c.done:
|
||||
case <-time.After(5 * time.Second):
|
||||
c.cmd.Process.Kill()
|
||||
}
|
||||
c.cmd = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements llm.LlamaServer.
|
||||
func (c *Client) Close() error {
|
||||
return c.Cmd.Process.Kill()
|
||||
}
|
||||
|
||||
// Completion implements llm.LlamaServer.
|
||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(req); err != nil {
|
||||
return err
|
||||
creq := completionRequest{
|
||||
Prompt: req.Prompt,
|
||||
}
|
||||
if req.Options != nil {
|
||||
creq.Options = &completionOpts{
|
||||
Temperature: req.Options.Temperature,
|
||||
TopP: req.Options.TopP,
|
||||
MinP: req.Options.MinP,
|
||||
TopK: req.Options.TopK,
|
||||
NumPredict: req.Options.NumPredict,
|
||||
}
|
||||
}
|
||||
|
||||
w, err := http.Post(c.JoinPath("/v1/completions"), "application/json", &b)
|
||||
body, err := json.Marshal(creq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
if err := c.CheckError(w); err != nil {
|
||||
httpURL := fmt.Sprintf("http://127.0.0.1:%d/completion", c.port)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", httpURL, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
scanner := bufio.NewScanner(w.Body)
|
||||
for scanner.Scan() {
|
||||
bts := scanner.Bytes()
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var resp llm.CompletionResponse
|
||||
if err := json.Unmarshal(bts, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fn(resp)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("%s", strings.TrimSpace(string(respBody)))
|
||||
}
|
||||
|
||||
return nil
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
var raw struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||
continue
|
||||
}
|
||||
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||
PromptEvalCount: raw.PromptEvalCount,
|
||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
||||
EvalCount: raw.EvalCount,
|
||||
EvalDuration: time.Duration(raw.EvalDuration),
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
if cresp.Done {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func (c *Client) ContextLength() int {
|
||||
@@ -80,71 +302,89 @@ func (c *Client) ContextLength() int {
|
||||
|
||||
// Detokenize implements llm.LlamaServer.
|
||||
func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
panic("unimplemented")
|
||||
return "", errors.New("not supported")
|
||||
}
|
||||
|
||||
// Embedding implements llm.LlamaServer.
|
||||
func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
panic("unimplemented")
|
||||
return nil, 0, errors.New("not supported")
|
||||
}
|
||||
|
||||
// GetDeviceInfos implements llm.LlamaServer.
|
||||
func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
panic("unimplemented")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPort implements llm.LlamaServer.
|
||||
func (c *Client) GetPort() int {
|
||||
return c.Port
|
||||
return c.port
|
||||
}
|
||||
|
||||
// HasExited implements llm.LlamaServer.
|
||||
func (c *Client) HasExited() bool {
|
||||
panic("unimplemented")
|
||||
select {
|
||||
case <-c.done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Load implements llm.LlamaServer.
|
||||
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) {
|
||||
w, err := http.Post(c.JoinPath("/v1/models"), "application/json", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
return []ml.DeviceID{}, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ModelPath implements llm.LlamaServer.
|
||||
func (c *Client) ModelPath() string {
|
||||
panic("unimplemented")
|
||||
return c.modelName
|
||||
}
|
||||
|
||||
// Pid implements llm.LlamaServer.
|
||||
func (c *Client) Pid() int {
|
||||
panic("unimplemented")
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.cmd != nil && c.cmd.Process != nil {
|
||||
return c.cmd.Process.Pid
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// Ping implements llm.LlamaServer.
|
||||
func (c *Client) Ping(ctx context.Context) error {
|
||||
w, err := http.Get(c.JoinPath("/v1/status"))
|
||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Tokenize implements llm.LlamaServer.
|
||||
func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
w, err := http.Post(c.JoinPath("/v1/tokenize"), "text/plain", strings.NewReader(content))
|
||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/tokenize", c.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, strings.NewReader(content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var tokens []int
|
||||
if err := json.NewDecoder(w.Body).Decode(&tokens); err != nil {
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -153,22 +393,22 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
|
||||
// TotalSize implements llm.LlamaServer.
|
||||
func (c *Client) TotalSize() uint64 {
|
||||
panic("unimplemented")
|
||||
return c.vramSize
|
||||
}
|
||||
|
||||
// VRAMByGPU implements llm.LlamaServer.
|
||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
panic("unimplemented")
|
||||
return c.vramSize
|
||||
}
|
||||
|
||||
// VRAMSize implements llm.LlamaServer.
|
||||
func (c *Client) VRAMSize() uint64 {
|
||||
panic("unimplemented")
|
||||
return c.vramSize
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements llm.LlamaServer.
|
||||
func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||
panic("unimplemented")
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ llm.LlamaServer = (*Client)(nil)
|
||||
|
||||
10
x/mlxrunner/imports.go
Normal file
10
x/mlxrunner/imports.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/x/models/gemma3"
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
)
|
||||
@@ -15,7 +15,7 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
|
||||
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
|
||||
@@ -133,6 +133,7 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
|
||||
}
|
||||
|
||||
func (t *Array) Set(other *Array) {
|
||||
Free(t.desc.inputs...)
|
||||
other.desc.numRefs++
|
||||
t.desc.inputs = []*Array{other}
|
||||
C.mlx_array_set(&t.ctx, other.ctx)
|
||||
@@ -248,9 +249,9 @@ func Free(s ...*Array) (n int) {
|
||||
free := make([]*Array, 0, 8192)
|
||||
fn := func(t *Array) {
|
||||
if t.Valid() {
|
||||
free = append(free, t.desc.inputs...)
|
||||
t.desc.numRefs--
|
||||
if t.desc.numRefs <= 0 {
|
||||
free = append(free, t.desc.inputs...)
|
||||
logutil.Trace("Free", "t", t)
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
|
||||
@@ -24,6 +24,37 @@ func CheckInit() error {
|
||||
return initError
|
||||
}
|
||||
|
||||
// tryLoadFromDir searches a directory for libmlxc.* and tries to load it.
|
||||
// Returns true if the library was successfully loaded.
|
||||
func tryLoadFromDir(dir string) bool {
|
||||
matches, err := fs.Glob(os.DirFS(dir), "libmlxc.*")
|
||||
if err != nil || len(matches) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
path := filepath.Join(dir, match)
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
var handle C.mlx_dynamic_handle
|
||||
if C.mlx_dynamic_load(&handle, cPath) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library", "path", path)
|
||||
continue
|
||||
}
|
||||
|
||||
if C.mlx_dynamic_load_symbols(handle) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
|
||||
C.mlx_dynamic_unload(&handle)
|
||||
continue
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
@@ -33,44 +64,34 @@ func init() {
|
||||
return
|
||||
}
|
||||
|
||||
paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH")
|
||||
if !ok {
|
||||
slog.Debug("OLLAMA_LIBRARY_PATH not set, skipping mlx dynamic loading")
|
||||
return
|
||||
// Try OLLAMA_LIBRARY_PATH first
|
||||
if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
|
||||
for _, dir := range filepath.SplitList(paths) {
|
||||
if tryLoadFromDir(dir) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, path := range filepath.SplitList(paths) {
|
||||
matches, err := fs.Glob(os.DirFS(path), "libmlxc.*")
|
||||
if err != nil {
|
||||
initError = fmt.Errorf("failed to glob for MLX libraries in %s: %w", path, err)
|
||||
slog.Warn("MLX dynamic library not available", "error", initError)
|
||||
return
|
||||
// Build search paths: executable directory, then build directories
|
||||
var searchDirs []string
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
searchDirs = append(searchDirs, filepath.Dir(exe))
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
path := filepath.Join(paths, match)
|
||||
slog.Info("Loading MLX dynamic library", "path", path)
|
||||
if cwd, err := os.Getwd(); err == nil {
|
||||
searchDirs = append(searchDirs, filepath.Join(cwd, "build", "lib", "ollama"))
|
||||
}
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
var handle C.mlx_dynamic_handle
|
||||
if C.mlx_dynamic_load(&handle, cPath) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library", "path", path)
|
||||
continue
|
||||
}
|
||||
|
||||
if C.mlx_dynamic_load_symbols(handle) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
|
||||
C.mlx_dynamic_unload(&handle)
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("Loaded MLX dynamic library", "path", path)
|
||||
for _, dir := range searchDirs {
|
||||
if tryLoadFromDir(dir) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
initError = fmt.Errorf("failed to load any MLX dynamic library from OLLAMA_LIBRARY_PATH=%s", paths)
|
||||
initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", searchDirs)
|
||||
slog.Warn("MLX dynamic library not available", "error", initError)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,19 @@ mlx_array (*mlx_array_new_data_)(
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_)(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void (*dtor)(void*)) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_payload_)(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void* payload,
|
||||
void (*dtor)(void*)) = NULL;
|
||||
int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL;
|
||||
int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL;
|
||||
int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL;
|
||||
@@ -56,7 +69,7 @@ int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL;
|
||||
const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL;
|
||||
@@ -70,7 +83,7 @@ const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL;
|
||||
const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL;
|
||||
const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL;
|
||||
const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL;
|
||||
const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
|
||||
const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
|
||||
const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL;
|
||||
const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL;
|
||||
int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL;
|
||||
@@ -94,10 +107,11 @@ int (*mlx_closure_apply_)(
|
||||
mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL;
|
||||
int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -136,11 +150,12 @@ int (*mlx_closure_value_and_grad_apply_)(
|
||||
const mlx_vector_array input) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array)) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array)) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -161,12 +176,13 @@ int (*mlx_closure_custom_apply_)(
|
||||
const mlx_vector_array input_2) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -189,12 +205,13 @@ int (*mlx_closure_custom_jvp_apply_)(
|
||||
size_t input_2_num) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -228,6 +245,7 @@ int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL;
|
||||
int (*mlx_disable_compile_)(void) = NULL;
|
||||
int (*mlx_enable_compile_)(void) = NULL;
|
||||
int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL;
|
||||
int (*mlx_cuda_is_available_)(bool* res) = NULL;
|
||||
mlx_device (*mlx_device_new_)(void) = NULL;
|
||||
mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL;
|
||||
int (*mlx_device_free_)(mlx_device dev) = NULL;
|
||||
@@ -238,11 +256,28 @@ int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL;
|
||||
int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL;
|
||||
int (*mlx_get_default_device_)(mlx_device* dev) = NULL;
|
||||
int (*mlx_set_default_device_)(mlx_device dev) = NULL;
|
||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
||||
int (*mlx_device_is_available_)(bool* avail, mlx_device dev) = NULL;
|
||||
int (*mlx_device_count_)(int* count, mlx_device_type type) = NULL;
|
||||
mlx_device_info (*mlx_device_info_new_)(void) = NULL;
|
||||
int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev) = NULL;
|
||||
int (*mlx_device_info_free_)(mlx_device_info info) = NULL;
|
||||
int (*mlx_device_info_has_key_)(
|
||||
bool* exists,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_is_string_)(
|
||||
bool* is_string,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_string_)(
|
||||
const char** value,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_size_)(
|
||||
size_t* value,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info) = NULL;
|
||||
int (*mlx_distributed_all_gather_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
@@ -288,6 +323,11 @@ int (*mlx_distributed_sum_scatter_)(
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
||||
void (*mlx_set_error_handler_)(
|
||||
mlx_error_handler_func handler,
|
||||
void* data,
|
||||
@@ -450,6 +490,16 @@ int (*mlx_fast_rope_)(
|
||||
int offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_fast_rope_dynamic_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
mlx_optional_float base,
|
||||
float scale,
|
||||
const mlx_array offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_fast_scaled_dot_product_attention_)(
|
||||
mlx_array* res,
|
||||
const mlx_array queries,
|
||||
@@ -560,14 +610,6 @@ int (*mlx_fft_rfftn_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
|
||||
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
|
||||
int (*mlx_load_reader_)(
|
||||
mlx_array* res,
|
||||
mlx_io_reader in_stream,
|
||||
@@ -593,6 +635,14 @@ int (*mlx_save_safetensors_)(
|
||||
const char* file,
|
||||
const mlx_map_string_to_array param,
|
||||
const mlx_map_string_to_string metadata) = NULL;
|
||||
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
|
||||
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
|
||||
int (*mlx_linalg_cholesky_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -733,7 +783,6 @@ int (*mlx_reset_peak_memory_)(void) = NULL;
|
||||
int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL;
|
||||
mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL;
|
||||
int (*mlx_metal_is_available_)(bool* res) = NULL;
|
||||
int (*mlx_metal_start_capture_)(const char* path) = NULL;
|
||||
int (*mlx_metal_stop_capture_)(void) = NULL;
|
||||
@@ -1162,6 +1211,14 @@ int (*mlx_gather_)(
|
||||
const int* slice_sizes,
|
||||
size_t slice_sizes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_gather_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
int axis,
|
||||
const int* slice_sizes,
|
||||
size_t slice_sizes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_gather_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1483,6 +1540,15 @@ int (*mlx_put_along_axis_)(
|
||||
const mlx_array values,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_qqmm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_array w,
|
||||
const mlx_array w_scales /* may be null */,
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_quantize_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_array w,
|
||||
@@ -1566,6 +1632,13 @@ int (*mlx_scatter_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1574,6 +1647,13 @@ int (*mlx_scatter_add_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_axis_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1589,6 +1669,13 @@ int (*mlx_scatter_max_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_max_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_min_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1597,6 +1684,13 @@ int (*mlx_scatter_min_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_min_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_prod_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1605,6 +1699,13 @@ int (*mlx_scatter_prod_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_prod_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_segmented_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -2028,22 +2129,6 @@ mlx_string (*mlx_string_new_data_)(const char* str) = NULL;
|
||||
int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL;
|
||||
const char * (*mlx_string_data_)(mlx_string str) = NULL;
|
||||
int (*mlx_string_free_)(mlx_string str) = NULL;
|
||||
int (*mlx_detail_vmap_replace_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_vector_array s_inputs,
|
||||
const mlx_vector_array s_outputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num,
|
||||
const int* out_axes,
|
||||
size_t out_axes_num) = NULL;
|
||||
int (*mlx_detail_vmap_trace_)(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array inputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num) = NULL;
|
||||
int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL;
|
||||
int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL;
|
||||
int (*mlx_custom_function_)(
|
||||
@@ -2074,6 +2159,22 @@ int (*mlx_vjp_)(
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array primals,
|
||||
const mlx_vector_array cotangents) = NULL;
|
||||
int (*mlx_detail_vmap_replace_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_vector_array s_inputs,
|
||||
const mlx_vector_array s_outputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num,
|
||||
const int* out_axes,
|
||||
size_t out_axes_num) = NULL;
|
||||
int (*mlx_detail_vmap_trace_)(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array inputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num) = NULL;
|
||||
mlx_vector_array (*mlx_vector_array_new_)(void) = NULL;
|
||||
int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL;
|
||||
int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL;
|
||||
@@ -2166,6 +2267,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_array_new_double);
|
||||
CHECK_LOAD(handle, mlx_array_new_complex);
|
||||
CHECK_LOAD(handle, mlx_array_new_data);
|
||||
CHECK_LOAD(handle, mlx_array_new_data_managed);
|
||||
CHECK_LOAD(handle, mlx_array_new_data_managed_payload);
|
||||
CHECK_LOAD(handle, mlx_array_set);
|
||||
CHECK_LOAD(handle, mlx_array_set_bool);
|
||||
CHECK_LOAD(handle, mlx_array_set_int);
|
||||
@@ -2261,6 +2364,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_disable_compile);
|
||||
CHECK_LOAD(handle, mlx_enable_compile);
|
||||
CHECK_LOAD(handle, mlx_set_compile_mode);
|
||||
CHECK_LOAD(handle, mlx_cuda_is_available);
|
||||
CHECK_LOAD(handle, mlx_device_new);
|
||||
CHECK_LOAD(handle, mlx_device_new_type);
|
||||
CHECK_LOAD(handle, mlx_device_free);
|
||||
@@ -2271,11 +2375,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_device_get_type);
|
||||
CHECK_LOAD(handle, mlx_get_default_device);
|
||||
CHECK_LOAD(handle, mlx_set_default_device);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_rank);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_size);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_split);
|
||||
CHECK_LOAD(handle, mlx_distributed_is_available);
|
||||
CHECK_LOAD(handle, mlx_distributed_init);
|
||||
CHECK_LOAD(handle, mlx_device_is_available);
|
||||
CHECK_LOAD(handle, mlx_device_count);
|
||||
CHECK_LOAD(handle, mlx_device_info_new);
|
||||
CHECK_LOAD(handle, mlx_device_info_get);
|
||||
CHECK_LOAD(handle, mlx_device_info_free);
|
||||
CHECK_LOAD(handle, mlx_device_info_has_key);
|
||||
CHECK_LOAD(handle, mlx_device_info_is_string);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_string);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_size);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_keys);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_gather);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_max);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_min);
|
||||
@@ -2284,6 +2393,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_distributed_recv_like);
|
||||
CHECK_LOAD(handle, mlx_distributed_send);
|
||||
CHECK_LOAD(handle, mlx_distributed_sum_scatter);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_rank);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_size);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_split);
|
||||
CHECK_LOAD(handle, mlx_distributed_is_available);
|
||||
CHECK_LOAD(handle, mlx_distributed_init);
|
||||
CHECK_LOAD(handle, mlx_set_error_handler);
|
||||
CHECK_LOAD(handle, _mlx_error);
|
||||
CHECK_LOAD(handle, mlx_export_function);
|
||||
@@ -2325,6 +2439,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_fast_metal_kernel_apply);
|
||||
CHECK_LOAD(handle, mlx_fast_rms_norm);
|
||||
CHECK_LOAD(handle, mlx_fast_rope);
|
||||
CHECK_LOAD(handle, mlx_fast_rope_dynamic);
|
||||
CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention);
|
||||
CHECK_LOAD(handle, mlx_fft_fft);
|
||||
CHECK_LOAD(handle, mlx_fft_fft2);
|
||||
@@ -2340,14 +2455,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_fft_rfft);
|
||||
CHECK_LOAD(handle, mlx_fft_rfft2);
|
||||
CHECK_LOAD(handle, mlx_fft_rfftn);
|
||||
CHECK_LOAD(handle, mlx_io_reader_new);
|
||||
CHECK_LOAD(handle, mlx_io_reader_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_reader_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_reader_free);
|
||||
CHECK_LOAD(handle, mlx_io_writer_new);
|
||||
CHECK_LOAD(handle, mlx_io_writer_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_writer_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_writer_free);
|
||||
CHECK_LOAD(handle, mlx_load_reader);
|
||||
CHECK_LOAD(handle, mlx_load);
|
||||
CHECK_LOAD(handle, mlx_load_safetensors_reader);
|
||||
@@ -2356,6 +2463,14 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_save);
|
||||
CHECK_LOAD(handle, mlx_save_safetensors_writer);
|
||||
CHECK_LOAD(handle, mlx_save_safetensors);
|
||||
CHECK_LOAD(handle, mlx_io_reader_new);
|
||||
CHECK_LOAD(handle, mlx_io_reader_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_reader_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_reader_free);
|
||||
CHECK_LOAD(handle, mlx_io_writer_new);
|
||||
CHECK_LOAD(handle, mlx_io_writer_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_writer_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_writer_free);
|
||||
CHECK_LOAD(handle, mlx_linalg_cholesky);
|
||||
CHECK_LOAD(handle, mlx_linalg_cholesky_inv);
|
||||
CHECK_LOAD(handle, mlx_linalg_cross);
|
||||
@@ -2400,7 +2515,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_set_cache_limit);
|
||||
CHECK_LOAD(handle, mlx_set_memory_limit);
|
||||
CHECK_LOAD(handle, mlx_set_wired_limit);
|
||||
CHECK_LOAD(handle, mlx_metal_device_info);
|
||||
CHECK_LOAD(handle, mlx_metal_is_available);
|
||||
CHECK_LOAD(handle, mlx_metal_start_capture);
|
||||
CHECK_LOAD(handle, mlx_metal_stop_capture);
|
||||
@@ -2486,6 +2600,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_full);
|
||||
CHECK_LOAD(handle, mlx_full_like);
|
||||
CHECK_LOAD(handle, mlx_gather);
|
||||
CHECK_LOAD(handle, mlx_gather_single);
|
||||
CHECK_LOAD(handle, mlx_gather_mm);
|
||||
CHECK_LOAD(handle, mlx_gather_qmm);
|
||||
CHECK_LOAD(handle, mlx_greater);
|
||||
@@ -2550,6 +2665,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_prod_axis);
|
||||
CHECK_LOAD(handle, mlx_prod);
|
||||
CHECK_LOAD(handle, mlx_put_along_axis);
|
||||
CHECK_LOAD(handle, mlx_qqmm);
|
||||
CHECK_LOAD(handle, mlx_quantize);
|
||||
CHECK_LOAD(handle, mlx_quantized_matmul);
|
||||
CHECK_LOAD(handle, mlx_radians);
|
||||
@@ -2566,11 +2682,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_round);
|
||||
CHECK_LOAD(handle, mlx_rsqrt);
|
||||
CHECK_LOAD(handle, mlx_scatter);
|
||||
CHECK_LOAD(handle, mlx_scatter_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_add);
|
||||
CHECK_LOAD(handle, mlx_scatter_add_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_add_axis);
|
||||
CHECK_LOAD(handle, mlx_scatter_max);
|
||||
CHECK_LOAD(handle, mlx_scatter_max_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_min);
|
||||
CHECK_LOAD(handle, mlx_scatter_min_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_prod);
|
||||
CHECK_LOAD(handle, mlx_scatter_prod_single);
|
||||
CHECK_LOAD(handle, mlx_segmented_mm);
|
||||
CHECK_LOAD(handle, mlx_sigmoid);
|
||||
CHECK_LOAD(handle, mlx_sign);
|
||||
@@ -2665,8 +2786,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_string_set);
|
||||
CHECK_LOAD(handle, mlx_string_data);
|
||||
CHECK_LOAD(handle, mlx_string_free);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_replace);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_trace);
|
||||
CHECK_LOAD(handle, mlx_async_eval);
|
||||
CHECK_LOAD(handle, mlx_checkpoint);
|
||||
CHECK_LOAD(handle, mlx_custom_function);
|
||||
@@ -2675,6 +2794,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_jvp);
|
||||
CHECK_LOAD(handle, mlx_value_and_grad);
|
||||
CHECK_LOAD(handle, mlx_vjp);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_replace);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_trace);
|
||||
CHECK_LOAD(handle, mlx_vector_array_new);
|
||||
CHECK_LOAD(handle, mlx_vector_array_set);
|
||||
CHECK_LOAD(handle, mlx_vector_array_free);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,10 @@
|
||||
#define MLX_GENERATED_H
|
||||
|
||||
#include "dynamic.h"
|
||||
{{ range .Functions }}
|
||||
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
|
||||
{{- end }}
|
||||
|
||||
#include "mlx/c/mlx.h"
|
||||
{{ range .Functions }}
|
||||
#undef {{ .Name }}
|
||||
|
||||
@@ -306,19 +306,42 @@ func AddMM(c, a, b *Array, alpha, beta float32) *Array {
|
||||
|
||||
// Scalar helpers
|
||||
|
||||
// scalarWithDtype creates a scalar array matching the dtype of a.
|
||||
// Matching dtype is important for graph fusion and avoiding implicit casts.
|
||||
func scalarWithDtype(s float32, a *Array) C.mlx_array {
|
||||
f32 := C.mlx_array_new_float(C.float(s))
|
||||
dtype := a.DType()
|
||||
if dtype == DTypeFloat32 {
|
||||
return f32
|
||||
}
|
||||
casted := C.mlx_array_new()
|
||||
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
C.mlx_array_free(f32)
|
||||
return casted
|
||||
}
|
||||
|
||||
func AddScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return a.Add(scalar)
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("ADD_SCALAR", a)
|
||||
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
}
|
||||
|
||||
func MulScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return a.Multiply(scalar)
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("MUL_SCALAR", a)
|
||||
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
}
|
||||
|
||||
func DivScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return a.Divide(scalar)
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("DIV_SCALAR", a)
|
||||
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
}
|
||||
|
||||
func FloorDivideScalar(a *Array, s int32) *Array {
|
||||
|
||||
85
x/mlxrunner/model/base/base.go
Normal file
85
x/mlxrunner/model/base/base.go
Normal file
@@ -0,0 +1,85 @@
|
||||
//go:build mlx
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
)
|
||||
|
||||
// Model is the interface that model implementations must satisfy.
|
||||
type Model interface {
|
||||
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
|
||||
Unembed(x *mlx.Array) *mlx.Array
|
||||
NumLayers() int
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns
|
||||
// them to model fields. Model-specific logic (MLA absorption, expert
|
||||
// stacking, quantized layer creation) happens here.
|
||||
LoadWeights(tensors map[string]*mlx.Array) error
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
registry = make(map[string]func(root *model.Root) (Model, error))
|
||||
)
|
||||
|
||||
// Register registers a model constructor by architecture name.
|
||||
// Called from init() in model packages. Panics on duplicate registration.
|
||||
func Register(arch string, fn func(root *model.Root) (Model, error)) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if _, exists := registry[arch]; exists {
|
||||
panic(fmt.Sprintf("model architecture %q already registered", arch))
|
||||
}
|
||||
registry[arch] = fn
|
||||
}
|
||||
|
||||
// New reads config.json from the manifest, detects the architecture, looks up
|
||||
// the registered constructor, and calls it to create the model (with config
|
||||
// parsed and struct created, but weights not yet loaded).
|
||||
func New(root *model.Root) (Model, error) {
|
||||
configData, err := root.Manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
var archConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &archConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config.json: %w", err)
|
||||
}
|
||||
|
||||
if len(archConfig.Architectures) == 0 {
|
||||
return nil, fmt.Errorf("no architectures found in config.json")
|
||||
}
|
||||
|
||||
arch := archConfig.Architectures[0]
|
||||
slog.Info("Model architecture", "arch", arch)
|
||||
|
||||
mu.Lock()
|
||||
fn, ok := registry[arch]
|
||||
mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported architecture: %s", arch)
|
||||
}
|
||||
|
||||
return fn(root)
|
||||
}
|
||||
|
||||
// Weights returns the model's LoadWeights method, which encapsulates all
|
||||
// weight assignment and post-processing (MLA absorption, expert stacking).
|
||||
func Weights(m Model) func(map[string]*mlx.Array) error {
|
||||
return m.LoadWeights
|
||||
}
|
||||
3
x/mlxrunner/model/base/base_stub.go
Normal file
3
x/mlxrunner/model/base/base_stub.go
Normal file
@@ -0,0 +1,3 @@
|
||||
//go:build !mlx
|
||||
|
||||
package base
|
||||
92
x/mlxrunner/model/linear.go
Normal file
92
x/mlxrunner/model/linear.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build mlx
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// LinearFactory builds linear layers using shared tensor maps and quant defaults.
|
||||
type LinearFactory struct {
|
||||
tensors map[string]*mlx.Array
|
||||
defaultGroupSize int
|
||||
defaultBits int
|
||||
defaultMode string
|
||||
tensorQuant map[string]*TensorQuantInfo
|
||||
}
|
||||
|
||||
// NewLinearFactory creates a reusable constructor for model linear layers.
|
||||
func NewLinearFactory(
|
||||
tensors map[string]*mlx.Array,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) LinearFactory {
|
||||
return LinearFactory{
|
||||
tensors: tensors,
|
||||
defaultGroupSize: defaultGroupSize,
|
||||
defaultBits: defaultBits,
|
||||
defaultMode: defaultMode,
|
||||
tensorQuant: tensorQuant,
|
||||
}
|
||||
}
|
||||
|
||||
// Make constructs a linear layer at path.
|
||||
func (f LinearFactory) Make(path string) nn.LinearLayer {
|
||||
return MakeLinearLayer(
|
||||
f.tensors,
|
||||
path,
|
||||
f.defaultGroupSize,
|
||||
f.defaultBits,
|
||||
f.defaultMode,
|
||||
f.tensorQuant,
|
||||
)
|
||||
}
|
||||
|
||||
// MakeLinearLayer constructs a linear layer from a tensor map.
|
||||
//
|
||||
// For quantized tensors (path.weight + path.weight_scale), it resolves per-tensor
|
||||
// quant params via TensorQuant metadata (with shape-based affine fallback).
|
||||
// For non-quantized tensors, it returns a standard nn.Linear.
|
||||
func MakeLinearLayer(
|
||||
tensors map[string]*mlx.Array,
|
||||
path string,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) nn.LinearLayer {
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scales := tensors[path+".weight_scale"]
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
bias := tensors[path+".bias"]
|
||||
|
||||
groupSize, bits, mode := ResolveLinearQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
path+".weight",
|
||||
w,
|
||||
scales,
|
||||
)
|
||||
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: w,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
bias := tensors[path+".bias"]
|
||||
return nn.NewLinear(w, bias)
|
||||
}
|
||||
130
x/mlxrunner/model/quant.go
Normal file
130
x/mlxrunner/model/quant.go
Normal file
@@ -0,0 +1,130 @@
|
||||
//go:build mlx
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// QuantizationParams returns default groupSize, bits, and mode for a quantization type.
|
||||
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "NVFP4":
|
||||
return 16, 4, "nvfp4"
|
||||
case "FP4", "Q4", "INT4":
|
||||
return 32, 4, "affine"
|
||||
case "MXFP8":
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8", "":
|
||||
return 64, 8, "affine"
|
||||
default:
|
||||
return 32, 8, "affine"
|
||||
}
|
||||
}
|
||||
|
||||
// TensorQuantParams resolves quant params for a tensor using per-tensor metadata
|
||||
// when available, otherwise falling back to the provided model defaults.
|
||||
func TensorQuantParams(
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
tensorName string,
|
||||
) (groupSize, bits int, mode string, fromTensor bool) {
|
||||
if tensorQuant != nil {
|
||||
if tq := tensorQuant[tensorName]; tq != nil {
|
||||
groupSize, bits, mode = QuantizationParams(tq.QuantType)
|
||||
if tq.GroupSize > 0 {
|
||||
groupSize = tq.GroupSize
|
||||
}
|
||||
return groupSize, bits, mode, true
|
||||
}
|
||||
}
|
||||
return defaultGroupSize, defaultBits, defaultMode, false
|
||||
}
|
||||
|
||||
// ResolveLinearQuantParams resolves quantization params for a quantized linear
|
||||
// tensor, preferring per-tensor metadata and falling back to shape-based
|
||||
// inference for affine packed tensors.
|
||||
func ResolveLinearQuantParams(
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
tensorName string,
|
||||
weight, scales *mlx.Array,
|
||||
) (groupSize, bits int, mode string) {
|
||||
groupSize, bits, mode, fromTensor := TensorQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
tensorName,
|
||||
)
|
||||
|
||||
if mode == "affine" {
|
||||
if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok {
|
||||
if !fromTensor || groupSize == 0 || bits == 0 {
|
||||
groupSize = inferredGroupSize
|
||||
bits = inferredBits
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return groupSize, bits, mode
|
||||
}
|
||||
|
||||
// InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized
|
||||
// tensors from packed weight and scale shapes.
|
||||
func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) {
|
||||
if weight == nil || scales == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
weightShape := weight.Dims()
|
||||
scaleShape := scales.Dims()
|
||||
if len(weightShape) == 0 || len(scaleShape) == 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
weightCols := weightShape[len(weightShape)-1]
|
||||
scalesCols := scaleShape[len(scaleShape)-1]
|
||||
if weightCols <= 0 || scalesCols <= 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
groupSize4 := weightCols * 8 / scalesCols
|
||||
groupSize8 := weightCols * 4 / scalesCols
|
||||
|
||||
switch {
|
||||
case groupSize4 == 32:
|
||||
return 32, 4, true
|
||||
case groupSize8 == 64:
|
||||
return 64, 8, true
|
||||
case groupSize4 == 64 && groupSize8 == 32:
|
||||
if hintBits == 8 {
|
||||
return 32, 8, true
|
||||
}
|
||||
if hintBits == 4 {
|
||||
return 64, 4, true
|
||||
}
|
||||
}
|
||||
|
||||
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
|
||||
return groupSize4, 4, true
|
||||
}
|
||||
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
|
||||
return groupSize8, 8, true
|
||||
}
|
||||
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func isCommonGroupSize(v int) bool {
|
||||
switch v {
|
||||
case 16, 32, 64, 128:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user