Compare commits

..

11 Commits

Author SHA1 Message Date
jmorganca
ebbebdf3b1 better logging / fix defaults 2025-12-20 21:05:34 -08:00
jmorganca
49393385ca better logging 2025-12-20 19:52:05 -08:00
jmorganca
12ff2d1461 tests 2025-12-20 19:23:09 -08:00
jmorganca
f90d968b8b skip 2025-12-20 19:10:35 -08:00
jmorganca
c623b256a3 better logs 2025-12-20 18:13:00 -08:00
jmorganca
8c8fb2f9f0 tweak 2025-12-20 18:10:24 -08:00
jmorganca
6e00a0c89a speed tracker 2025-12-20 18:02:52 -08:00
jmorganca
55b1ee2557 wip 2025-12-20 18:01:07 -08:00
jmorganca
51cb1155ba streaming 2025-12-20 17:39:55 -08:00
jmorganca
7c5b656bb3 wip 2025-12-20 17:24:03 -08:00
jmorganca
bddb27ab5b client2 updated 2025-12-20 16:16:34 -08:00
33 changed files with 921 additions and 2932 deletions

View File

@@ -1,779 +0,0 @@
// Package anthropic provides core transformation logic for compatibility with the Anthropic Messages API
package anthropic
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// Error types matching Anthropic API
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ErrorResponse struct {
Type string `json:"type"` // always "error"
Error Error `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusUnauthorized:
etype = "authentication_error"
case http.StatusForbidden:
etype = "permission_error"
case http.StatusNotFound:
etype = "not_found_error"
case http.StatusTooManyRequests:
etype = "rate_limit_error"
case http.StatusServiceUnavailable, 529:
etype = "overloaded_error"
default:
etype = "api_error"
}
return ErrorResponse{
Type: "error",
Error: Error{Type: etype, Message: message},
RequestID: generateID("req"),
}
}
// Request types
// MessagesRequest represents an Anthropic Messages API request
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// MessageParam represents a message in the request
type MessageParam struct {
Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock
}
// ContentBlock represents a content block in a message
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
// For text blocks
Text string `json:"text,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
IsError bool `json:"is_error,omitempty"`
// For thinking blocks
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
MediaType string `json:"media_type,omitempty"`
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
}
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
}
// ToolChoice controls how the model uses tools
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", "none"
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
// ThinkingConfig controls extended thinking
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"`
}
// Metadata for the request
type Metadata struct {
UserID string `json:"user_id,omitempty"`
}
// Response types
// MessagesResponse represents an Anthropic Messages API response
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage contains token usage information
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// Streaming event types
// MessageStartEvent is sent at the start of streaming
type MessageStartEvent struct {
Type string `json:"type"` // "message_start"
Message MessagesResponse `json:"message"`
}
// ContentBlockStartEvent signals the start of a content block
type ContentBlockStartEvent struct {
Type string `json:"type"` // "content_block_start"
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
// ContentBlockDeltaEvent contains incremental content updates
type ContentBlockDeltaEvent struct {
Type string `json:"type"` // "content_block_delta"
Index int `json:"index"`
Delta Delta `json:"delta"`
}
// Delta represents an incremental update
type Delta struct {
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ContentBlockStopEvent signals the end of a content block
type ContentBlockStopEvent struct {
Type string `json:"type"` // "content_block_stop"
Index int `json:"index"`
}
// MessageDeltaEvent contains updates to the message
type MessageDeltaEvent struct {
Type string `json:"type"` // "message_delta"
Delta MessageDelta `json:"delta"`
Usage DeltaUsage `json:"usage"`
}
// MessageDelta contains stop information
type MessageDelta struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
OutputTokens int `json:"output_tokens"`
}
// MessageStopEvent signals the end of the message
type MessageStopEvent struct {
Type string `json:"type"` // "message_stop"
}
// PingEvent is a keepalive event
type PingEvent struct {
Type string `json:"type"` // "ping"
}
// StreamErrorEvent is an error during streaming
type StreamErrorEvent struct {
Type string `json:"type"` // "error"
Error Error `json:"error"`
}
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
var messages []api.Message
// Handle system prompt
if r.System != nil {
switch sys := r.System.(type) {
case string:
if sys != "" {
messages = append(messages, api.Message{Role: "system", Content: sys})
}
case []any:
// System can be an array of content blocks
var content strings.Builder
for _, block := range sys {
if blockMap, ok := block.(map[string]any); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content.WriteString(text)
}
}
}
}
if content.Len() > 0 {
messages = append(messages, api.Message{Role: "system", Content: content.String()})
}
}
}
// Convert messages
for _, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, converted...)
}
// Build options
options := make(map[string]any)
options["num_predict"] = r.MaxTokens
if r.Temperature != nil {
options["temperature"] = *r.Temperature
}
if r.TopP != nil {
options["top_p"] = *r.TopP
}
if r.TopK != nil {
options["top_k"] = *r.TopK
}
if len(r.StopSequences) > 0 {
options["stop"] = r.StopSequences
}
// Convert tools
var tools api.Tools
for _, t := range r.Tools {
tool, err := convertTool(t)
if err != nil {
return nil, err
}
tools = append(tools, tool)
}
// Handle thinking
var think *api.ThinkValue
if r.Thinking != nil && r.Thinking.Type == "enabled" {
think = &api.ThinkValue{Value: true}
}
stream := r.Stream
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: role, Content: content})
case []any:
// Handle array of content blocks
var textContent strings.Builder
var images []api.ImageData
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
return nil, errors.New("invalid content block format")
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
}
// URL images would need to be fetched - skip for now
case "tool_use":
id, ok := blockMap["id"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = api.ToolCallFunctionArguments(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
// Extract text from content blocks
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
}
}
}
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
}
}
// Build the main message
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Parameters: params,
},
}, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
var content []ContentBlock
// Add thinking block if present
if r.Message.Thinking != "" {
content = append(content, ContentBlock{
Type: "thinking",
Thinking: r.Message.Thinking,
})
}
// Add text content if present
if r.Message.Content != "" {
content = append(content, ContentBlock{
Type: "text",
Text: r.Message.Content,
})
}
// Add tool use blocks
for _, tc := range r.Message.ToolCalls {
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: tc.Function.Arguments,
})
}
// Map stop reason
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
return MessagesResponse{
ID: id,
Type: "message",
Role: "assistant",
Model: r.Model,
Content: content,
StopReason: stopReason,
Usage: Usage{
InputTokens: r.Metrics.PromptEvalCount,
OutputTokens: r.Metrics.EvalCount,
},
}
}
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
func mapStopReason(reason string, hasToolCalls bool) string {
if hasToolCalls {
return "tool_use"
}
switch reason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
default:
if reason != "" {
return "stop_sequence"
}
return ""
}
}
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
// NewStreamConverter creates a new StreamConverter
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
// StreamEvent represents a streaming event to be sent to the client
type StreamEvent struct {
Event string
Data any
}
// Process converts an Ollama ChatResponse to Anthropic streaming events
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
var events []StreamEvent
// First write: emit message_start
if c.firstWrite {
c.firstWrite = false
c.inputTokens = r.Metrics.PromptEvalCount
events = append(events, StreamEvent{
Event: "message_start",
Data: MessageStartEvent{
Type: "message_start",
Message: MessagesResponse{
ID: c.ID,
Type: "message",
Role: "assistant",
Model: c.Model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: c.inputTokens,
OutputTokens: 0,
},
},
},
})
}
// Handle thinking content
if r.Message.Thinking != "" && !c.thinkingDone {
if !c.thinkingStarted {
c.thinkingStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "thinking",
Thinking: "",
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "thinking_delta",
Thinking: r.Message.Thinking,
},
},
})
}
// Handle text content
if r.Message.Content != "" {
// Close thinking block if it was open
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if !c.textStarted {
c.textStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "text",
Text: "",
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "text_delta",
Text: r.Message.Content,
},
},
})
}
// Handle tool calls
for _, tc := range r.Message.ToolCalls {
if c.toolCallsSent[tc.ID] {
continue
}
// Close any previous block
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
c.textStarted = false
}
// Start tool use block
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: map[string]any{},
},
},
})
// Send input as JSON delta
argsJSON, _ := json.Marshal(tc.Function.Arguments)
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "input_json_delta",
PartialJSON: string(argsJSON),
},
},
})
// Close tool use block
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.toolCallsSent[tc.ID] = true
c.contentIndex++
}
// Handle done
if r.Done {
// Close any open block
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
} else if c.thinkingStarted && !c.thinkingDone {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
}
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
events = append(events, StreamEvent{
Event: "message_delta",
Data: MessageDeltaEvent{
Type: "message_delta",
Delta: MessageDelta{
StopReason: stopReason,
},
Usage: DeltaUsage{
OutputTokens: c.outputTokens,
},
},
})
events = append(events, StreamEvent{
Event: "message_stop",
Data: MessageStopEvent{
Type: "message_stop",
},
})
}
return events
}
// generateID generates a unique ID with the given prefix using crypto/rand
func generateID(prefix string) string {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based ID if crypto/rand fails
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
return fmt.Sprintf("%s_%x", prefix, b)
}
// GenerateMessageID generates a unique message ID
func GenerateMessageID() string {
return generateID("msg")
}

View File

@@ -1,667 +0,0 @@
package anthropic
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
}
}
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
t.Errorf("unexpected system message: %+v", result.Messages[0])
}
}
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: []any{
map[string]any{"type": "text", "text": "You are helpful."},
map[string]any{"type": "text", "text": " Be concise."},
},
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "You are helpful. Be concise." {
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
}
}
func TestFromMessagesRequest_WithOptions(t *testing.T) {
temp := 0.7
topP := 0.9
topK := 40
req := MessagesRequest{
Model: "test-model",
MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp,
TopP: &topP,
TopK: &topK,
StopSequences: []string{"\n", "END"},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
}
if result.Options["top_p"] != 0.9 {
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
}
if result.Options["top_k"] != 40 {
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
}
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
t.Errorf("stop sequences mismatch: %s", diff)
}
}
func TestFromMessagesRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "What's in this image?"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Content != "What's in this image?" {
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
}
if len(result.Messages[0].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
}
if string(result.Messages[0].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": map[string]any{"location": "Paris"},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if len(result.Messages[1].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
}
tc := result.Messages[1].ToolCalls[0]
if tc.ID != "call_123" {
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
}
if tc.Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
}
}
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_123",
"content": "The weather in Paris is sunny, 22°C",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_123" {
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
}
if msg.Content != "The weather in Paris is sunny, 22°C" {
t.Errorf("unexpected content: %q", msg.Content)
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
}
tool := result.Tools[0]
if tool.Type != "function" {
t.Errorf("expected type 'function', got %q", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
}
if tool.Function.Description != "Get current weather" {
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Think == nil {
t.Fatal("expected Think to be set")
}
if v, ok := result.Think.Value.(bool); !ok || !v {
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
}
}
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"name": "get_weather",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use id")
}
if err.Error() != "tool_use block missing required 'id' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use name")
}
if err.Error() != "tool_use block missing required 'name' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "bad_tool",
InputSchema: json.RawMessage(`{invalid json`),
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for invalid tool schema")
}
}
func TestToMessagesResponse_Basic(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
result := ToMessagesResponse("msg_123", resp)
if result.ID != "msg_123" {
t.Errorf("expected ID 'msg_123', got %q", result.ID)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "text" || result.Content[0].Text != "Hello there!" {
t.Errorf("unexpected content: %+v", result.Content[0])
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", result.Usage)
}
}
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "Paris"},
},
},
},
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "tool_use" {
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].ID != "call_123" {
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
}
if result.Content[0].Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
}
if result.StopReason != "tool_use" {
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
}
}
func TestToMessagesResponse_WithThinking(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "The answer is 42.",
Thinking: "Let me think about this...",
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "thinking" {
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
}
if result.Content[0].Thinking != "Let me think about this..." {
t.Errorf("unexpected thinking content: %q", result.Content[0].Thinking)
}
if result.Content[1].Type != "text" {
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
}
}
func TestMapStopReason(t *testing.T) {
tests := []struct {
reason string
hasToolCalls bool
want string
}{
{"stop", false, "end_turn"},
{"length", false, "max_tokens"},
{"stop", true, "tool_use"},
{"other", false, "stop_sequence"},
{"", false, ""},
}
for _, tt := range tests {
got := mapStopReason(tt.reason, tt.hasToolCalls)
if got != tt.want {
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
}
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{401, "authentication_error"},
{403, "permission_error"},
{404, "not_found_error"},
{429, "rate_limit_error"},
{500, "api_error"},
{503, "overloaded_error"},
{529, "overloaded_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Type != "error" {
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
}
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
}
if result.RequestID == "" {
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
}
}
}
func TestGenerateMessageID(t *testing.T) {
id1 := GenerateMessageID()
id2 := GenerateMessageID()
if id1 == "" {
t.Error("GenerateMessageID returned empty string")
}
if id1 == id2 {
t.Error("GenerateMessageID returned duplicate IDs")
}
if len(id1) < 10 {
t.Errorf("GenerateMessageID returned short ID: %q", id1)
}
if id1[:4] != "msg_" {
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
}
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello",
},
Metrics: api.Metrics{PromptEvalCount: 10},
}
events1 := conv.Process(resp1)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
}
// Should have message_start, content_block_start, content_block_delta
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
if events1[1].Event != "content_block_start" {
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
}
if events1[2].Event != "content_block_delta" {
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
}
// Final chunk
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: " world!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_stop" {
hasStop = true
}
}
if !hasStop {
t.Error("expected message_stop event in final chunk")
}
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "Paris"},
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events := conv.Process(resp)
hasToolStart := false
hasToolDelta := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
hasToolDelta = true
}
}
}
}
if !hasToolStart {
t.Error("expected tool_use content_block_start event")
}
if !hasToolDelta {
t.Error("expected input_json_delta event")
}
}

View File

@@ -14,7 +14,6 @@
* [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](https://docs.ollama.com/modelfile)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
### Resources

View File

@@ -1,339 +0,0 @@
---
title: Anthropic compatibility
---
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Usage
### Environment variables
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
<CodeGroup dropdown>
```python basic.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama', # required but ignored
)
message = client.messages.create(
model='llama3.2:3b',
max_tokens=1024,
messages=[
{'role': 'user', 'content': 'Hello, how are you?'}
]
)
print(message.content[0].text)
```
```javascript basic.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama", // required but ignored
});
const message = await anthropic.messages.create({
model: "llama3.2:3b",
max_tokens: 1024,
messages: [{ role: "user", content: "Hello, how are you?" }],
});
console.log(message.content[0].text);
```
```shell basic.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: ollama" \
-H "anthropic-version: 2023-06-01" \
-d '{
"model": "llama3.2:3b",
"max_tokens": 1024,
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
}'
```
</CodeGroup>
### Streaming example
<CodeGroup dropdown>
```python streaming.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
with client.messages.stream(
model='llama3.2:3b',
max_tokens=1024,
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
) as stream:
for text in stream.text_stream:
print(text, end='', flush=True)
```
```javascript streaming.js
import Anthropic from "@anthropic-ai/sdk";
const anthropic = new Anthropic({
baseURL: "http://localhost:11434",
apiKey: "ollama",
});
const stream = await anthropic.messages.stream({
model: "llama3.2:3b",
max_tokens: 1024,
messages: [{ role: "user", content: "Count from 1 to 10" }],
});
for await (const event of stream) {
if (
event.type === "content_block_delta" &&
event.delta.type === "text_delta"
) {
process.stdout.write(event.delta.text);
}
}
```
```shell streaming.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "llama3.2:3b",
"max_tokens": 1024,
"stream": true,
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
}'
```
</CodeGroup>
### Tool calling example
<CodeGroup dropdown>
```python tools.py
import anthropic
client = anthropic.Anthropic(
base_url='http://localhost:11434',
api_key='ollama',
)
message = client.messages.create(
model='llama3.2:3b',
max_tokens=1024,
tools=[
{
'name': 'get_weather',
'description': 'Get the current weather in a location',
'input_schema': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
}
},
'required': ['location']
}
}
],
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
)
for block in message.content:
if block.type == 'tool_use':
print(f'Tool: {block.name}')
print(f'Input: {block.input}')
```
```shell tools.sh
curl -X POST http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "llama3.2:3b",
"max_tokens": 1024,
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}
}
],
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
}'
```
</CodeGroup>
## Using with Claude Code
[Claude Code](https://docs.anthropic.com/en/docs/claude-code) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model llama3.2:3b
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
Then run Claude Code with any Ollama model:
```shell
claude --model llama3.2:3b
claude --model qwen3:8b
claude --model deepseek-r1:14b
```
## Endpoints
### `/v1/messages`
#### Supported features
- [x] Messages
- [x] Streaming
- [x] System prompts
- [x] Multi-turn conversations
- [x] Vision (images)
- [x] Tools (function calling)
- [x] Tool results
- [x] Thinking/extended thinking
#### Supported request fields
- [x] `model`
- [x] `max_tokens`
- [x] `messages`
- [x] Text `content`
- [x] Image `content` (base64)
- [x] Array of content blocks
- [x] `tool_use` blocks
- [x] `tool_result` blocks
- [x] `thinking` blocks
- [x] `system` (string or array)
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `top_k`
- [x] `stop_sequences`
- [x] `tools`
- [x] `thinking`
- [ ] `tool_choice`
- [ ] `metadata`
#### Supported response fields
- [x] `id`
- [x] `type`
- [x] `role`
- [x] `model`
- [x] `content` (text, tool_use, thinking blocks)
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
- [x] `usage` (input_tokens, output_tokens)
#### Streaming events
- [x] `message_start`
- [x] `content_block_start`
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
- [x] `content_block_stop`
- [x] `message_delta`
- [x] `message_stop`
- [x] `ping`
- [x] `error`
## Models
Before using a model, pull it locally with `ollama pull`:
```shell
ollama pull llama3.2:3b
```
### Default model names
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
```shell
ollama cp llama3.2:3b claude-3-5-sonnet
```
Afterwards, this new model name can be specified in the `model` field:
```shell
curl http://localhost:11434/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "claude-3-5-sonnet",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
## Differences from the Anthropic API
### Behavior differences
- API key is accepted but not validated
- `anthropic-version` header is accepted but not used
- Token counts are approximations based on the underlying model's tokenizer
### Not supported
The following Anthropic API features are not currently supported:
| Feature | Description |
|---------|-------------|
| `/v1/messages/count_tokens` | Token counting endpoint |
| `tool_choice` | Forcing specific tool use or disabling tools |
| `metadata` | Request metadata (user_id) |
| Prompt caching | `cache_control` blocks for caching prefixes |
| Batches API | `/v1/messages/batches` for async batch processing |
| Citations | `citations` content blocks |
| PDF support | `document` content blocks with PDF files |
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
### Partial support
| Feature | Status |
|---------|--------|
| Image content | Base64 images supported; URL images not supported |
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |

View File

@@ -139,8 +139,7 @@
"/api/streaming",
"/api/usage",
"/api/errors",
"/api/openai-compatibility",
"/api/anthropic-compatibility"
"/api/openai-compatibility"
]
},
{

View File

@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
## How can I view the logs?
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
## Is my GPU compatible with Ollama?
Please refer to the [GPU docs](./gpu).
Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size?

View File

@@ -33,7 +33,7 @@ Check your compute compatibility to see if your card is supported:
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
For building locally to support older GPUs, see [developer](./development#linux-cuda-nvidia)
For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
### GPU Selection
@@ -54,7 +54,7 @@ sudo modprobe nvidia_uvm`
Ollama supports the following AMD GPUs via the ROCm library:
> **NOTE:**
> [!NOTE]
> Additional AMD GPU support is provided by the Vulkan Library - see below.
@@ -132,9 +132,9 @@ Ollama supports GPU acceleration on Apple devices via the Metal API.
## Vulkan GPU Support
> **NOTE:**
> [!NOTE]
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server)
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
Additional GPU support on Windows and Linux is provided via
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
@@ -161,6 +161,6 @@ sudo setcap cap_perfmon+ep /usr/local/bin/ollama
To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
by setting `GGML_VK_VISIBLE_DEVICES=-1`

View File

@@ -87,7 +87,7 @@ When Ollama starts up, it takes inventory of the GPUs present in the system to d
### Linux NVIDIA Troubleshooting
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker](./docker)
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem

View File

@@ -20,10 +20,10 @@ fix vulkan PCI ID and ID handling
ggml/src/ggml-cuda/vendors/hip.h | 3 +
ggml/src/ggml-impl.h | 8 +
ggml/src/ggml-metal/ggml-metal.cpp | 2 +
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 +++++++-
ggml/src/mem_hip.cpp | 558 +++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 209 ++++++++++
9 files changed, 1005 insertions(+), 17 deletions(-)
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 ++++++++-
ggml/src/mem_hip.cpp | 529 +++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 209 +++++++++++
9 files changed, 976 insertions(+), 17 deletions(-)
create mode 100644 ggml/src/mem_hip.cpp
create mode 100644 ggml/src/mem_nvml.cpp
@@ -58,7 +58,7 @@ index d55aed348..99ae293cc 100644
set_target_properties(ggml-base PROPERTIES
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 6852d2e20..334a30135 100644
index 6852d2e20..48cdb1dcf 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -267,6 +267,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
@@ -109,7 +109,7 @@ index 6852d2e20..334a30135 100644
+
+#if defined(GGML_USE_HIP)
+ if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total, ctx->integrated != 0);
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
+ ggml_hip_mgmt_release();
@@ -204,7 +204,7 @@ index 4e162258d..d89e35a8e 100644
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index fe57d4c58..dba8f4695 100644
index fe57d4c58..1c07e767a 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -677,6 +677,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
@@ -216,7 +216,7 @@ index fe57d4c58..dba8f4695 100644
+GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
+GGML_API void ggml_nvml_release();
+GGML_API int ggml_hip_mgmt_init();
+GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
+GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
+GGML_API void ggml_hip_mgmt_release();
+
#ifdef __cplusplus
@@ -243,7 +243,7 @@ index ba95b4acc..f6f8f7a10 100644
/* .async = */ true,
/* .host_buffer = */ false,
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 5349bce24..0103fd03a 100644
index 5349bce24..d43d46d1d 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -236,6 +236,7 @@ class vk_memory_logger;
@@ -334,7 +334,7 @@ index 5349bce24..0103fd03a 100644
+ switch (props2.properties.vendorID) {
+ case VK_VENDOR_ID_AMD:
+ if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
+ ggml_hip_mgmt_release();
@@ -505,10 +505,10 @@ index 5349bce24..0103fd03a 100644
}
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
new file mode 100644
index 000000000..23c765806
index 000000000..c1949b899
--- /dev/null
+++ b/ggml/src/mem_hip.cpp
@@ -0,0 +1,558 @@
@@ -0,0 +1,529 @@
+#include "ggml.h"
+#include "ggml-impl.h"
+
@@ -842,7 +842,7 @@ index 000000000..23c765806
+ if (gpus != NULL) gpus->pVtbl->Release(gpus); \
+ if (gpu != NULL) gpu->pVtbl->Release(gpu)
+
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
+ std::lock_guard<std::mutex> lock(ggml_adlx_lock);
+ if (adlx.handle == NULL) {
+ GGML_LOG_INFO("%s ADLX was not initialized\n", __func__);
@@ -966,16 +966,13 @@ index 000000000..23c765806
+ return 0;
+}
+void ggml_hip_mgmt_release() {}
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
+ GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
+ const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
+ const std::string drmTotalMemoryFile = "mem_info_vram_total";
+ const std::string drmUsedMemoryFile = "mem_info_vram_used";
+ const std::string drmGTTTotalMemoryFile = "mem_info_gtt_total";
+ const std::string drmGTTUsedMemoryFile = "mem_info_gtt_used";
+ const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
+
+
+ glob_t glob_result;
+ glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
+
@@ -1009,6 +1006,7 @@ index 000000000..23c765806
+
+ uint64_t memory;
+ totalFileStream >> memory;
+ *total = memory;
+
+ std::string usedFile = dir + "/" + drmUsedMemoryFile;
+ std::ifstream usedFileStream(usedFile.c_str());
@@ -1021,33 +1019,6 @@ index 000000000..23c765806
+
+ uint64_t memoryUsed;
+ usedFileStream >> memoryUsed;
+
+ if (is_integrated_gpu) {
+ std::string totalFile = dir + "/" + drmGTTTotalMemoryFile;
+ std::ifstream totalFileStream(totalFile.c_str());
+ if (!totalFileStream.is_open()) {
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
+ file.close();
+ globfree(&glob_result);
+ return 1;
+ }
+ uint64_t gtt;
+ totalFileStream >> gtt;
+ std::string usedFile = dir + "/" + drmGTTUsedMemoryFile;
+ std::ifstream usedFileStream(usedFile.c_str());
+ if (!usedFileStream.is_open()) {
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
+ file.close();
+ globfree(&glob_result);
+ return 1;
+ }
+ uint64_t gttUsed;
+ usedFileStream >> gttUsed;
+ memory += gtt;
+ memoryUsed += gttUsed;
+ }
+
+ *total = memory;
+ *free = memory - memoryUsed;
+
+ file.close();

View File

@@ -24,12 +24,12 @@ index 99ae293cc..9a134b7af 100644
set_target_properties(ggml-base PROPERTIES
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index dba8f4695..7e17032c7 100644
index 1c07e767a..0da3e065b 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -684,6 +684,9 @@ GGML_API void ggml_nvml_release();
GGML_API int ggml_hip_mgmt_init();
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
GGML_API void ggml_hip_mgmt_release();
+GGML_API int ggml_dxgi_pdh_init();
+GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu);
@@ -38,7 +38,7 @@ index dba8f4695..7e17032c7 100644
#ifdef __cplusplus
}
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 0103fd03a..9cc4ebdef 100644
index d43d46d1d..df79f9f79 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -74,6 +74,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();

View File

@@ -10,7 +10,7 @@ fallback to cpu
1 file changed, 3 insertions(+)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 334a30135..5c9dfd032 100644
index 48cdb1dcf..3102d7ea7 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -4633,6 +4633,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g

View File

@@ -1,152 +0,0 @@
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
type AnthropicWriter struct {
BaseWriter
stream bool
id string
model string
converter *anthropic.StreamConverter
}
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(serr.StatusCode, serr.Error()))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
d, err := json.Marshal(data)
if err != nil {
return err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
if err != nil {
return err
}
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
return nil
}
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
if w.stream {
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
events := w.converter.Process(chatResponse)
for _, event := range events {
if err := w.writeEvent(event.Event, event.Data); err != nil {
return 0, err
}
}
return len(data), nil
}
// Non-streaming response
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := anthropic.ToMessagesResponse(w.id, chatResponse)
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
func (w *AnthropicWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
func AnthropicMessagesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req anthropic.MessagesRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
// Validate required fields
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
return
}
if req.MaxTokens <= 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
return
}
// Convert to internal format
chatReq, err := anthropic.FromMessagesRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
messageID := anthropic.GenerateMessageID()
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model),
}
// Set headers based on streaming mode
if req.Stream {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
}
c.Writer = w
c.Next()
}
}

View File

@@ -1,487 +0,0 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
_ = json.Unmarshal(bodyBytes, capturedRequest)
c.Next()
}
}
func TestAnthropicMessagesMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.ChatRequest
err anthropic.ErrorResponse
}
var capturedRequest *api.ChatRequest
stream := true
testCases := []testCase{
{
name: "basic message",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with system prompt",
body: `{
"model": "test-model",
"max_tokens": 1024,
"system": "You are helpful.",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with options",
body: `{
"model": "test-model",
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop_sequences": ["\n", "END"],
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{
"num_predict": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop": []string{"\n", "END"},
},
Stream: &False,
},
},
{
name: "streaming",
body: `{
"model": "test-model",
"max_tokens": 1024,
"stream": true,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &stream,
},
},
{
name: "with tools",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"}
],
"tools": [{
"name": "get_weather",
"description": "Get current weather",
"input_schema": {
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with tool result",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"},
{"role": "assistant", "content": [
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
]},
{"role": "user", "content": [
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"location": "Paris"},
},
},
},
},
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with thinking enabled",
body: `{
"model": "test-model",
"max_tokens": 1024,
"thinking": {"type": "enabled", "budget_tokens": 1000},
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
Think: &api.ThinkValue{Value: true},
},
},
{
name: "missing model error",
body: `{
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "model is required",
},
},
},
{
name: "missing max_tokens error",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "max_tokens is required and must be positive",
},
},
},
{
name: "missing messages error",
body: `{
"model": "test-model",
"max_tokens": 1024
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "messages is required",
},
},
},
{
name: "tool_use missing id error",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "assistant", "content": [
{"type": "tool_use", "name": "test"}
]}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "tool_use block missing required 'id' field",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
router.Handle(http.MethodPost, "/v1/messages", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Type != "" {
// Expect error
if resp.Code == http.StatusOK {
t.Fatalf("expected error response, got 200 OK")
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != tc.err.Type {
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
}
if errResp.Error.Type != tc.err.Error.Type {
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
}
if errResp.Error.Message != tc.err.Error.Message {
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
}
if capturedRequest == nil {
t.Fatal("request was not captured")
}
// Compare relevant fields
if capturedRequest.Model != tc.req.Model {
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
}
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages); diff != "" {
t.Errorf("messages mismatch (-want +got):\n%s", diff)
}
if tc.req.Stream != nil && capturedRequest.Stream != nil {
if *tc.req.Stream != *capturedRequest.Stream {
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
}
}
if tc.req.Think != nil {
if capturedRequest.Think == nil {
t.Error("expected Think to be set")
} else if capturedRequest.Think.Value != tc.req.Think.Value {
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
}
}
})
}
}
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("streaming sets correct headers", func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Check headers were set
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
}
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
}
c.Status(http.StatusOK)
})
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
})
}
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != "invalid_request_error" {
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
}
}
func TestAnthropicWriter_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate Ollama response
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.Code)
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Text != "Hello there!" {
t.Errorf("expected text 'Hello there!', got %q", result.Content[0].Text)
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 {
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
}
if result.Usage.OutputTokens != 5 {
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
}
}

View File

@@ -4436,7 +4436,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
#if defined(GGML_USE_HIP)
if (ggml_hip_mgmt_init() == 0) {
int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total, ctx->integrated != 0);
int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
if (status == 0) {
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
ggml_hip_mgmt_release();

View File

@@ -682,7 +682,7 @@ GGML_API int ggml_nvml_init();
GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
GGML_API void ggml_nvml_release();
GGML_API int ggml_hip_mgmt_init();
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
GGML_API void ggml_hip_mgmt_release();
GGML_API int ggml_dxgi_pdh_init();
GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu);

View File

@@ -13710,7 +13710,7 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
switch (props2.properties.vendorID) {
case VK_VENDOR_ID_AMD:
if (ggml_hip_mgmt_init() == 0) {
int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
if (status == 0) {
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
ggml_hip_mgmt_release();

View File

@@ -331,7 +331,7 @@ void ggml_hip_mgmt_release() {
if (gpus != NULL) gpus->pVtbl->Release(gpus); \
if (gpu != NULL) gpu->pVtbl->Release(gpu)
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
std::lock_guard<std::mutex> lock(ggml_adlx_lock);
if (adlx.handle == NULL) {
GGML_LOG_INFO("%s ADLX was not initialized\n", __func__);
@@ -455,16 +455,13 @@ int ggml_hip_mgmt_init() {
return 0;
}
void ggml_hip_mgmt_release() {}
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
const std::string drmTotalMemoryFile = "mem_info_vram_total";
const std::string drmUsedMemoryFile = "mem_info_vram_used";
const std::string drmGTTTotalMemoryFile = "mem_info_gtt_total";
const std::string drmGTTUsedMemoryFile = "mem_info_gtt_used";
const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
glob_t glob_result;
glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
@@ -498,6 +495,7 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool
uint64_t memory;
totalFileStream >> memory;
*total = memory;
std::string usedFile = dir + "/" + drmUsedMemoryFile;
std::ifstream usedFileStream(usedFile.c_str());
@@ -510,33 +508,6 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool
uint64_t memoryUsed;
usedFileStream >> memoryUsed;
if (is_integrated_gpu) {
std::string totalFile = dir + "/" + drmGTTTotalMemoryFile;
std::ifstream totalFileStream(totalFile.c_str());
if (!totalFileStream.is_open()) {
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
file.close();
globfree(&glob_result);
return 1;
}
uint64_t gtt;
totalFileStream >> gtt;
std::string usedFile = dir + "/" + drmGTTUsedMemoryFile;
std::ifstream usedFileStream(usedFile.c_str());
if (!usedFileStream.is_open()) {
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
file.close();
globfree(&glob_result);
return 1;
}
uint64_t gttUsed;
usedFileStream >> gttUsed;
memory += gtt;
memoryUsed += gttUsed;
}
*total = memory;
*free = memory - memoryUsed;
file.close();

View File

@@ -2,9 +2,11 @@ package server
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"log/slog"
"math"
@@ -31,9 +33,45 @@ const maxRetries = 6
var (
errMaxRetriesExceeded = errors.New("max retries exceeded")
errPartStalled = errors.New("part stalled")
errPartSlow = errors.New("part slow, racing")
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
)
// speedTracker tracks download speeds and computes rolling median.
type speedTracker struct {
mu sync.Mutex
speeds []float64 // bytes per second
}
func (s *speedTracker) Record(bytesPerSec float64) {
s.mu.Lock()
defer s.mu.Unlock()
s.speeds = append(s.speeds, bytesPerSec)
// Keep last 100 samples
if len(s.speeds) > 100 {
s.speeds = s.speeds[1:]
}
}
func (s *speedTracker) Median() float64 {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.speeds) < 3 {
return 0 // not enough data
}
// Simple median: sort a copy and take middle
sorted := make([]float64, len(s.speeds))
copy(sorted, s.speeds)
for i := range sorted {
for j := i + 1; j < len(sorted); j++ {
if sorted[j] < sorted[i] {
sorted[i], sorted[j] = sorted[j], sorted[i]
}
}
}
return sorted[len(sorted)/2]
}
var blobDownloadManager sync.Map
type blobDownload struct {
@@ -94,26 +132,127 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
return nil
}
const (
numDownloadParts = 16
minDownloadPartSize int64 = 100 * format.MegaByte
maxDownloadPartSize int64 = 1000 * format.MegaByte
var (
downloadPartSize = int64(envInt("OLLAMA_DOWNLOAD_PART_SIZE", 64)) * format.MegaByte
downloadConcurrency = envInt("OLLAMA_DOWNLOAD_CONCURRENCY", 48)
)
func envInt(key string, defaultVal int) int {
if s := os.Getenv(key); s != "" {
if v, err := strconv.Atoi(s); err == nil {
return v
}
}
return defaultVal
}
// streamHasher reads a file sequentially and hashes it as chunks complete.
// Memory usage: ~64KB (just the read buffer), regardless of file size or concurrency.
// Works by reading from OS page cache - data just written is still in RAM.
type streamHasher struct {
file *os.File
hasher hash.Hash
parts []*blobDownloadPart
total int64 // total bytes to hash
hashed atomic.Int64
mu sync.Mutex
cond *sync.Cond
completed []bool
done bool
err error
}
func newStreamHasher(file *os.File, parts []*blobDownloadPart, total int64) *streamHasher {
h := &streamHasher{
file: file,
hasher: sha256.New(),
parts: parts,
total: total,
completed: make([]bool, len(parts)),
}
h.cond = sync.NewCond(&h.mu)
return h
}
// MarkComplete signals that a part has been written to disk.
func (h *streamHasher) MarkComplete(partIndex int) {
h.mu.Lock()
h.completed[partIndex] = true
h.cond.Broadcast()
h.mu.Unlock()
}
// Run reads and hashes the file sequentially. Call in a goroutine.
func (h *streamHasher) Run() {
buf := make([]byte, 64*1024) // 64KB read buffer
var offset int64
for i, part := range h.parts {
// Wait for this part to be written
h.mu.Lock()
for !h.completed[i] && !h.done {
h.cond.Wait()
}
if h.done {
h.mu.Unlock()
return
}
h.mu.Unlock()
// Read and hash this part (from page cache)
remaining := part.Size
for remaining > 0 {
n := int64(len(buf))
if n > remaining {
n = remaining
}
nr, err := h.file.ReadAt(buf[:n], offset)
if err != nil && err != io.EOF {
h.mu.Lock()
h.err = err
h.mu.Unlock()
return
}
h.hasher.Write(buf[:nr])
offset += int64(nr)
remaining -= int64(nr)
h.hashed.Store(offset)
}
}
}
// Stop signals the hasher to exit early.
func (h *streamHasher) Stop() {
h.mu.Lock()
h.done = true
h.cond.Broadcast()
h.mu.Unlock()
}
// Hashed returns bytes hashed so far.
func (h *streamHasher) Hashed() int64 {
return h.hashed.Load()
}
// Digest returns the computed hash.
func (h *streamHasher) Digest() string {
return fmt.Sprintf("sha256:%x", h.hasher.Sum(nil))
}
// Err returns any error from hashing.
func (h *streamHasher) Err() error {
h.mu.Lock()
defer h.mu.Unlock()
return h.err
}
func (p *blobDownloadPart) Name() string {
return strings.Join([]string{
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
}, "-")
}
func (p *blobDownloadPart) StartsAt() int64 {
return p.Offset + p.Completed.Load()
}
func (p *blobDownloadPart) StopsAt() int64 {
return p.Offset + p.Size
}
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
n = len(b)
p.blobDownload.Completed.Add(int64(n))
@@ -151,14 +290,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
size := b.Total / numDownloadParts
switch {
case size < minDownloadPartSize:
size = minDownloadPartSize
case size > maxDownloadPartSize:
size = maxDownloadPartSize
}
size := downloadPartSize
var offset int64
for offset < b.Total {
if offset+size > b.Total {
@@ -220,9 +352,6 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err
}
defer file.Close()
setSparse(file)
_ = file.Truncate(b.Total)
directURL, err := func() (*url.URL, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
@@ -270,44 +399,106 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err
}
// Download chunks to disk, hash by reading from page cache.
// Memory: ~64KB (hasher read buffer only), regardless of concurrency.
// The hasher follows behind the downloaders, reading recently-written
// data from OS page cache (RAM) rather than disk.
sh := newStreamHasher(file, b.Parts, b.Total)
tracker := &speedTracker{}
// Start hasher goroutine
hashDone := make(chan struct{})
go func() {
sh.Run()
close(hashDone)
}()
// Log progress periodically
// Page cache warning: if spread > 4GB, hasher may hit disk instead of RAM
const pageCacheWarningBytes = 4 << 30 // 4GB
progressDone := make(chan struct{})
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
downloaded := b.Completed.Load()
hashed := sh.Hashed()
dlPct := int(downloaded * 100 / b.Total)
hPct := int(hashed * 100 / b.Total)
spread := dlPct - hPct
spreadBytes := downloaded - hashed
slog.Debug(fmt.Sprintf("progress: downloaded %d%% | hashed %d%% | spread %d%%", dlPct, hPct, spread))
if spreadBytes > pageCacheWarningBytes {
slog.Debug("page cache pressure", "ahead", fmt.Sprintf("%.1fGB", float64(spreadBytes)/(1<<30)))
}
case <-progressDone:
return
}
}
}()
g, inner := errgroup.WithContext(ctx)
g.SetLimit(numDownloadParts)
g.SetLimit(downloadConcurrency)
for i := range b.Parts {
part := b.Parts[i]
if part.Completed.Load() == part.Size {
sh.MarkComplete(part.N)
continue
}
g.Go(func() error {
var err error
var slowRetries int
for try := 0; try < maxRetries; try++ {
w := io.NewOffsetWriter(file, part.StartsAt())
err = b.downloadChunk(inner, directURL, w, part)
// After 3 slow retries, stop checking slowness and let it complete
skipSlowCheck := slowRetries >= 3
err = b.downloadChunkToDisk(inner, directURL, file, part, tracker, skipSlowCheck)
switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space
return err
case errors.Is(err, errPartStalled):
try--
continue
case errors.Is(err, errPartSlow):
// Kill slow request, retry immediately (stays within concurrency limit)
slowRetries++
try--
continue
case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
time.Sleep(sleep)
continue
default:
sh.MarkComplete(part.N)
return nil
}
}
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
})
}
if err := g.Wait(); err != nil {
close(progressDone)
sh.Stop()
return err
}
// Wait for hasher to finish
<-hashDone
close(progressDone)
if err := sh.Err(); err != nil {
return err
}
// Verify hash
if computed := sh.Digest(); computed != b.Digest {
return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest)
}
// explicitly close the file so we can rename it
if err := file.Close(); err != nil {
return err
@@ -326,38 +517,69 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return nil
}
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
// downloadChunkToDisk streams a part directly to disk at its offset.
// Memory: ~32KB (read buffer only).
// If skipSlowCheck is true, don't flag slow parts (used after repeated slow retries).
func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error {
g, ctx := errgroup.WithContext(ctx)
startTime := time.Now()
var bytesAtLastCheck atomic.Int64
g.Go(func() error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.Offset, part.Offset+part.Size-1))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress
b.Completed.Add(-n)
return err
w := io.NewOffsetWriter(file, part.Offset)
buf := make([]byte, 32*1024)
var written int64
for written < part.Size {
n, err := resp.Body.Read(buf)
if n > 0 {
if _, werr := w.Write(buf[:n]); werr != nil {
return werr
}
written += int64(n)
b.Completed.Add(int64(n))
bytesAtLastCheck.Store(written)
part.lastUpdatedMu.Lock()
part.lastUpdated = time.Now()
part.lastUpdatedMu.Unlock()
}
if err == io.EOF {
break
}
if err != nil {
b.Completed.Add(-written)
return err
}
}
part.Completed.Add(n)
if err := b.writePart(part.Name(), part); err != nil {
return err
// Record speed for this part
elapsed := time.Since(startTime).Seconds()
if elapsed > 0 {
tracker.Record(float64(part.Size) / elapsed)
}
// return nil or context.Canceled or UnexpectedEOF (resumable)
return err
part.Completed.Store(part.Size)
return b.writePart(part.Name(), part)
})
g.Go(func() error {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
var lastBytes int64
checksWithoutProgress := 0
for {
select {
case <-ticker.C:
@@ -365,19 +587,47 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
return nil
}
currentBytes := bytesAtLastCheck.Load()
// Check for complete stall (30 seconds no progress)
part.lastUpdatedMu.Lock()
lastUpdated := part.lastUpdated
part.lastUpdatedMu.Unlock()
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second {
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
// reset last updated
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
part.lastUpdatedMu.Lock()
part.lastUpdated = time.Time{}
part.lastUpdatedMu.Unlock()
return errPartStalled
}
// Check for slow speed after 5+ seconds (only for multi-part downloads)
// Skip if we've already retried for slowness too many times
elapsed := time.Since(startTime).Seconds()
if !skipSlowCheck && elapsed >= 5 && currentBytes > 0 && len(b.Parts) > 1 {
currentSpeed := float64(currentBytes) / elapsed
median := tracker.Median()
// If we're below 10% of median speed, flag as slow
if median > 0 && currentSpeed < median*0.1 {
slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); retrying",
b.Digest[7:19], part.N, currentSpeed/1024, median/1024))
return errPartSlow
}
}
// Also check if speed dropped significantly mid-download
if currentBytes == lastBytes {
checksWithoutProgress++
if checksWithoutProgress >= 10 {
slog.Info(fmt.Sprintf("%s part %d no progress for 10s; retrying", b.Digest[7:19], part.N))
return errPartStalled
}
} else {
checksWithoutProgress = 0
}
lastBytes = currentBytes
case <-ctx.Done():
return ctx.Err()
}

319
server/download_test.go Normal file
View File

@@ -0,0 +1,319 @@
package server
import (
"crypto/rand"
"crypto/sha256"
"fmt"
"os"
"sync"
"testing"
)
func TestSpeedTracker_Median(t *testing.T) {
s := &speedTracker{}
// Less than 3 samples returns 0
s.Record(100)
s.Record(200)
if got := s.Median(); got != 0 {
t.Errorf("expected 0 with < 3 samples, got %f", got)
}
// With 3+ samples, returns median
s.Record(300)
// Samples: [100, 200, 300] -> median = 200
if got := s.Median(); got != 200 {
t.Errorf("expected median 200, got %f", got)
}
// Add more samples
s.Record(50)
s.Record(250)
// Samples: [100, 200, 300, 50, 250] sorted = [50, 100, 200, 250, 300] -> median = 200
if got := s.Median(); got != 200 {
t.Errorf("expected median 200, got %f", got)
}
}
func TestSpeedTracker_RollingWindow(t *testing.T) {
s := &speedTracker{}
// Add 105 samples (should keep only last 100)
for i := 0; i < 105; i++ {
s.Record(float64(i))
}
s.mu.Lock()
if len(s.speeds) != 100 {
t.Errorf("expected 100 samples, got %d", len(s.speeds))
}
// First sample should be 5 (0-4 were dropped)
if s.speeds[0] != 5 {
t.Errorf("expected first sample to be 5, got %f", s.speeds[0])
}
s.mu.Unlock()
}
func TestSpeedTracker_Concurrent(t *testing.T) {
s := &speedTracker{}
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(v int) {
defer wg.Done()
s.Record(float64(v))
s.Median() // concurrent read
}(i)
}
wg.Wait()
// Should not panic, and should have reasonable state
s.mu.Lock()
if len(s.speeds) == 0 || len(s.speeds) > 100 {
t.Errorf("unexpected speeds length: %d", len(s.speeds))
}
s.mu.Unlock()
}
func TestStreamHasher_Sequential(t *testing.T) {
// Create temp file
f, err := os.CreateTemp("", "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
// Write test data
data := []byte("hello world, this is a test of the stream hasher")
if _, err := f.Write(data); err != nil {
t.Fatal(err)
}
// Create parts
parts := []*blobDownloadPart{
{Offset: 0, Size: int64(len(data))},
}
sh := newStreamHasher(f, parts, int64(len(data)))
// Mark complete and run
sh.MarkComplete(0)
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
<-done
// Verify digest
expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
if got := sh.Digest(); got != expected {
t.Errorf("digest mismatch: got %s, want %s", got, expected)
}
if err := sh.Err(); err != nil {
t.Errorf("unexpected error: %v", err)
}
}
func TestStreamHasher_OutOfOrderCompletion(t *testing.T) {
// Create temp file
f, err := os.CreateTemp("", "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
// Write test data (3 parts of 10 bytes each)
data := []byte("0123456789ABCDEFGHIJabcdefghij")
if _, err := f.Write(data); err != nil {
t.Fatal(err)
}
// Create 3 parts
parts := []*blobDownloadPart{
{N: 0, Offset: 0, Size: 10},
{N: 1, Offset: 10, Size: 10},
{N: 2, Offset: 20, Size: 10},
}
sh := newStreamHasher(f, parts, int64(len(data)))
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
// Mark parts complete out of order: 2, 0, 1
sh.MarkComplete(2)
sh.MarkComplete(0) // This should trigger hashing of part 0
sh.MarkComplete(1) // This should trigger hashing of parts 1 and 2
<-done
// Verify digest
expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
if got := sh.Digest(); got != expected {
t.Errorf("digest mismatch: got %s, want %s", got, expected)
}
}
func TestStreamHasher_Stop(t *testing.T) {
// Create temp file
f, err := os.CreateTemp("", "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
parts := []*blobDownloadPart{
{Offset: 0, Size: 100},
}
sh := newStreamHasher(f, parts, 100)
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
// Stop without completing any parts
sh.Stop()
<-done
// Should exit cleanly without error
if err := sh.Err(); err != nil {
t.Errorf("unexpected error after Stop: %v", err)
}
}
func TestStreamHasher_HashedProgress(t *testing.T) {
// Create temp file with known data
f, err := os.CreateTemp("", "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
data := make([]byte, 1000)
rand.Read(data)
if _, err := f.Write(data); err != nil {
t.Fatal(err)
}
parts := []*blobDownloadPart{
{N: 0, Offset: 0, Size: 500},
{N: 1, Offset: 500, Size: 500},
}
sh := newStreamHasher(f, parts, 1000)
// Initially no progress
if got := sh.Hashed(); got != 0 {
t.Errorf("expected 0 hashed initially, got %d", got)
}
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
// Complete part 0
sh.MarkComplete(0)
// Give hasher time to process
for i := 0; i < 100; i++ {
if sh.Hashed() >= 500 {
break
}
}
// Complete part 1
sh.MarkComplete(1)
<-done
if got := sh.Hashed(); got != 1000 {
t.Errorf("expected 1000 hashed, got %d", got)
}
}
func BenchmarkSpeedTracker_Record(b *testing.B) {
s := &speedTracker{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
s.Record(float64(i))
}
}
func BenchmarkSpeedTracker_Median(b *testing.B) {
s := &speedTracker{}
// Pre-populate with 100 samples
for i := 0; i < 100; i++ {
s.Record(float64(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
s.Median()
}
}
func BenchmarkStreamHasher(b *testing.B) {
// Create temp file with test data
f, err := os.CreateTemp("", "streamhasher_bench")
if err != nil {
b.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
size := 64 * 1024 * 1024 // 64MB
data := make([]byte, size)
rand.Read(data)
if _, err := f.Write(data); err != nil {
b.Fatal(err)
}
parts := []*blobDownloadPart{
{Offset: 0, Size: int64(size)},
}
b.SetBytes(int64(size))
b.ResetTimer()
for i := 0; i < b.N; i++ {
sh := newStreamHasher(f, parts, int64(size))
sh.MarkComplete(0)
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
<-done
}
}
func BenchmarkHashThroughput(b *testing.B) {
// Baseline: raw SHA256 throughput on this machine
size := 256 * 1024 * 1024 // 256MB
data := make([]byte, size)
rand.Read(data)
b.SetBytes(int64(size))
b.ResetTimer()
for i := 0; i < b.N; i++ {
h := sha256.New()
h.Write(data)
h.Sum(nil)
}
}

View File

@@ -620,9 +620,8 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
_, err := downloadBlob(ctx, downloadOpts{
mp: mp,
digest: layer.Digest,
regOpts: regOpts,
@@ -631,31 +630,12 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
if err != nil {
return err
}
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
if skipVerify[layer.Digest] {
continue
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
return err
}
}
// Note: Digest verification now happens inline during download in blobDownload.run()
// via the orderedWriter, so no separate verification pass is needed.
fn(api.ProgressResponse{Status: "writing manifest"})

View File

@@ -10,7 +10,6 @@ import (
"hash"
"io"
"io/fs"
"iter"
"os"
"path/filepath"
"strings"
@@ -327,21 +326,19 @@ func (c *DiskCache) GetFile(d Digest) string {
return absJoin(c.dir, "blobs", filename)
}
// Links returns a sequence of link names. The sequence is in lexical order.
// Links returns a slice of link names in lexical order.
// Names are converted from their relative path form to their name form but are
// not guaranteed to be valid. Callers should validate the names before using.
func (c *DiskCache) Links() iter.Seq2[string, error] {
return func(yield func(string, error) bool) {
for path, err := range c.links() {
if err != nil {
yield("", err)
return
}
if !yield(pathToName(path), nil) {
return
}
}
func (c *DiskCache) Links() ([]string, error) {
paths, err := c.links()
if err != nil {
return nil, err
}
names := make([]string, len(paths))
for i, path := range paths {
names[i] = pathToName(path)
}
return names, nil
}
// pathToName converts a path to a name. It is the inverse of nameToPath. The
@@ -372,10 +369,11 @@ func (c *DiskCache) manifestPath(name string) (string, error) {
}
maybe := filepath.Join("manifests", np)
for l, err := range c.links() {
if err != nil {
return "", err
}
paths, err := c.links()
if err != nil {
return "", err
}
for _, l := range paths {
if strings.EqualFold(maybe, l) {
return filepath.Join(c.dir, l), nil
}
@@ -383,22 +381,10 @@ func (c *DiskCache) manifestPath(name string) (string, error) {
return filepath.Join(c.dir, maybe), nil
}
// links returns a sequence of links in the cache in lexical order.
func (c *DiskCache) links() iter.Seq2[string, error] {
// TODO(bmizerany): reuse empty dirnames if exist
return func(yield func(string, error) bool) {
fsys := os.DirFS(c.dir)
manifests, err := fs.Glob(fsys, "manifests/*/*/*/*")
if err != nil {
yield("", err)
return
}
for _, manifest := range manifests {
if !yield(manifest, nil) {
return
}
}
}
// links returns a slice of link paths in the cache in lexical order.
func (c *DiskCache) links() ([]string, error) {
fsys := os.DirFS(c.dir)
return fs.Glob(fsys, "manifests/*/*/*/*")
}
type checkWriter struct {

View File

@@ -466,12 +466,9 @@ func testManifestNameReuse(t *testing.T) {
t.Fatalf("g = %v, want %v", g, w)
}
var got []string
for l, err := range c.links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
got, err := c.links()
if err != nil {
t.Fatal(err)
}
want := []string{"manifests/h/n/m/t"}
if !slices.Equal(got, want) {
@@ -487,12 +484,9 @@ func testManifestNameReuse(t *testing.T) {
err = c.Link("h/n/m:T", d1)
check(err)
got = got[:0]
for l, err := range c.links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
got, err = c.links()
if err != nil {
t.Fatal(err)
}
// we should have only one link that is same case as the last link
@@ -554,12 +548,9 @@ func TestNames(t *testing.T) {
check(c.Link("h/n/m:t", mkdigest("1")))
check(c.Link("h/n/m:u", mkdigest("2")))
var got []string
for l, err := range c.Links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
got, err := c.Links()
if err != nil {
t.Fatal(err)
}
want := []string{"h/n/m:t", "h/n/m:u"}
if !slices.Equal(got, want) {

View File

@@ -19,7 +19,6 @@ import (
"fmt"
"io"
"io/fs"
"iter"
"log/slog"
"net/http"
"os"
@@ -546,18 +545,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
})
}()
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil {
// Note the chunksum stream
// interruption, but do not cancel
// in-flight downloads. We can still
// make progress on them. Once they are
// done, ErrIncomplete will be returned
// below.
update(0, err)
break
}
err = r.chunksums(ctx, name, l, func(cs chunksum) bool {
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
@@ -569,7 +557,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
_, err := c.Get(cacheKeyDigest)
if err == nil {
update(cs.Chunk.Size(), ErrCached)
continue
return true // continue
}
wg.Add(1)
@@ -620,6 +608,13 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// Record the downloading of this chunk.
return blob.PutBytes(c, cacheKeyDigest, cacheKey)
})
return true // continue processing chunks
})
if err != nil {
// Note the chunksum stream interruption, but do not cancel
// in-flight downloads. We can still make progress on them.
// Once they are done, ErrIncomplete will be returned below.
update(0, err)
}
return nil
@@ -674,19 +669,6 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
return nil
}
func (m *Manifest) All() iter.Seq[*Layer] {
return func(yield func(*Layer) bool) {
if !yield(m.Config) {
return
}
for _, l := range m.Layers {
if !yield(l) {
return
}
}
}
}
func (m *Manifest) Size() int64 {
var size int64
if m.Config != nil {
@@ -811,125 +793,114 @@ type chunksum struct {
Digest blob.Digest
}
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
return func(yield func(chunksum, error) bool) {
scheme, n, _, err := r.parseNameExtended(name)
// chunksums calls fn for each chunksum in the layer. If the layer is under the
// chunking threshold, a single chunksum covering the entire layer is passed to fn.
// If the layer is over the chunking threshold, chunksums are read from the chunksums endpoint.
// Returns an error if the chunksum stream fails, or nil if all chunksums were processed.
// If fn returns false, iteration stops early and chunksums returns nil.
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer, fn func(chunksum) bool) error {
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
return err
}
if l.Size < r.maxChunkingThreshold() {
// any layer under the threshold should be downloaded
// in one go.
cs := chunksum{
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
),
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
Digest: l.Digest,
}
fn(cs)
return nil
}
// The response is a sequence of chunksums.
//
// Chunksums are chunks of a larger blob that can be
// downloaded and verified independently.
//
// The chunksums endpoint is a GET request that returns a
// sequence of chunksums in the following format:
//
// > GET /v2/<namespace>/<model>/chunksums/<digest>
//
// < HTTP/1.1 200 OK
// < Content-Location: <blobURL>
// <
// < <digest> <start>-<end>
// < ...
//
// The <blobURL> is the URL to download the chunks from and
// each <digest> is the digest of the chunk, and <start>-<end>
// is the range the chunk in the blob.
//
// Ranges may be used directly in Range headers like
// "bytes=<start>-<end>".
//
// The chunksums returned are guaranteed to be contiguous and
// include all bytes of the layer. If the stream is cut short,
// clients should retry.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
)
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
if err != nil {
return err
}
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != 200 {
return fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
}
blobURL := res.Header.Get("Content-Location")
s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords)
for {
if !s.Scan() {
return s.Err()
}
d, err := blob.ParseDigest(s.Bytes())
if err != nil {
yield(chunksum{}, err)
return
return fmt.Errorf("invalid digest: %q", s.Bytes())
}
if l.Size < r.maxChunkingThreshold() {
// any layer under the threshold should be downloaded
// in one go.
cs := chunksum{
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
),
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
Digest: l.Digest,
if !s.Scan() {
err := s.Err()
if err == nil {
err = fmt.Errorf("missing chunk range for digest %s", d)
}
yield(cs, nil)
return
return err
}
// The response is a sequence of chunksums.
//
// Chunksums are chunks of a larger blob that can be
// downloaded and verified independently.
//
// The chunksums endpoint is a GET request that returns a
// sequence of chunksums in the following format:
//
// > GET /v2/<namespace>/<model>/chunksums/<digest>
//
// < HTTP/1.1 200 OK
// < Content-Location: <blobURL>
// <
// < <digest> <start>-<end>
// < ...
//
// The <blobURL> is the URL to download the chunks from and
// each <digest> is the digest of the chunk, and <start>-<end>
// is the range the chunk in the blob.
//
// Ranges may be used directly in Range headers like
// "bytes=<start>-<end>".
//
// The chunksums returned are guaranteed to be contiguous and
// include all bytes of the layer. If the stream is cut short,
// clients should retry.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
)
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
chunk, err := parseChunk(s.Bytes())
if err != nil {
yield(chunksum{}, err)
return
return fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes())
}
res, err := sendRequest(r.client(), req)
if err != nil {
yield(chunksum{}, err)
return
cs := chunksum{
URL: blobURL,
Chunk: chunk,
Digest: d,
}
defer res.Body.Close()
if res.StatusCode != 200 {
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
yield(chunksum{}, err)
return
}
blobURL := res.Header.Get("Content-Location")
s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords)
for {
if !s.Scan() {
if s.Err() != nil {
yield(chunksum{}, s.Err())
}
return
}
d, err := blob.ParseDigest(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
return
}
if !s.Scan() {
err := s.Err()
if err == nil {
err = fmt.Errorf("missing chunk range for digest %s", d)
}
yield(chunksum{}, err)
return
}
chunk, err := parseChunk(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
return
}
cs := chunksum{
URL: blobURL,
Chunk: chunk,
Digest: d,
}
if !yield(cs, nil) {
return
}
if !fn(cs) {
return nil
}
}
}
@@ -1176,8 +1147,8 @@ func splitExtended(s string) (scheme, name, digest string) {
return scheme, s, digest
}
// parseChunk parses a string in the form "start-end" and returns the Chunk.
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
// parseChunk parses a byte slice in the form "start-end" and returns the Chunk.
func parseChunk(s []byte) (blob.Chunk, error) {
startPart, endPart, found := strings.Cut(string(s), "-")
if !found {
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)

View File

@@ -27,46 +27,20 @@ type Trace struct {
}
func (t *Trace) update(l *Layer, n int64, err error) {
if t.Update != nil {
if t != nil && t.Update != nil {
t.Update(l, n, err)
}
}
type traceKey struct{}
// WithTrace adds a trace to the context for transfer progress reporting.
// WithTrace attaches a Trace to the context for transfer progress reporting.
func WithTrace(ctx context.Context, t *Trace) context.Context {
old := traceFromContext(ctx)
if old == t {
// No change, return the original context. This also prevents
// infinite recursion below, if the caller passes the same
// Trace.
return ctx
}
// Create a new Trace that wraps the old one, if any. If we used the
// same pointer t, we end up with a recursive structure.
composed := &Trace{
Update: func(l *Layer, n int64, err error) {
if old != nil {
old.update(l, n, err)
}
t.update(l, n, err)
},
}
return context.WithValue(ctx, traceKey{}, composed)
return context.WithValue(ctx, traceKey{}, t)
}
var emptyTrace = &Trace{}
// traceFromContext returns the Trace associated with ctx, or an empty Trace if
// none is found.
//
// It never returns nil.
// traceFromContext returns the Trace associated with ctx, or nil if none.
func traceFromContext(ctx context.Context) *Trace {
t, _ := ctx.Value(traceKey{}).(*Trace)
if t == nil {
return emptyTrace
}
return t
}

View File

@@ -2,44 +2,46 @@ package backoff
import (
"context"
"iter"
"math/rand/v2"
"time"
)
func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
var n int
return func(yield func(int, error) bool) {
var t *time.Timer
for {
if ctx.Err() != nil {
yield(n, ctx.Err())
return
}
// Retry calls fn repeatedly with exponential backoff until it returns nil,
// a non-retryable error (shouldRetry returns false), or the context is cancelled.
// The shouldRetry function determines if an error is retryable.
// Returns the last error encountered, or nil if fn succeeded.
func Retry(ctx context.Context, maxBackoff time.Duration, shouldRetry func(error) bool, fn func() error) error {
var t *time.Timer
for n := 0; ; n++ {
if err := ctx.Err(); err != nil {
return err
}
if !yield(n, nil) {
return
}
err := fn()
if err == nil {
return nil
}
if !shouldRetry(err) {
return err
}
n++
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
if t == nil {
t = time.NewTimer(d)
} else {
t.Reset(d)
}
select {
case <-ctx.Done():
t.Stop()
case <-t.C:
}
if t == nil {
t = time.NewTimer(d)
} else {
t.Reset(d)
}
select {
case <-ctx.Done():
t.Stop()
return ctx.Err()
case <-t.C:
}
}
}

View File

@@ -10,31 +10,70 @@ import (
"time"
)
func TestLoop(t *testing.T) {
func TestRetry(t *testing.T) {
synctest.Run(func() {
last := -1
n := 0
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
for n, err := range Loop(ctx, 100*time.Millisecond) {
if !errors.Is(err, ctx.Err()) {
t.Errorf("err = %v, want nil", err)
}
if err != nil {
break
}
if n != last+1 {
t.Errorf("n = %d, want %d", n, last+1)
}
last = n
err := Retry(ctx, 100*time.Millisecond, func(err error) bool { return true }, func() error {
n++
if n > 5 {
cancel()
}
return errors.New("keep going")
})
if !errors.Is(err, context.Canceled) {
t.Errorf("err = %v, want context.Canceled", err)
}
if last != 6 {
t.Errorf("last = %d, want 6", last)
if n != 6 {
t.Errorf("n = %d, want 6", n)
}
})
}
func TestRetrySuccess(t *testing.T) {
synctest.Run(func() {
n := 0
err := Retry(t.Context(), 100*time.Millisecond, func(err error) bool { return true }, func() error {
n++
if n >= 3 {
return nil // success
}
return errors.New("retry")
})
if err != nil {
t.Errorf("err = %v, want nil", err)
}
if n != 3 {
t.Errorf("n = %d, want 3", n)
}
})
}
func TestRetryNonRetryable(t *testing.T) {
synctest.Run(func() {
permanent := errors.New("permanent error")
n := 0
err := Retry(t.Context(), 100*time.Millisecond, func(err error) bool {
return !errors.Is(err, permanent)
}, func() error {
n++
if n >= 2 {
return permanent
}
return errors.New("retry")
})
if !errors.Is(err, permanent) {
t.Errorf("err = %v, want permanent", err)
}
if n != 2 {
t.Errorf("n = %d, want 2", n)
}
})
}

View File

@@ -3,37 +3,46 @@
package backoff
import (
"errors"
"testing"
"testing/synctest"
"time"
)
func TestLoopAllocs(t *testing.T) {
var errRetry = errors.New("retry")
func TestRetryAllocs(t *testing.T) {
for i := range 3 {
got := testing.AllocsPerRun(1000, func() {
for tick := range Loop(t.Context(), 1) {
tick := 0
Retry(t.Context(), 1, func(err error) bool { return true }, func() error {
tick++
if tick >= i {
break
return nil
}
}
return errRetry
})
})
want := float64(0)
if i > 0 {
want = 3 // due to time.NewTimer
}
if got > want {
t.Errorf("[%d ticks]: allocs = %v, want 0", i, want)
t.Errorf("[%d ticks]: allocs = %v, want <= %v", i, got, want)
}
}
}
func BenchmarkLoop(b *testing.B) {
func BenchmarkRetry(b *testing.B) {
ctx := b.Context()
synctest.Run(func() {
for n := range Loop(ctx, 100*time.Millisecond) {
n := 0
Retry(ctx, 100*time.Millisecond, func(err error) bool { return true }, func() error {
n++
if n == b.N {
break
return nil
}
}
return errRetry
})
})
}

View File

@@ -231,7 +231,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
if r.Method != "DELETE" {
return errMethodNotAllowed
}
p, err := decodeUserJSON[*params](r.Body)
p, err := decodeParams(r.Body)
if err != nil {
return err
}
@@ -261,7 +261,7 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
return errMethodNotAllowed
}
p, err := decodeUserJSON[*params](r.Body)
p, err := decodeParams(r.Body)
if err != nil {
return err
}
@@ -293,10 +293,14 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
}
}
t := time.NewTicker(1<<63 - 1) // "unstarted" timer
// ticker controls periodic progress flushing. It starts paused (very long
// interval) and is activated by start() once all layers are registered,
// so clients see a complete total before progress begins.
ticker := time.NewTicker(1 << 62) // effectively paused until started
defer ticker.Stop()
start := sync.OnceFunc(func() {
flushProgress() // flush initial state
t.Reset(100 * time.Millisecond)
flushProgress()
ticker.Reset(100 * time.Millisecond)
})
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
@@ -320,36 +324,21 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
})
}()
// Block flushing progress updates until every
// layer is accounted for. Clients depend on a
// complete model size to calculate progress
// correctly; if they use an incomplete total,
// progress indicators would erratically jump
// as new layers are registered.
start()
},
})
done := make(chan error, 1)
go func() (err error) {
defer func() { done <- err }()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := s.Client.Pull(ctx, p.model())
if canRetry(err) {
continue
}
return err
}
return nil
go func() {
done <- backoff.Retry(ctx, 3*time.Second, canRetry, func() error {
return s.Client.Pull(ctx, p.model())
})
}()
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
for {
select {
case <-t.C:
case <-ticker.C:
flushProgress()
case err := <-done:
flushProgress()
@@ -374,20 +363,13 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
}
}
func decodeUserJSON[T any](r io.Reader) (T, error) {
var v T
err := json.NewDecoder(r).Decode(&v)
func decodeParams(r io.Reader) (*params, error) {
var p params
err := json.NewDecoder(r).Decode(&p)
if err == nil {
return v, nil
return &p, nil
}
var zero T
// Not sure why, but I can't seem to be able to use:
//
// errors.As(err, &json.UnmarshalTypeError{})
//
// This is working fine in stdlib, so I'm not sure what rules changed
// and why this no longer works here. So, we do it the verbose way.
var a *json.UnmarshalTypeError
var b *json.SyntaxError
if errors.As(err, &a) || errors.As(err, &b) {
@@ -396,7 +378,7 @@ func decodeUserJSON[T any](r io.Reader) (T, error) {
if errors.Is(err, io.EOF) {
err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
}
return zero, err
return nil, err
}
func canRetry(err error) bool {
@@ -408,10 +390,8 @@ func canRetry(err error) bool {
return oe.Temporary()
}
s := err.Error()
return cmp.Or(
errors.Is(err, context.DeadlineExceeded),
strings.Contains(s, "unreachable"),
strings.Contains(s, "no route to host"),
strings.Contains(s, "connection reset by peer"),
)
return errors.Is(err, context.DeadlineExceeded) ||
strings.Contains(s, "unreachable") ||
strings.Contains(s, "no route to host") ||
strings.Contains(s, "connection reset by peer")
}

View File

@@ -1535,9 +1535,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
if rc != nil {
// wrap old with new
rs := &registry.Local{

View File

@@ -1,8 +0,0 @@
//go:build !windows
package server
import "os"
func setSparse(*os.File) {
}

View File

@@ -1,17 +0,0 @@
package server
import (
"os"
"golang.org/x/sys/windows"
)
func setSparse(file *os.File) {
// exFat (and other FS types) don't support sparse files, so ignore errors
windows.DeviceIoControl( //nolint:errcheck
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
nil, 0,
nil, 0,
nil, nil,
)
}