Compare commits

..

2 Commits

Author SHA1 Message Date
ParthSareen
52f757d8a2 cmd: fix gofmt formatting in pi integration
Fixes formatting issues caught by golangci-lint.
2026-02-04 19:27:34 -08:00
ParthSareen
86aa7cd0a6 cmd: add pi integration to ollama launch
Adds support for launching Pi (Pi Coding Agent) via ollama launch.
Pi is a CLI tool distributed via npm that can use Ollama models.

Features:
- Configures ~/.pi/agent/models.json with Ollama provider settings
- Sets defaultProvider and defaultModel in ~/.pi/agent/settings.json
- Uses _launch marker to identify ollama-managed models
- Preserves user-managed models and other config fields
- Backups both files before modification
- Hidden from interactive selector (use ollama launch pi)

Tested with 20 comprehensive tests including safety tests
for user config preservation.

Usage:
  ollama launch pi
  ollama launch pi --model glm-4.7-flash
2026-02-04 18:51:40 -08:00
176 changed files with 8394 additions and 19800 deletions

View File

@@ -1,22 +0,0 @@
name: test-install
on:
pull_request:
paths:
- 'scripts/install.sh'
- '.github/workflows/test-install.yaml'
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Run install script
run: sh ./scripts/install.sh
env:
OLLAMA_NO_START: 1 # do not start app
- name: Verify ollama is available
run: ollama --version

View File

@@ -182,7 +182,7 @@ option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
@@ -216,4 +216,4 @@ if(MLX_ENGINE)
COMPONENT MLX)
endif()
endif()
endif()
endif()

View File

@@ -518,26 +518,24 @@ func mapStopReason(reason string, hasToolCalls bool) string {
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
estimatedInputTokens: estimatedInputTokens,
toolCallsSent: make(map[string]bool),
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
@@ -553,11 +551,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
if c.firstWrite {
c.firstWrite = false
// Use actual metrics if available, otherwise use estimate
c.inputTokens = r.Metrics.PromptEvalCount
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
c.inputTokens = c.estimatedInputTokens
}
events = append(events, StreamEvent{
Event: "message_start",
@@ -785,123 +779,3 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
}
return args
}
// CountTokensRequest represents an Anthropic count_tokens request
type CountTokensRequest struct {
Model string `json:"model"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
}
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
func EstimateInputTokens(req MessagesRequest) int {
return estimateTokens(CountTokensRequest{
Model: req.Model,
Messages: req.Messages,
System: req.System,
Tools: req.Tools,
Thinking: req.Thinking,
})
}
// CountTokensResponse represents an Anthropic count_tokens response
type CountTokensResponse struct {
InputTokens int `json:"input_tokens"`
}
// estimateTokens returns a rough estimate of tokens (len/4).
// TODO: Replace with actual tokenization via Tokenize API for accuracy.
// Current len/4 heuristic is a rough approximation (~4 chars/token average).
func estimateTokens(req CountTokensRequest) int {
var totalLen int
// Count system prompt
if req.System != nil {
totalLen += countAnyContent(req.System)
}
// Count messages
for _, msg := range req.Messages {
// Count role (always present)
totalLen += len(msg.Role)
// Count content
contentLen := countAnyContent(msg.Content)
totalLen += contentLen
}
for _, tool := range req.Tools {
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
}
// Return len/4 as rough token estimate, minimum 1 if there's any content
tokens := totalLen / 4
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
tokens = 1
}
return tokens
}
func countAnyContent(content any) int {
if content == nil {
return 0
}
switch c := content.(type) {
case string:
return len(c)
case []any:
total := 0
for _, block := range c {
total += countContentBlock(block)
}
return total
default:
if data, err := json.Marshal(content); err == nil {
return len(data)
}
return 0
}
}
func countContentBlock(block any) int {
blockMap, ok := block.(map[string]any)
if !ok {
if s, ok := block.(string); ok {
return len(s)
}
return 0
}
total := 0
blockType, _ := blockMap["type"].(string)
if text, ok := blockMap["text"].(string); ok {
total += len(text)
}
if thinking, ok := blockMap["thinking"].(string); ok {
total += len(thinking)
}
if blockType == "tool_use" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data)
}
}
if blockType == "tool_result" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data)
}
}
if source, ok := blockMap["source"].(map[string]any); ok {
if data, ok := source["data"].(string); ok {
total += len(data)
}
}
return total
}

View File

@@ -321,6 +321,8 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
@@ -603,7 +605,7 @@ func TestGenerateMessageID(t *testing.T) {
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0)
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
@@ -676,7 +678,7 @@ func TestStreamConverter_Basic(t *testing.T) {
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0)
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
@@ -729,7 +731,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model", 0)
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
@@ -776,7 +778,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model", 0)
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
@@ -840,6 +842,10 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
@@ -893,9 +899,11 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0)
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
@@ -929,7 +937,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0)
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
@@ -961,105 +969,3 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
}
})
}
func TestEstimateTokens_SimpleMessage(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "Hello, world!"},
},
}
tokens := estimateTokens(req)
// "user" (4) + "Hello, world!" (13) = 17 chars / 4 = 4 tokens
if tokens < 1 {
t.Errorf("expected at least 1 token, got %d", tokens)
}
// Sanity check: shouldn't be wildly off
if tokens > 10 {
t.Errorf("expected fewer than 10 tokens for short message, got %d", tokens)
}
}
func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
tokens := estimateTokens(req)
// System prompt adds to count
if tokens < 5 {
t.Errorf("expected at least 5 tokens with system prompt, got %d", tokens)
}
}
func TestEstimateTokens_WithTools(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "What's the weather?"},
},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get the current weather for a location",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
},
},
}
tokens := estimateTokens(req)
// Tools add significant content
if tokens < 10 {
t.Errorf("expected at least 10 tokens with tools, got %d", tokens)
}
}
func TestEstimateTokens_WithThinking(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this carefully...",
},
map[string]any{
"type": "text",
"text": "Here is my response.",
},
},
},
},
}
tokens := estimateTokens(req)
// Thinking content should be counted
if tokens < 10 {
t.Errorf("expected at least 10 tokens with thinking content, got %d", tokens)
}
}
func TestEstimateTokens_EmptyContent(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{},
}
tokens := estimateTokens(req)
if tokens != 0 {
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
}
}

View File

@@ -466,25 +466,3 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
}
return &resp, nil
}
// AliasRequest is the request body for creating or updating a model alias.
type AliasRequest struct {
Alias string `json:"alias"`
Target string `json:"target"`
PrefixMatching bool `json:"prefix_matching,omitempty"`
}
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
}
// AliasDeleteRequest is the request body for deleting a model alias.
type AliasDeleteRequest struct {
Alias string `json:"alias"`
}
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
}

View File

@@ -367,25 +367,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
return err
} else if info.RemoteHost != "" {
// Cloud model, no need to load/unload
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
// Check if user is signed in for ollama.com cloud models
if isCloud {
if _, err := client.Whoami(cmd.Context()); err != nil {
return err
}
}
if opts.ShowConnect {
p.StopAndClear()
if isCloud {
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
}
}
return nil
}
@@ -1763,7 +1752,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return err
}
if err := startApp(cmd.Context(), client); err != nil {
return err
return fmt.Errorf("ollama server not responding - %w", err)
}
}
return nil

View File

@@ -3,7 +3,6 @@ package cmd
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -1660,103 +1659,3 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
t.Error("Copy Think should not be affected by original modification")
}
}
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
tests := []struct {
name string
remoteHost string
whoamiStatus int
whoamiResp any
expectedError string
}{
{
name: "ollama.com cloud model - user signed in",
remoteHost: "https://ollama.com",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
},
{
name: "ollama.com cloud model - user not signed in",
remoteHost: "https://ollama.com",
whoamiStatus: http.StatusUnauthorized,
whoamiResp: map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
},
expectedError: "unauthorized",
},
{
name: "non-ollama.com remote - no auth check",
remoteHost: "https://other-remote.com",
whoamiStatus: http.StatusUnauthorized, // should not be called
whoamiResp: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
whoamiCalled := false
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(api.ShowResponse{
RemoteHost: tt.remoteHost,
RemoteModel: "test-model",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
case "/api/me":
whoamiCalled = true
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(tt.whoamiStatus)
if tt.whoamiResp != nil {
if err := json.NewEncoder(w).Encode(tt.whoamiResp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
default:
http.NotFound(w, r)
}
}))
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
opts := &runOptions{
Model: "test-cloud-model",
ShowConnect: false,
}
err := loadOrUnloadModel(cmd, opts)
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
if !whoamiCalled {
t.Error("expected whoami to be called for ollama.com cloud model")
}
} else {
if whoamiCalled {
t.Error("whoami should not be called for non-ollama.com remote")
}
}
if tt.expectedError != "" {
if err == nil {
t.Errorf("expected error containing %q, got nil", tt.expectedError)
} else {
var authErr api.AuthorizationError
if !errors.As(err, &authErr) {
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
}
}
} else {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
}
})
}
}

View File

@@ -1,23 +1,18 @@
package config
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner and AliasConfigurer for Claude Code integration
// Claude implements Runner for Claude Code integration
type Claude struct{}
// Compile-time check that Claude implements AliasConfigurer
var _ AliasConfigurer = (*Claude)(nil)
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extra []string) []string {
@@ -58,136 +53,10 @@ func (c *Claude) Run(model string, args []string) error {
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
env := append(os.Environ(),
cmd.Env = append(os.Environ(),
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)
env = append(env, c.modelEnvVars(model)...)
cmd.Env = env
return cmd.Run()
}
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
func (c *Claude) modelEnvVars(model string) []string {
primary := model
fast := model
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
if p := cfg.Aliases["primary"]; p != "" {
primary = p
}
if f := cfg.Aliases["fast"]; f != "" {
fast = f
}
}
return []string{
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
}
}
// ConfigureAliases sets up model aliases for Claude Code.
// model: the model to use (if empty, user will be prompted to select)
// aliases: existing alias configuration to preserve/update
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
// there is a better strategy for prompt caching on local models.
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
aliases := make(map[string]string)
for k, v := range existingAliases {
aliases[k] = v
}
if model != "" {
aliases["primary"] = model
}
if !force && aliases["primary"] != "" {
client, _ := api.ClientFromEnvironment()
if isCloudModel(ctx, client, aliases["primary"]) {
if isCloudModel(ctx, client, aliases["fast"]) {
return aliases, false, nil
}
} else {
delete(aliases, "fast")
return aliases, false, nil
}
}
items, existingModels, cloudModels, client, err := listModels(ctx)
if err != nil {
return nil, false, err
}
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
if aliases["primary"] == "" || force {
primary, err := selectPrompt("Select model:", items)
fmt.Fprintf(os.Stderr, "\033[3A\033[J")
if err != nil {
return nil, false, err
}
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
return nil, false, err
}
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
return nil, false, err
}
aliases["primary"] = primary
}
if isCloudModel(ctx, client, aliases["primary"]) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = aliases["primary"]
}
} else {
delete(aliases, "fast")
}
return aliases, true, nil
}
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
// prevent stale routing to a previous cloud model.
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
if aliases["fast"] == "" {
for _, prefix := range prefixes {
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
}
return nil
}
prefixAliases := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
var errs []string
for prefix, target := range prefixAliases {
req := &api.AliasRequest{
Alias: prefix,
Target: target,
PrefixMatching: true,
}
if err := client.SetAliasExperimental(ctx, req); err != nil {
errs = append(errs, prefix)
}
}
if len(errs) > 0 {
return fmt.Errorf("failed to set aliases: %v", errs)
}
return nil
}

View File

@@ -5,7 +5,6 @@ import (
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
@@ -104,95 +103,3 @@ func TestClaudeArgs(t *testing.T) {
})
}
}
func TestClaudeModelEnvVars(t *testing.T) {
c := &Claude{}
envMap := func(envs []string) map[string]string {
m := make(map[string]string)
for _, e := range envs {
k, v, _ := strings.Cut(e, "=")
m[k] = v
}
return m
}
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
got := envMap(c.modelEnvVars("llama3.2"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2" {
t.Errorf("SONNET = %q, want llama3.2", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2" {
t.Errorf("HAIKU = %q, want llama3.2", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
})
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
saveIntegration("claude", []string{"qwen3:8b"})
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
got := envMap(c.modelEnvVars("qwen3:8b"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
})
t.Run("uses fast alias for haiku", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
saveIntegration("claude", []string{"llama3.2:70b"})
saveAliases("claude", map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
})
got := envMap(c.modelEnvVars("llama3.2:70b"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
})
t.Run("alias primary overrides model param", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
saveIntegration("claude", []string{"saved-model"})
saveAliases("claude", map[string]string{"primary": "saved-model"})
got := envMap(c.modelEnvVars("different-model"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
})
}

View File

@@ -6,14 +6,14 @@ import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
)
type integration struct {
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
Models []string `json:"models"`
}
type config struct {
@@ -53,6 +53,7 @@ func migrateConfig() (bool, error) {
var js json.RawMessage
if err := json.Unmarshal(oldData, &js); err != nil {
slog.Warn("legacy config has invalid JSON, skipping migration", "path", oldPath, "error", err)
return false, nil
}
@@ -71,6 +72,7 @@ func migrateConfig() (bool, error) {
_ = os.Remove(oldPath)
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
slog.Info("migrated config", "from", oldPath, "to", newPath)
return true, nil
}
@@ -131,16 +133,8 @@ func saveIntegration(appName string, models []string) error {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
var aliases map[string]string
if existing != nil && existing.Aliases != nil {
aliases = existing.Aliases
}
cfg.Integrations[key] = &integration{
Models: models,
Aliases: aliases,
cfg.Integrations[strings.ToLower(appName)] = &integration{
Models: models,
}
return save(cfg)
@@ -160,29 +154,6 @@ func loadIntegration(appName string) (*integration, error) {
return ic, nil
}
func saveAliases(appName string, aliases map[string]string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}
cfg, err := load()
if err != nil {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
if existing == nil {
existing = &integration{}
}
// Replace aliases entirely (not merge) so deletions are persisted
existing.Aliases = aliases
cfg.Integrations[key] = existing
return save(cfg)
}
func listIntegrations() ([]integration, error) {
cfg, err := load()
if err != nil {

View File

@@ -1,677 +0,0 @@
package config
import (
"context"
"errors"
"os"
"path/filepath"
"testing"
)
func TestSetAliases_CloudModel(t *testing.T) {
// Test the SetAliases logic by checking the alias map behavior
aliases := map[string]string{
"primary": "kimi-k2.5:cloud",
"fast": "kimi-k2.5:cloud",
}
// Verify fast is set (cloud model behavior)
if aliases["fast"] == "" {
t.Error("cloud model should have fast alias set")
}
if aliases["fast"] != aliases["primary"] {
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
}
}
func TestSetAliases_LocalModel(t *testing.T) {
aliases := map[string]string{
"primary": "llama3.2:latest",
}
// Simulate local model behavior: fast should be empty
delete(aliases, "fast")
if aliases["fast"] != "" {
t.Error("local model should have empty fast alias")
}
}
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// First save with both primary and fast
initial := map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}
if err := saveAliases("claude", initial); err != nil {
t.Fatalf("failed to save initial aliases: %v", err)
}
// Verify both are saved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
}
// Now save without fast (simulating switch to local model)
updated := map[string]string{
"primary": "local-model",
// fast intentionally missing
}
if err := saveAliases("claude", updated); err != nil {
t.Fatalf("failed to save updated aliases: %v", err)
}
// Verify fast is GONE (not merged/preserved)
loaded, err = loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
}
if loaded.Aliases["primary"] != "local-model" {
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
}
}
func TestSaveAliases_PreservesModels(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// First save integration with models
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
t.Fatalf("failed to save integration: %v", err)
}
// Then update aliases
aliases := map[string]string{"primary": "new-model"}
if err := saveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save aliases: %v", err)
}
// Verify models are preserved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
t.Errorf("models should be preserved, got %v", loaded.Models)
}
}
// TestSaveAliases_EmptyMap clears all aliases
func TestSaveAliases_EmptyMap(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save with aliases
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save empty map
if err := saveAliases("claude", map[string]string{}); err != nil {
t.Fatalf("failed to save empty: %v", err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Aliases) != 0 {
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
}
}
// TestSaveAliases_NilMap handles nil gracefully
func TestSaveAliases_NilMap(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save with aliases first
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save nil map - should clear aliases
if err := saveAliases("claude", nil); err != nil {
t.Fatalf("failed to save nil: %v", err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Aliases) > 0 {
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
}
}
// TestSaveAliases_EmptyAppName returns error
func TestSaveAliases_EmptyAppName(t *testing.T) {
err := saveAliases("", map[string]string{"primary": "model"})
if err == nil {
t.Error("expected error for empty app name")
}
}
func TestSaveAliases_CaseInsensitive(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Load with different case
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model1" {
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
}
// Update with different case
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to update: %v", err)
}
loaded, err = loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
if loaded.Aliases["primary"] != "model2" {
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
}
}
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
func TestSaveAliases_CreatesIntegration(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save aliases for non-existent integration
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
loaded, err := loadIntegration("newintegration")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model" {
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
}
}
func TestConfigureAliases_AliasMap(t *testing.T) {
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
aliases := make(map[string]string)
aliases["primary"] = "cloud-model"
// Simulate cloud model behavior
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "cloud-model" {
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
}
})
t.Run("cloud model preserves custom fast", func(t *testing.T) {
aliases := map[string]string{
"primary": "cloud-model",
"fast": "custom-fast-model",
}
// Simulate cloud model behavior - should preserve existing fast
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "custom-fast-model" {
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
}
})
t.Run("local model clears fast", func(t *testing.T) {
aliases := map[string]string{
"primary": "local-model",
"fast": "should-be-cleared",
}
// Simulate local model behavior
isCloud := false
if !isCloud {
delete(aliases, "fast")
}
if aliases["fast"] != "" {
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
}
})
t.Run("switching cloud to local clears fast", func(t *testing.T) {
// Start with cloud config
aliases := map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}
// Switch to local
aliases["primary"] = "local-model"
isCloud := false
if !isCloud {
delete(aliases, "fast")
}
if aliases["fast"] != "" {
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
}
if aliases["primary"] != "local-model" {
t.Errorf("primary should be updated, got %q", aliases["primary"])
}
})
t.Run("switching local to cloud sets fast", func(t *testing.T) {
// Start with local config (no fast)
aliases := map[string]string{
"primary": "local-model",
}
// Switch to cloud
aliases["primary"] = "cloud-model"
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "cloud-model" {
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
}
})
}
func TestSetAliases_PrefixMapping(t *testing.T) {
// This tests the expected mapping without needing a real client
aliases := map[string]string{
"primary": "my-cloud-model",
"fast": "my-fast-model",
}
expectedMappings := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
t.Errorf("claude-sonnet- should map to primary")
}
if expectedMappings["claude-haiku-"] != "my-fast-model" {
t.Errorf("claude-haiku- should map to fast")
}
}
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
aliases := map[string]string{
"primary": "local-model",
// fast is empty/missing - indicates local model
}
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
// Verify the logic: when fast is empty, we should delete
if aliases["fast"] != "" {
t.Error("fast should be empty for local model")
}
// Verify we have the right prefixes to delete
if len(prefixesToDelete) != 2 {
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
}
}
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate: server fails, config should NOT be saved
serverErr := errors.New("server unavailable")
if serverErr == nil {
t.Error("config should NOT be saved when server fails")
}
}
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate: server succeeds, config should be saved
var serverErr error
if serverErr != nil {
t.Fatal("server should succeed")
}
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("saveAliases failed: %v", err)
}
// Verify it was actually saved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model" {
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
}
}
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Write config with extra fields
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
os.MkdirAll(filepath.Dir(configPath), 0o755)
// Note: Our config struct only has Integrations, so top-level unknown fields
// won't be preserved by our current implementation. This test documents that.
initialConfig := `{
"integrations": {
"claude": {
"models": ["model1"],
"aliases": {"primary": "model1"},
"unknownField": "should be lost"
}
},
"topLevelUnknown": "will be lost"
}`
os.WriteFile(configPath, []byte(initialConfig), 0o644)
// Update aliases
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Read raw file to check
data, _ := os.ReadFile(configPath)
content := string(data)
// models should be preserved
if !contains(content, "model1") {
t.Error("models should be preserved")
}
// primary should be updated
if !contains(content, "model2") {
t.Error("primary should be updated to model2")
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
c := &Claude{}
var _ AliasConfigurer = c // Compile-time check
}
func TestModelNameEdgeCases(t *testing.T) {
testCases := []struct {
name string
model string
}{
{"simple", "llama3.2"},
{"with tag", "llama3.2:latest"},
{"with cloud tag", "kimi-k2.5:cloud"},
{"with namespace", "library/llama3.2"},
{"with dots", "glm-4.7-flash"},
{"with numbers", "qwen3:8b"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
aliases := map[string]string{"primary": tc.model}
if err := saveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save model %q: %v", tc.model, err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != tc.model {
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
}
})
}
}
func TestSwitchingScenarios(t *testing.T) {
t.Run("cloud to local removes fast", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial cloud config
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
t.Fatal(err)
}
// Switch to local (no fast)
if err := saveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
}
if loaded.Aliases["primary"] != "local-model" {
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
}
})
t.Run("local to cloud adds fast", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial local config
if err := saveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
// Switch to cloud (with fast)
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
}
})
t.Run("cloud to different cloud updates both", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial cloud config
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model-1",
"fast": "cloud-model-1",
}); err != nil {
t.Fatal(err)
}
// Switch to different cloud
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model-2",
"fast": "cloud-model-2",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["primary"] != "cloud-model-2" {
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
}
if loaded.Aliases["fast"] != "cloud-model-2" {
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
}
})
}
func TestToolCapabilityFiltering(t *testing.T) {
t.Run("all models checked for tool capability", func(t *testing.T) {
// Both cloud and local models are checked for tool capability via Show API
// Only models with "tools" in capabilities are included
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
if !m.ToolCapable {
t.Error("tool capable model should be marked as such")
}
})
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
if !m.ToolCapable {
t.Error("ToolCapable field should be accessible")
}
})
}
func TestIsCloudModel_RequiresClient(t *testing.T) {
t.Run("nil client always returns false", func(t *testing.T) {
// isCloudModel now only uses Show API, no suffix detection
if isCloudModel(context.Background(), nil, "model:cloud") {
t.Error("nil client should return false regardless of suffix")
}
if isCloudModel(context.Background(), nil, "local-model") {
t.Error("nil client should return false")
}
})
}
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save aliases with one model
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
t.Fatal(err)
}
// Save integration with same model (this is the pattern we use)
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["primary"] != loaded.Models[0] {
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
}
})
t.Run("out of sync config is detectable", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate out-of-sync state (like manual edit or bug)
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
// They should be different (this is the bug state)
if loaded.Models[0] == loaded.Aliases["primary"] {
t.Error("expected out-of-sync state for this test")
}
// The fix: when updating aliases, also update models
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
t.Fatal(err)
}
loaded, _ = loadIntegration("claude")
if loaded.Models[0] != loaded.Aliases["primary"] {
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
loaded.Models[0], loaded.Aliases["primary"])
}
})
t.Run("updating primary alias updates models too", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial state
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
t.Fatal(err)
}
// Update aliases AND models together
newAliases := map[string]string{"primary": "updated-model"}
if err := saveAliases("claude", newAliases); err != nil {
t.Fatal(err)
}
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Models[0] != "updated-model" {
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
}
if loaded.Aliases["primary"] != "updated-model" {
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
}
})
}

View File

@@ -46,53 +46,6 @@ func TestIntegrationConfig(t *testing.T) {
}
})
t.Run("save and load aliases", func(t *testing.T) {
models := []string{"llama3.2"}
if err := saveIntegration("claude", models); err != nil {
t.Fatal(err)
}
aliases := map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
}
if err := saveAliases("claude", aliases); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases == nil {
t.Fatal("expected aliases to be saved")
}
for k, v := range aliases {
if config.Aliases[k] != v {
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
}
}
})
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
t.Fatal(err)
}
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases["primary"] != "model-a" {
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
}
})
t.Run("defaultModel returns first model", func(t *testing.T) {
saveIntegration("codex", []string{"model-a", "model-b"})

View File

@@ -1,7 +1,6 @@
package config
import (
"context"
"encoding/json"
"fmt"
"os"
@@ -9,7 +8,6 @@ import (
"path/filepath"
"slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
@@ -114,17 +112,9 @@ func (d *Droid) Edit(models []string) error {
}
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
client, _ := api.ClientFromEnvironment()
var newModels []any
var defaultModelID string
for i, model := range models {
maxOutput := 64000
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
maxOutput = l.Output
}
}
modelID := fmt.Sprintf("custom:%s-%d", model, i)
newModels = append(newModels, modelEntry{
Model: model,
@@ -132,7 +122,7 @@ func (d *Droid) Edit(models []string) error {
BaseURL: envconfig.Host().String() + "/v1",
APIKey: "ollama",
Provider: "generic-chat-completion-api",
MaxOutputTokens: maxOutput,
MaxOutputTokens: 64000,
SupportsImages: false,
ID: modelID,
Index: i,

View File

@@ -1251,55 +1251,6 @@ func TestDroidEdit_LargeNumberOfModels(t *testing.T) {
}
}
func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
if err := d.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
models := settings["customModels"].([]any)
entry := models[0].(map[string]any)
if entry["maxOutputTokens"] != float64(64000) {
t.Errorf("local model maxOutputTokens = %v, want 64000", entry["maxOutputTokens"])
}
}
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
// Verify that every cloud model in cloudModelLimits has a valid output
// value that would be used for maxOutputTokens when isCloudModel returns true.
// :cloud suffix stripping must also work since that's how users specify them.
for name, expected := range cloudModelLimits {
t.Run(name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(name)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
}
if l.Output != expected.Output {
t.Errorf("output = %d, want %d", l.Output, expected.Output)
}
// Also verify :cloud suffix lookup
cloudName := name + ":cloud"
l2, ok := lookupCloudModelLimit(cloudName)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
}
if l2.Output != expected.Output {
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
}
})
}
}
func TestDroidEdit_ArraysWithMixedTypes(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()

View File

@@ -39,15 +39,6 @@ type Editor interface {
Models() []string
}
// AliasConfigurer can configure model aliases (e.g., for subagent routing).
// Integrations like Claude and Codex use this to route model requests to local models.
type AliasConfigurer interface {
// ConfigureAliases prompts the user to configure aliases and returns the updated map.
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
// SetAliases syncs the configured aliases to the server
SetAliases(ctx context.Context, aliases map[string]string) error
}
// integrations is the registry of available integrations.
var integrations = map[string]Runner{
"claude": &Claude{},
@@ -57,6 +48,7 @@ var integrations = map[string]Runner{
"droid": &Droid{},
"opencode": &OpenCode{},
"openclaw": &Openclaw{},
"pi": &Pi{},
}
// recommendedModels are shown when the user has no models or as suggestions.
@@ -72,6 +64,7 @@ var recommendedModels = []selectItem{
var integrationAliases = map[string]bool{
"clawdbot": true,
"moltbot": true,
"pi": true,
}
func selectIntegration() (string, error) {
@@ -138,11 +131,7 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
return nil, err
}
} else {
prompt := fmt.Sprintf("Select model for %s:", r)
if _, ok := r.(AliasConfigurer); ok {
prompt = fmt.Sprintf("Select Primary model for %s:", r)
}
model, err := selectPrompt(prompt, items)
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
if err != nil {
return nil, err
}
@@ -170,123 +159,73 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
}
}
if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
return nil, err
}
return selected, nil
}
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
if existingModels[model] {
return nil
}
msg := fmt.Sprintf("Download %s?", model)
if ok, err := confirmPrompt(msg); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
if err := pullModel(ctx, client, model); err != nil {
return fmt.Errorf("failed to pull %s: %w", model, err)
}
return nil
}
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, nil, nil, nil, err
}
models, err := client.List(ctx)
if err != nil {
return nil, nil, nil, nil, err
}
var existing []modelInfo
for _, m := range models.Models {
existing = append(existing, modelInfo{
Name: m.Name,
Remote: m.RemoteModel != "",
})
}
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
if len(items) == 0 {
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
return items, existingModels, cloudModels, client, nil
}
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) == 0 {
return nil
}
if len(selectedCloudModels) > 0 {
// ensure user is signed in
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return selected, nil
}
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return nil, err
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return err
}
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return nil, fmt.Errorf("%s requires sign in", modelList)
}
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return fmt.Errorf("%s requires sign in", modelList)
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
// TODO(parthsareen): extract into auth package for cmd
// Auto-open browser (best effort, fail silently)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return nil, ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return nil
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return selected, nil
}
}
}
}
}
return selected, nil
}
func runIntegration(name, modelName string, args []string) error {
@@ -294,33 +233,10 @@ func runIntegration(name, modelName string, args []string) error {
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName, args)
}
// syncAliases syncs aliases to server and saves locally for an AliasConfigurer.
func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, name, model string, existing map[string]string) error {
aliases := make(map[string]string)
for k, v := range existing {
aliases[k] = v
}
aliases["primary"] = model
if isCloudModel(ctx, client, model) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = model
}
} else {
delete(aliases, "fast")
}
if err := ac.SetAliases(ctx, aliases); err != nil {
return err
}
return saveAliases(name, aliases)
}
// LaunchCmd returns the cobra command for launching integrations.
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
var modelFlag string
@@ -388,87 +304,9 @@ Examples:
return fmt.Errorf("unknown integration: %s", name)
}
// Handle AliasConfigurer integrations (claude, codex)
if ac, ok := r.(AliasConfigurer); ok {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Validate --model flag if provided
if modelFlag != "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
}
}
var model string
var existingAliases map[string]string
// Load saved config
if cfg, err := loadIntegration(name); err == nil {
existingAliases = cfg.Aliases
if len(cfg.Models) > 0 {
model = cfg.Models[0]
// AliasConfigurer integrations use single model; sanitize if multiple
if len(cfg.Models) > 1 {
_ = saveIntegration(name, []string{model})
}
}
}
// --model flag overrides saved model
if modelFlag != "" {
model = modelFlag
}
// Validate saved model still exists
if model != "" && modelFlag == "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}); err != nil {
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
model = ""
}
}
// If no valid model or --config flag, show picker
if model == "" || configFlag {
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
}
// Sync aliases and save
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
}
if err := saveIntegration(name, []string{model}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
// Launch (unless --config without confirmation)
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
return runIntegration(name, model, passArgs)
}
return nil
}
return runIntegration(name, model, passArgs)
}
// Validate --model flag for non-AliasConfigurer integrations
if modelFlag != "" {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
if !configFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0], passArgs)
}
}
@@ -482,8 +320,6 @@ Examples:
}
}
}
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
return runIntegration(name, saved.Models[0], passArgs)
} else {
var err error
models, err = selectModels(cmd.Context(), name, "")
@@ -546,9 +382,8 @@ Examples:
}
type modelInfo struct {
Name string
Remote bool
ToolCapable bool
Name string
Remote bool
}
// buildModelList merges existing models with recommendations, sorts them, and returns
@@ -585,7 +420,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
continue
}
items = append(items, rec)
if strings.HasSuffix(rec.Name, ":cloud") {
if isCloudModel(rec.Name) {
cloudModels[rec.Name] = true
}
}
@@ -645,16 +480,8 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
return items, preChecked, existingModels, cloudModels
}
// isCloudModel checks if a model is a cloud model using the Show API.
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
if client == nil {
return false
}
resp, err := client.Show(ctx, &api.ShowRequest{Name: name})
if err != nil {
return false
}
return resp.RemoteModel != ""
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud")
}
func pullModel(ctx context.Context, client *api.Client, model string) error {

View File

@@ -1,7 +1,6 @@
package config
import (
"context"
"fmt"
"slices"
"strings"
@@ -298,15 +297,24 @@ func TestParseArgs(t *testing.T) {
}
func TestIsCloudModel(t *testing.T) {
// isCloudModel now only uses Show API, so nil client always returns false
t.Run("nil client returns false", func(t *testing.T) {
models := []string{"glm-4.7:cloud", "kimi-k2.5:cloud", "local-model"}
for _, model := range models {
if isCloudModel(context.Background(), nil, model) {
t.Errorf("isCloudModel(%q) with nil client should return false", model)
tests := []struct {
name string
want bool
}{
{"glm-4.7:cloud", true},
{"kimi-k2.5:cloud", true},
{"glm-4.7-flash", false},
{"glm-4.7-flash:latest", false},
{"cloud-model", false},
{"model:cloudish", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isCloudModel(tt.name); got != tt.want {
t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want)
}
}
})
})
}
}
func names(items []selectItem) []string {
@@ -501,41 +509,3 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
t.Error("llama3.2 should not be in cloudModels")
}
}
func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save a config for opencode so it looks like a previous launch
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Verify loadIntegration returns the saved models
saved, err := loadIntegration("opencode")
if err != nil {
t.Fatal(err)
}
if len(saved.Models) == 0 {
t.Fatal("expected saved models")
}
if saved.Models[0] != "llama3.2" {
t.Errorf("expected llama3.2, got %s", saved.Models[0])
}
}
func TestAliasConfigurerInterface(t *testing.T) {
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
claude := &Claude{}
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
t.Error("Claude should implement AliasConfigurer")
}
})
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
codex := &Codex{}
if _, ok := interface{}(codex).(AliasConfigurer); ok {
t.Error("Codex should not implement AliasConfigurer")
}
})
}

View File

@@ -17,6 +17,8 @@ type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" }
const ansiGreen = "\033[32m"
func (c *Openclaw) Run(model string, args []string) error {
bin := "openclaw"
if _, err := exec.LookPath(bin); err != nil {

View File

@@ -1,7 +1,6 @@
package config
import (
"context"
"encoding/json"
"fmt"
"maps"
@@ -11,53 +10,12 @@ import (
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
// cloudModelLimit holds context and output token limits for a cloud model.
type cloudModelLimit struct {
Context int
Output int
}
// cloudModelLimits maps cloud model base names to their token limits.
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
var cloudModelLimits = map[string]cloudModelLimit{
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
"glm-4.6": {Context: 202_752, Output: 131_072},
"glm-4.7": {Context: 202_752, Output: 131_072},
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
"kimi-k2.5": {Context: 262_144, Output: 262_144},
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It tries the exact name first, then strips the ":cloud" suffix.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
if l, ok := cloudModelLimits[name]; ok {
return l, true
}
base := strings.TrimSuffix(name, ":cloud")
if base != name {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string, args []string) error {
@@ -155,8 +113,6 @@ func (o *OpenCode) Edit(modelList []string) error {
}
}
client, _ := api.ClientFromEnvironment()
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
@@ -166,29 +122,12 @@ func (o *OpenCode) Edit(modelList []string) error {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
existing["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
continue
}
entry := map[string]any{
models[model] = map[string]any{
"name": model,
"_launch": true,
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
models[model] = entry
}
ollama["models"] = models

View File

@@ -2,7 +2,6 @@ package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
@@ -496,166 +495,6 @@ func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
}
}
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
t.Helper()
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry, ok := models[model].(map[string]any)
if !ok {
t.Fatalf("model %s not found in config", model)
}
return entry
}
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
if entry["limit"] != nil {
t.Errorf("local model should not have limit set, got %v", entry["limit"])
}
}
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
// Set up a model with a user-configured limit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{
"provider": {
"ollama": {
"models": {
"llama3.2": {
"name": "llama3.2",
"_launch": true,
"limit": {"context": 8192, "output": 4096}
}
}
}
}
}`), 0o644)
// Re-edit should preserve the user's limit (not delete it)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("user-configured limit was removed")
}
if limit["context"] != float64(8192) {
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
}
if limit["output"] != float64(4096) {
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
}
}
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
// Verify that when a cloud model entry has limits set (as Edit would do),
// the structure matches what opencode expects and re-edit preserves them.
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
expected := cloudModelLimits["glm-4.7"]
// Simulate a cloud model that already has the limit set by a previous Edit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
"provider": {
"ollama": {
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud",
"_launch": true,
"limit": {"context": %d, "output": %d}
}
}
}
}
}`, expected.Context, expected.Output)), 0o644)
// Re-edit should preserve the cloud model limit
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model limit was removed on re-edit")
}
if limit["context"] != float64(expected.Context) {
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
}
if limit["output"] != float64(expected.Output) {
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
}
}
func TestLookupCloudModelLimit(t *testing.T) {
tests := []struct {
name string
wantOK bool
wantContext int
wantOutput int
}{
{"glm-4.7", true, 202_752, 131_072},
{"glm-4.7:cloud", true, 202_752, 131_072},
{"kimi-k2.5", true, 262_144, 262_144},
{"kimi-k2.5:cloud", true, 262_144, 262_144},
{"deepseek-v3.2", true, 163_840, 65_536},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3-coder:480b", true, 262_144, 65_536},
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(tt.name)
if ok != tt.wantOK {
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
}
if ok {
if l.Context != tt.wantContext {
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
}
if l.Output != tt.wantOutput {
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
}
}
})
}
}
func TestOpenCodeModels_NoConfig(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()

196
cmd/config/pi.go Normal file
View File

@@ -0,0 +1,196 @@
package config
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"slices"
"github.com/ollama/ollama/envconfig"
)
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
type Pi struct{}
func (p *Pi) String() string { return "Pi" }
func (p *Pi) Run(model string, args []string) error {
if _, err := exec.LookPath("pi"); err != nil {
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := p.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("pi", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (p *Pi) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
if _, err := os.Stat(modelsPath); err == nil {
paths = append(paths, modelsPath)
}
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
if _, err := os.Stat(settingsPath); err == nil {
paths = append(paths, settingsPath)
}
return paths
}
func (p *Pi) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
}
providers, ok := config["providers"].(map[string]any)
if !ok {
providers = make(map[string]any)
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"baseUrl": envconfig.Host().String() + "/v1",
"api": "openai-completions",
"apiKey": "ollama",
}
}
existingModels, ok := ollama["models"].([]any)
if !ok {
existingModels = make([]any, 0)
}
// Build set of selected models to track which need to be added
selectedSet := make(map[string]bool, len(models))
for _, m := range models {
selectedSet[m] = true
}
// Build new models list:
// 1. Keep user-managed models (no _launch marker) - untouched
// 2. Keep ollama-managed models (_launch marker) that are still selected
// 3. Add new ollama-managed models
var newModels []any
for _, m := range existingModels {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
// User-managed model (no _launch marker) - always preserve
if !isPiOllamaModel(modelObj) {
newModels = append(newModels, m)
} else if selectedSet[id] {
// Ollama-managed and still selected - keep it
newModels = append(newModels, m)
selectedSet[id] = false
}
}
}
}
// Add newly selected models that weren't already in the list
for _, model := range models {
if selectedSet[model] {
newModels = append(newModels, map[string]any{
"id": model,
"_launch": true,
})
}
}
ollama["models"] = newModels
providers["ollama"] = ollama
config["providers"] = providers
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
return err
}
// Update settings.json with default provider and model
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
settings := make(map[string]any)
if data, err := os.ReadFile(settingsPath); err == nil {
_ = json.Unmarshal(data, &settings)
}
settings["defaultProvider"] = "ollama"
settings["defaultModel"] = models[0]
settingsData, err := json.MarshalIndent(settings, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, settingsData)
}
func (p *Pi) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
config, err := readJSONFile(configPath)
if err != nil {
return nil
}
providers, _ := config["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
models, _ := ollama["models"].([]any)
var result []string
for _, m := range models {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
result = append(result, id)
}
}
}
slices.Sort(result)
return result
}
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
func isPiOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
}
return false
}

609
cmd/config/pi_test.go Normal file
View File

@@ -0,0 +1,609 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestPiIntegration(t *testing.T) {
pi := &Pi{}
t.Run("String", func(t *testing.T) {
if got := pi.String(); got != "Pi" {
t.Errorf("String() = %q, want %q", got, "Pi")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = pi
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = pi
})
}
func TestPiPaths(t *testing.T) {
pi := &Pi{}
t.Run("returns empty when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
paths := pi.Paths()
if len(paths) != 0 {
t.Errorf("Paths() = %v, want empty", paths)
}
})
t.Run("returns path when config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte("{}"), 0o644); err != nil {
t.Fatal(err)
}
paths := pi.Paths()
if len(paths) != 1 || paths[0] != configPath {
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
}
})
}
func TestPiEdit(t *testing.T) {
pi := &Pi{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
configPath := filepath.Join(configDir, "models.json")
cleanup := func() {
os.RemoveAll(configDir)
}
readConfig := func() map[string]any {
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
return cfg
}
t.Run("returns nil for empty models", func(t *testing.T) {
if err := pi.Edit([]string{}); err != nil {
t.Errorf("Edit([]) error = %v, want nil", err)
}
})
t.Run("creates config with models", func(t *testing.T) {
cleanup()
models := []string{"llama3.2", "qwen3:8b"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers, ok := cfg["providers"].(map[string]any)
if !ok {
t.Error("Config missing providers")
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
t.Error("Providers missing ollama")
}
modelsArray, ok := ollama["models"].([]any)
if !ok || len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %v", modelsArray)
}
if ollama["baseUrl"] == nil {
t.Error("Missing baseUrl")
}
if ollama["api"] != "openai-completions" {
t.Errorf("Expected api=openai-completions, got %v", ollama["api"])
}
if ollama["apiKey"] != "ollama" {
t.Errorf("Expected apiKey=ollama, got %v", ollama["apiKey"])
}
})
t.Run("updates existing config preserving ollama provider settings", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://custom:8080/v1",
"api": "custom-api",
"apiKey": "custom-key",
"models": [
{"id": "old-model", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"new-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
if ollama["baseUrl"] != "http://custom:8080/v1" {
t.Errorf("Custom baseUrl not preserved, got %v", ollama["baseUrl"])
}
if ollama["api"] != "custom-api" {
t.Errorf("Custom api not preserved, got %v", ollama["api"])
}
if ollama["apiKey"] != "custom-key" {
t.Errorf("Custom apiKey not preserved, got %v", ollama["apiKey"])
}
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 1 {
t.Errorf("Expected 1 model after update, got %d", len(modelsArray))
} else {
modelEntry := modelsArray[0].(map[string]any)
if modelEntry["id"] != "new-model" {
t.Errorf("Expected new-model, got %v", modelEntry["id"])
}
// Verify _launch marker is present
if modelEntry["_launch"] != true {
t.Errorf("Expected _launch marker to be true")
}
}
})
t.Run("replaces old models with new ones", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Old models must have _launch marker to be managed by us
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "old-model-1", "_launch": true},
{"id": "old-model-2", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
newModels := []string{"new-model-1", "new-model-2"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %d", len(modelsArray))
}
modelIDs := make(map[string]bool)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = true
}
if !modelIDs["new-model-1"] || !modelIDs["new-model-2"] {
t.Errorf("Expected new models, got %v", modelIDs)
}
if modelIDs["old-model-1"] || modelIDs["old-model-2"] {
t.Errorf("Old models should have been removed, got %v", modelIDs)
}
})
t.Run("handles partial overlap in model list", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Models must have _launch marker to be managed
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "keep-model", "_launch": true},
{"id": "remove-model", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
newModels := []string{"keep-model", "add-model"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 2 {
t.Errorf("Expected 2 models, got %d", len(modelsArray))
}
modelIDs := make(map[string]bool)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = true
}
if !modelIDs["keep-model"] || !modelIDs["add-model"] {
t.Errorf("Expected keep-model and add-model, got %v", modelIDs)
}
if modelIDs["remove-model"] {
t.Errorf("remove-model should have been removed")
}
})
t.Run("handles corrupt config gracefully", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"test-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() should not fail with corrupt config, got %v", err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("Failed to read config: %v", err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("Config should be valid after Edit, got parse error: %v", err)
}
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
if len(modelsArray) != 1 {
t.Errorf("Expected 1 model, got %d", len(modelsArray))
}
})
// CRITICAL SAFETY TEST: verifies we don't stomp on user configs
t.Run("preserves user-managed models without _launch marker", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// User has manually configured models in ollama provider (no _launch marker)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "user-model-1"},
{"id": "user-model-2", "customField": "preserved"},
{"id": "ollama-managed", "_launch": true}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
// Add a new ollama-managed model
newModels := []string{"new-ollama-model"}
if err := pi.Edit(newModels); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
// Should have: new-ollama-model (managed) + 2 user models (preserved)
if len(modelsArray) != 3 {
t.Errorf("Expected 3 models (1 new managed + 2 preserved user models), got %d", len(modelsArray))
}
modelIDs := make(map[string]map[string]any)
for _, m := range modelsArray {
modelObj := m.(map[string]any)
id := modelObj["id"].(string)
modelIDs[id] = modelObj
}
// Verify new model has _launch marker
if m, ok := modelIDs["new-ollama-model"]; !ok {
t.Errorf("new-ollama-model should be present")
} else if m["_launch"] != true {
t.Errorf("new-ollama-model should have _launch marker")
}
// Verify user models are preserved
if _, ok := modelIDs["user-model-1"]; !ok {
t.Errorf("user-model-1 should be preserved")
}
if _, ok := modelIDs["user-model-2"]; !ok {
t.Errorf("user-model-2 should be preserved")
} else if modelIDs["user-model-2"]["customField"] != "preserved" {
t.Errorf("user-model-2 customField should be preserved")
}
// Verify old ollama-managed model is removed (not in new list)
if _, ok := modelIDs["ollama-managed"]; ok {
t.Errorf("ollama-managed should be removed (old ollama model not in new selection)")
}
})
t.Run("updates settings.json with default provider and model", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Create existing settings with other fields
settingsPath := filepath.Join(configDir, "settings.json")
existingSettings := `{
"theme": "dark",
"customSetting": "value",
"defaultProvider": "anthropic",
"defaultModel": "claude-3"
}`
if err := os.WriteFile(settingsPath, []byte(existingSettings), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"llama3.2"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("Failed to read settings: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("Failed to parse settings: %v", err)
}
// Verify defaultProvider is set to ollama
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
// Verify defaultModel is set to first model
if settings["defaultModel"] != "llama3.2" {
t.Errorf("defaultModel = %v, want llama3.2", settings["defaultModel"])
}
// Verify other fields are preserved
if settings["theme"] != "dark" {
t.Errorf("theme = %v, want dark (preserved)", settings["theme"])
}
if settings["customSetting"] != "value" {
t.Errorf("customSetting = %v, want value (preserved)", settings["customSetting"])
}
})
t.Run("creates settings.json if it does not exist", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
models := []string{"qwen3:8b"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() error = %v", err)
}
settingsPath := filepath.Join(configDir, "settings.json")
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("settings.json should be created: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("Failed to parse settings: %v", err)
}
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
if settings["defaultModel"] != "qwen3:8b" {
t.Errorf("defaultModel = %v, want qwen3:8b", settings["defaultModel"])
}
})
t.Run("handles corrupt settings.json gracefully", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Create corrupt settings
settingsPath := filepath.Join(configDir, "settings.json")
if err := os.WriteFile(settingsPath, []byte("{invalid"), 0o644); err != nil {
t.Fatal(err)
}
models := []string{"test-model"}
if err := pi.Edit(models); err != nil {
t.Fatalf("Edit() should not fail with corrupt settings, got %v", err)
}
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatalf("Failed to read settings: %v", err)
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("settings.json should be valid after Edit, got parse error: %v", err)
}
if settings["defaultProvider"] != "ollama" {
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
}
if settings["defaultModel"] != "test-model" {
t.Errorf("defaultModel = %v, want test-model", settings["defaultModel"])
}
})
}
func TestPiModels(t *testing.T) {
pi := &Pi{}
t.Run("returns nil when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns models from config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {
"models": [
{"id": "llama3.2"},
{"id": "qwen3:8b"}
]
}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if len(models) != 2 {
t.Errorf("Models() returned %d models, want 2", len(models))
}
if models[0] != "llama3.2" || models[1] != "qwen3:8b" {
t.Errorf("Models() = %v, want [llama3.2 qwen3:8b] (sorted)", models)
}
})
t.Run("returns sorted models", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {
"models": [
{"id": "z-model"},
{"id": "a-model"},
{"id": "m-model"}
]
}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models[0] != "a-model" || models[1] != "m-model" || models[2] != "z-model" {
t.Errorf("Models() = %v, want [a-model m-model z-model] (sorted)", models)
}
})
t.Run("returns nil when models array is missing", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
config := `{
"providers": {
"ollama": {}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil when models array is missing", models)
}
})
t.Run("handles corrupt config gracefully", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".pi", "agent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
t.Fatal(err)
}
models := pi.Models()
if models != nil {
t.Errorf("Models() = %v, want nil for corrupt config", models)
}
})
}

View File

@@ -17,7 +17,6 @@ const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiClearDown = "\033[J"
)

View File

@@ -96,14 +96,6 @@ func TestSelectState(t *testing.T) {
}
})
t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) {
s := newSelectState([]selectItem{})
done, result, err := s.handleInput(eventEnter, 0)
if done || result != "" || err != nil {
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEscape, 0)
@@ -582,19 +574,8 @@ func TestRenderSelect(t *testing.T) {
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "no matches") {
t.Errorf("expected 'no matches' message, got: %s", output)
}
})
t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) {
s := newSelectState([]selectItem{})
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' message for empty list with no filter")
t.Error("expected 'no matches' message")
}
})

View File

@@ -10,21 +10,19 @@ import (
"github.com/ollama/ollama/api"
)
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
func startApp(ctx context.Context, client *api.Client) error {
exe, err := os.Executable()
if err != nil {
return errNotRunning
return err
}
link, err := os.Readlink(exe)
if err != nil {
return errNotRunning
return err
}
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
m := r.FindStringSubmatch(link)
if len(m) != 1 {
return errNotRunning
return errors.New("could not find ollama app")
}
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
return err

5
go.mod
View File

@@ -13,7 +13,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.24
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.10.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
@@ -29,8 +29,6 @@ require (
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/tree-sitter/go-tree-sitter v0.25.0
github.com/tree-sitter/tree-sitter-cpp v0.23.4
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0
@@ -52,7 +50,6 @@ require (
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-pointer v0.0.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect

31
go.sum
View File

@@ -152,8 +152,6 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
@@ -208,39 +206,12 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=

View File

@@ -144,47 +144,3 @@ func TestUnicodeModelDir(t *testing.T) {
}
ChatTestHelper(ctx, t, req, blueSkyExpected)
}
// TestNumPredict verifies that when num_predict is set, the model generates
// exactly that many tokens. It uses logprobs to count the actual tokens output.
func TestNumPredict(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
t.Fatalf("failed to pull model: %v", err)
}
req := api.GenerateRequest{
Model: "qwen3:0.6b",
Prompt: "Write a long story.",
Stream: &stream,
Logprobs: true,
Options: map[string]any{
"num_predict": 10,
"temperature": 0,
"seed": 123,
},
}
logprobCount := 0
var finalResponse api.GenerateResponse
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
logprobCount += len(resp.Logprobs)
if resp.Done {
finalResponse = resp
}
return nil
})
if err != nil {
t.Fatalf("generate failed: %v", err)
}
if logprobCount != 10 {
t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
}
}

View File

@@ -34,7 +34,6 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
type filteredEnv []string
@@ -117,7 +116,7 @@ type llamaServer struct {
type ollamaServer struct {
llmServer
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
}
// LoadModel will load a model from disk. The model must be in the GGML format.
@@ -143,11 +142,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
var llamaModel *llama.Model
var tok tokenizer.Tokenizer
var textProcessor model.TextProcessor
var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 {
tok, err = model.NewTextProcessor(modelPath)
textProcessor, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
@@ -156,7 +155,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
}
}
if tok == nil {
if textProcessor == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
@@ -212,7 +211,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType())
if tok == nil {
if textProcessor == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
@@ -262,7 +261,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
gpuLibs := ml.LibraryPaths(gpus)
status := NewStatusWriter(os.Stderr)
cmd, port, err := StartRunner(
tok != nil,
textProcessor != nil,
modelPath,
gpuLibs,
status,
@@ -311,8 +310,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}
}()
if tok != nil {
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
if textProcessor != nil {
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
}
@@ -1775,7 +1774,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
}
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
tokens, err := s.tokenizer.Encode(content, false)
tokens, err := s.textProcessor.Encode(content, false)
if err != nil {
return nil, err
}
@@ -1810,7 +1809,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
toks[i] = int32(t)
}
content, err := s.tokenizer.Decode(toks)
content, err := s.textProcessor.Decode(toks)
if err != nil {
return "", err
}

View File

@@ -131,15 +131,12 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
messageID := anthropic.GenerateMessageID()
// Estimate input tokens for streaming (actual count not available until generation completes)
estimatedTokens := anthropic.EstimateInputTokens(req)
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
converter: anthropic.NewStreamConverter(messageID, req.Model),
}
if req.Stream {

View File

@@ -175,7 +175,6 @@ type Tensor interface {
SILU(ctx Context, up ...Tensor) Tensor
RELU(ctx Context, up ...Tensor) Tensor
Sigmoid(ctx Context) Tensor
SigmoidOut(ctx Context) Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor

View File

@@ -1468,13 +1468,6 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) SigmoidOut(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sigmoid(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
switch len(shape) {
case 1:

272
model/bytepairencoding.go Normal file
View File

@@ -0,0 +1,272 @@
package model
import (
"cmp"
"iter"
"slices"
"strings"
"github.com/dlclark/regexp2"
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil"
)
type BytePairEncoding struct {
vocab *Vocabulary
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
}
}),
}
}
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
return bpe.vocab
}
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
return bpe.vocab.Is(id, special)
}
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
parts := []string{s}
for _, re := range bpe.regexps {
parts = slices.Collect(func(yield func(string) bool) {
for _, part := range parts {
r := []rune(part)
var offset int
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
if offset-m.Index != 0 {
if !yield(string(r[:m.Index])) {
return
}
}
if !yield(m.String()) {
return
}
offset = m.Index + m.Length
}
if offset < len(r) {
if !yield(string(r[offset:])) {
return
}
}
}
})
}
return slices.Values(parts)
}
// fragment is a string fragment and their corresponding token IDs
type fragment struct {
value string
ids []int32
}
// pair is a pair of runes and its rank
type pair struct {
a, b int
rank int
value string
}
type merge struct {
p, n int
runes []rune
}
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := bpe.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range bpe.split(frag.value) {
// TODO: process splits concurrently
var sb strings.Builder
for _, b := range []byte(split) {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
sb.WriteRune(r)
}
// short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *pair {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
rank := bpe.vocab.Merge(left, right)
if rank < 0 {
return nil
}
return &pair{
a: a,
b: b,
rank: rank,
value: left + right,
}
}
pairs := heap.NewWith(func(i, j *pair) int {
return cmp.Compare(i.rank, j.rank)
})
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pairs.Push(pair)
}
}
for !pairs.Empty() {
pair, _ := pairs.Pop()
left, right := merges[pair.a], merges[pair.b]
if len(left.runes) == 0 || len(right.runes) == 0 ||
string(left.runes)+string(right.runes) != pair.value {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pairs.Push(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pairs.Push(pair)
}
}
for _, merge := range merges {
if len(merge.runes) > 0 {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
}
}
}
}
}
if addSpecial {
ids = bpe.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
for _, r := range bpe.vocab.Decode(id) {
switch {
case r == 0x0100:
// this produces 0x00 aka NULL
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// NOTE: not using WriteRune here because it writes the UTF-8
// encoding of the rune which is _not_ what we want
if err := sb.WriteByte(byte(r)); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "string", sb.String(), "from", ids)
return sb.String(), nil
}

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"bufio"
@@ -17,7 +17,7 @@ import (
func llama(t testing.TB) BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
}
}
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
if err != nil {
t.Fatal(err)
}

View File

@@ -23,7 +23,6 @@ import (
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var (
@@ -134,7 +133,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil
}
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
func NewTextProcessor(s string) (TextProcessor, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
@@ -151,7 +150,7 @@ func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
return nil, err
}
tp, ok := m.(tokenizer.Tokenizer)
tp, ok := m.(TextProcessor)
if !ok {
return nil, ErrUnsupportedTokenizer
}

View File

@@ -10,12 +10,11 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -130,7 +129,7 @@ func (o Options) headDim() int {
}
func New(c fs.Config) (model.Model, error) {
vocab := &tokenizer.Vocabulary{
vocab := &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -154,17 +153,17 @@ func New(c fs.Config) (model.Model, error) {
},
}
var t tokenizer.Tokenizer
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
t = tokenizer.NewWordPiece(vocab, true)
processor = model.NewWordPiece(vocab, true)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
Tokenizer: t,
Layers: make([]EncoderLayer, c.Uint("block_count")),
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,7 +13,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -223,7 +222,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -278,8 +277,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,12 +10,11 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
@@ -135,8 +134,8 @@ func init() {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,7 +10,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -28,7 +27,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
tokenizer.Tokenizer
model.SentencePiece
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -44,8 +43,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -7,12 +7,11 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
tokenizer.Tokenizer
model.SentencePiece
*TextModel
poolingType pooling.Type
@@ -32,8 +31,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -12,12 +12,11 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
*VisionModel `gguf:"v"`
*TextModel
@@ -55,7 +54,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
vocabulary := tokenizer.Vocabulary{
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -71,19 +70,19 @@ func New(c fs.Config) (model.Model, error) {
),
}
var t tokenizer.Tokenizer
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model") {
case "gpt2":
t = tokenizer.NewBytePairEncoding(&vocabulary)
processor = model.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
t = tokenizer.NewSentencePiece(&vocabulary)
processor = model.NewSentencePiece(&vocabulary)
}
m := Model{
Tokenizer: t,
TextProcessor: processor,
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),

View File

@@ -6,12 +6,11 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.SentencePiece
*TextModel
}
@@ -24,8 +23,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -10,7 +10,6 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
@@ -199,7 +198,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -237,8 +236,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,12 +11,11 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v"`
@@ -38,8 +37,8 @@ func New(c fs.Config) (model.Model, error) {
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
m := &Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,12 +12,11 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Transformer struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TransformerBlocks []TransformerBlock `gguf:"blk"`
@@ -197,8 +196,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,7 +10,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -60,7 +59,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -79,7 +78,7 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedTokenizer
}
vocabulary := tokenizer.Vocabulary{
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -105,8 +104,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),

View File

@@ -11,7 +11,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -26,7 +25,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -42,8 +41,8 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedModel
}
var processor tokenizer.Tokenizer
vocabulary := tokenizer.Vocabulary{
var processor model.TextProcessor
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -81,16 +80,16 @@ func New(c fs.Config) (model.Model, error) {
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
processor = tokenizer.NewSentencePiece(&vocabulary)
processor = model.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
Tokenizer: processor,
Layers: make([]Layer, c.Uint("block_count")),
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -11,12 +11,11 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
ImageProcessor
*VisionModel `gguf:"v"`
@@ -34,8 +33,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) {
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,12 +11,11 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v"`
@@ -29,12 +28,12 @@ type Model struct {
var _ model.MultimodalProcessor = (*Model)(nil)
// Implement TextProcessor interface
var _ tokenizer.Tokenizer = (*Model)(nil)
var _ model.TextProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,12 +11,11 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
*VisionModel `gguf:"v"`
*TextModel
@@ -33,8 +32,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,12 +11,11 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -179,6 +178,29 @@ func New(c fs.Config) (model.Model, error) {
numHeads := int(c.Uint("attention.head_count"))
headDim := hiddenSize / numHeads
processor := model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
)
blockCount := int(c.Uint("block_count"))
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
layers := make([]EncoderLayer, blockCount)
@@ -197,29 +219,8 @@ func New(c fs.Config) (model.Model, error) {
}
return &Model{
Tokenizer: tokenizer.NewWordPiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
),
Layers: layers,
TextProcessor: processor,
Layers: layers,
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,

View File

@@ -11,7 +11,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
const (
@@ -34,7 +33,7 @@ type Options struct {
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -45,24 +44,28 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
Layers: make([]Layer, c.Uint("block_count")),
}
processor := model.NewBytePairEncoding(
&vocabulary,
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,7 +13,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -93,7 +92,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []DecoderLayer `gguf:"blk"`
@@ -140,8 +139,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,12 +10,11 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v"`
@@ -28,8 +27,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -7,12 +7,11 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
*Model
poolingType pooling.Type
@@ -35,8 +34,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
layers[i].MLP = &dense{}
}
m := embedModel{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,7 +12,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -160,7 +159,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -219,8 +218,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,7 +11,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
// Options contains model configuration
@@ -136,7 +135,7 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
// Apply shared expert gating
if mlp.SharedGateInp != nil {
sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D)
sharedGateVal = sharedGateVal.SigmoidOut(ctx)
sharedGateVal = sharedGateVal.Sigmoid(ctx)
// Broadcast gate to match dimensions
sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0))
sharedOut = sharedOut.Mul(ctx, sharedGateVal)
@@ -208,7 +207,7 @@ func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outp
// Model is the main Qwen3-Next model
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -354,8 +353,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,12 +10,11 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
*TextModel
*VisionModel `gguf:"v"`
@@ -173,8 +172,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"container/heap"
@@ -17,7 +17,7 @@ type SentencePiece struct {
vocab *Vocabulary
}
var _ Tokenizer = (*SentencePiece)(nil)
var _ TextProcessor = (*SentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
@@ -224,7 +224,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizer that use byte tokens like "<0xEA>"
// For tokenizers that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"log/slog"
@@ -15,7 +15,7 @@ import (
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
if err != nil {
t.Fatal(err)
}

17
model/textprocessor.go Normal file
View File

@@ -0,0 +1,17 @@
package model
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocabulary() *Vocabulary
}

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"log/slog"

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"testing"

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"fmt"
@@ -32,7 +32,7 @@ var wordPieceReplacer = strings.NewReplacer(
" 're", "'re",
)
// Decode implements Tokenizer.
// Decode implements TextProcessor.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
@@ -96,7 +96,7 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
}
}
// Encode implements Tokenizer.
// Encode implements TextProcessor.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
@@ -151,17 +151,17 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
// Is implements Tokenizer.
// Is implements TextProcessor.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements Tokenizer.
// Vocabulary implements TextProcessor.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ Tokenizer = (*WordPiece)(nil)
var _ TextProcessor = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
return WordPiece{

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"slices"

View File

@@ -37,7 +37,6 @@ import (
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/tokenizer"
_ "github.com/ollama/ollama/model/models"
)
@@ -211,9 +210,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
decoder := func(tokenID int) string {
text, _ := tok.Decode([]int32{int32(tokenID)})
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
return text
}
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
@@ -243,7 +242,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
if err != nil {
return nil, nil, nil, err
}
@@ -515,6 +514,13 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength)
nextBatch.seqs[seqIdx] = nil
continue
}
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []*input.Input{}
@@ -703,6 +709,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
seq.numPredicted++
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
seq.inputs = []*input.Input{nextToken}
nextBatchTokens[i] = nextToken
@@ -738,9 +745,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
seq.lastUpdatedAt = t
seq.numPredicted++
if seq.numPredicted == 1 {
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
seq.startedAt = seq.lastUpdatedAt
@@ -765,7 +770,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
@@ -774,25 +779,18 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
if err != nil {
panic("failed to decode token")
}
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
if seq.logprobs {
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
}
seq.pendingResponses = append(seq.pendingResponses, piece)
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(i, llm.DoneReasonLength)
continue
}
sequence := strings.Join(seq.pendingResponses, "")
if ok, stop := common.FindStop(sequence, seq.stop); ok {
@@ -879,7 +877,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var grammar *sample.GrammarSampler
var err error
if req.Grammar != "" {
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return

View File

@@ -3,7 +3,6 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
@@ -12,15 +11,22 @@ func Execute(args []string) error {
args = args[1:]
}
if len(args) > 0 {
switch args[0] {
case "--ollama-engine":
return ollamarunner.Execute(args[1:])
case "--imagegen-engine":
return imagegen.Execute(args[1:])
case "--mlx-engine":
return mlxrunner.Execute(args[1:])
}
var newRunner bool
var mlxRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--mlx-engine" {
args = args[1:]
mlxRunner = true
}
if mlxRunner {
return mlxrunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {
return llamarunner.Execute(args)
}
return llamarunner.Execute(args)
}

View File

@@ -7,7 +7,7 @@ import (
"slices"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/tokenizer"
"github.com/ollama/ollama/model"
)
// token represents information about a single token during sampling
@@ -168,15 +168,15 @@ type GrammarSampler struct {
grammar *llama.Grammar
}
func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(tok.Vocabulary().Values))
pieces := make([]string, len(tok.Vocabulary().Values))
for i := range tok.Vocabulary().Values {
pieces[i], _ = tok.Decode([]int32{int32(i)})
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(model.Vocabulary().Values))
pieces := make([]string, len(model.Vocabulary().Values))
for i := range model.Vocabulary().Values {
pieces[i], _ = model.Decode([]int32{int32(i)})
vocabIds[i] = uint32(i)
}
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tok.Vocabulary().EOS)
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}

View File

@@ -8,7 +8,7 @@ import (
"path/filepath"
"testing"
"github.com/ollama/ollama/tokenizer"
"github.com/ollama/ollama/model"
)
func TestWeighted(t *testing.T) {
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
}
}
func modelHelper(t testing.TB) tokenizer.Tokenizer {
func modelHelper(t testing.TB) model.BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.FromSlash("../tokenizer/testdata/llama3.2/encoder.json"))
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) tokenizer.Tokenizer {
merges := make([]string, 0, 1)
// Only need vocab for Grammar Test
return tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
return model.NewBytePairEncoding(
&model.Vocabulary{
Values: tokens,
Types: make([]int32, len(vocab)),
Merges: merges,

View File

@@ -1,5 +1,5 @@
#!/bin/sh
# This script installs Ollama on Linux and macOS.
# This script installs Ollama on Linux.
# It detects the current operating system architecture and installs the appropriate version of Ollama.
set -eu
@@ -27,7 +27,8 @@ require() {
echo $MISSING
}
OS="$(uname -s)"
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;;
@@ -35,65 +36,6 @@ case "$ARCH" in
*) error "Unsupported architecture: $ARCH" ;;
esac
###########################################
# macOS
###########################################
if [ "$OS" = "Darwin" ]; then
NEEDS=$(require curl unzip)
if [ -n "$NEEDS" ]; then
status "ERROR: The following tools are required but missing:"
for NEED in $NEEDS; do
echo " - $NEED"
done
exit 1
fi
if [ -n "${OLLAMA_VERSION:-}" ]; then
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/download/${OLLAMA_VERSION}/Ollama-darwin.zip"
else
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/latest/download/Ollama-darwin.zip"
fi
if pgrep -x Ollama >/dev/null 2>&1; then
status "Stopping running Ollama instance..."
pkill -x Ollama 2>/dev/null || true
sleep 2
fi
if [ -d "/Applications/Ollama.app" ]; then
status "Removing existing Ollama installation..."
rm -rf "/Applications/Ollama.app"
fi
status "Downloading Ollama for macOS..."
curl --fail --show-error --location --progress-bar \
-o "$TEMP_DIR/Ollama-darwin.zip" "$DOWNLOAD_URL"
status "Installing Ollama to /Applications..."
unzip -q "$TEMP_DIR/Ollama-darwin.zip" -d "$TEMP_DIR"
mv "$TEMP_DIR/Ollama.app" "/Applications/"
status "Adding 'ollama' command to PATH (may require password)..."
mkdir -p "/usr/local/bin" 2>/dev/null || sudo mkdir -p "/usr/local/bin"
ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama" 2>/dev/null || \
sudo ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama"
if [ -z "${OLLAMA_NO_START:-}" ]; then
status "Starting Ollama..."
open -a Ollama --args hidden
fi
status "Install complete. You can now run 'ollama'."
exit 0
fi
###########################################
# Linux
###########################################
[ "$OS" = "Linux" ] || error 'This script is intended to run on Linux and macOS only.'
IS_WSL2=false
KERN=$(uname -r)

View File

@@ -1,422 +0,0 @@
package server
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
const (
serverConfigFilename = "server.json"
serverConfigVersion = 1
)
var errAliasCycle = errors.New("alias cycle detected")
type aliasEntry struct {
Alias string `json:"alias"`
Target string `json:"target"`
PrefixMatching bool `json:"prefix_matching,omitempty"`
}
type serverConfig struct {
Version int `json:"version"`
Aliases []aliasEntry `json:"aliases"`
}
type store struct {
mu sync.RWMutex
path string
entries map[string]aliasEntry // normalized alias -> entry (exact matches)
prefixEntries []aliasEntry // prefix matches, sorted longest-first
}
func createStore(path string) (*store, error) {
store := &store{
path: path,
entries: make(map[string]aliasEntry),
}
if err := store.load(); err != nil {
return nil, err
}
return store, nil
}
func (s *store) load() error {
data, err := os.ReadFile(s.path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
var cfg serverConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return err
}
if cfg.Version != 0 && cfg.Version != serverConfigVersion {
return fmt.Errorf("unsupported router config version %d", cfg.Version)
}
for _, entry := range cfg.Aliases {
targetName := model.ParseName(entry.Target)
if !targetName.IsValid() {
slog.Warn("invalid alias target in router config", "target", entry.Target)
continue
}
canonicalTarget := displayAliasName(targetName)
if entry.PrefixMatching {
// Prefix aliases don't need to be valid model names
alias := strings.TrimSpace(entry.Alias)
if alias == "" {
slog.Warn("empty prefix alias in router config")
continue
}
s.prefixEntries = append(s.prefixEntries, aliasEntry{
Alias: alias,
Target: canonicalTarget,
PrefixMatching: true,
})
} else {
aliasName := model.ParseName(entry.Alias)
if !aliasName.IsValid() {
slog.Warn("invalid alias name in router config", "alias", entry.Alias)
continue
}
canonicalAlias := displayAliasName(aliasName)
s.entries[normalizeAliasKey(aliasName)] = aliasEntry{
Alias: canonicalAlias,
Target: canonicalTarget,
}
}
}
// Sort prefix entries by alias length descending (longest prefix wins)
s.sortPrefixEntriesLocked()
return nil
}
func (s *store) saveLocked() error {
dir := filepath.Dir(s.path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
// Combine exact and prefix entries
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
for _, entry := range s.entries {
entries = append(entries, entry)
}
entries = append(entries, s.prefixEntries...)
sort.Slice(entries, func(i, j int) bool {
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
})
cfg := serverConfig{
Version: serverConfigVersion,
Aliases: entries,
}
f, err := os.CreateTemp(dir, "router-*.json")
if err != nil {
return err
}
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(cfg); err != nil {
_ = f.Close()
_ = os.Remove(f.Name())
return err
}
if err := f.Close(); err != nil {
_ = os.Remove(f.Name())
return err
}
if err := os.Chmod(f.Name(), 0o644); err != nil {
_ = os.Remove(f.Name())
return err
}
return os.Rename(f.Name(), s.path)
}
func (s *store) ResolveName(name model.Name) (model.Name, bool, error) {
// If a local model exists, do not allow alias shadowing (highest priority).
exists, err := localModelExists(name)
if err != nil {
return name, false, err
}
if exists {
return name, false, nil
}
key := normalizeAliasKey(name)
s.mu.RLock()
entry, exactMatch := s.entries[key]
var prefixMatch *aliasEntry
if !exactMatch {
// Try prefix matching - prefixEntries is sorted longest-first
nameStr := strings.ToLower(displayAliasName(name))
for i := range s.prefixEntries {
prefix := strings.ToLower(s.prefixEntries[i].Alias)
if strings.HasPrefix(nameStr, prefix) {
prefixMatch = &s.prefixEntries[i]
break // First match is longest due to sorting
}
}
}
s.mu.RUnlock()
if !exactMatch && prefixMatch == nil {
return name, false, nil
}
var current string
var visited map[string]struct{}
if exactMatch {
visited = map[string]struct{}{key: {}}
current = entry.Target
} else {
// For prefix match, use the target as-is
visited = map[string]struct{}{}
current = prefixMatch.Target
}
targetKey := normalizeAliasKeyString(current)
for {
targetName := model.ParseName(current)
if !targetName.IsValid() {
return name, false, fmt.Errorf("alias target %q is invalid", current)
}
if _, seen := visited[targetKey]; seen {
return name, false, errAliasCycle
}
visited[targetKey] = struct{}{}
s.mu.RLock()
next, ok := s.entries[targetKey]
s.mu.RUnlock()
if !ok {
return targetName, true, nil
}
current = next.Target
targetKey = normalizeAliasKeyString(current)
}
}
func (s *store) Set(alias, target model.Name, prefixMatching bool) error {
targetKey := normalizeAliasKey(target)
s.mu.Lock()
defer s.mu.Unlock()
if prefixMatching {
// For prefix aliases, we skip cycle detection since prefix matching
// works differently and the target is a specific model
aliasStr := displayAliasName(alias)
// Remove any existing prefix entry with the same alias
for i, e := range s.prefixEntries {
if strings.EqualFold(e.Alias, aliasStr) {
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
break
}
}
s.prefixEntries = append(s.prefixEntries, aliasEntry{
Alias: aliasStr,
Target: displayAliasName(target),
PrefixMatching: true,
})
s.sortPrefixEntriesLocked()
return s.saveLocked()
}
aliasKey := normalizeAliasKey(alias)
if aliasKey == targetKey {
return fmt.Errorf("alias cannot point to itself")
}
visited := map[string]struct{}{aliasKey: {}}
currentKey := targetKey
for {
if _, seen := visited[currentKey]; seen {
return errAliasCycle
}
visited[currentKey] = struct{}{}
next, ok := s.entries[currentKey]
if !ok {
break
}
currentKey = normalizeAliasKeyString(next.Target)
}
s.entries[aliasKey] = aliasEntry{
Alias: displayAliasName(alias),
Target: displayAliasName(target),
}
return s.saveLocked()
}
func (s *store) Delete(alias model.Name) (bool, error) {
aliasKey := normalizeAliasKey(alias)
s.mu.Lock()
defer s.mu.Unlock()
// Try exact match first
if _, ok := s.entries[aliasKey]; ok {
delete(s.entries, aliasKey)
return true, s.saveLocked()
}
// Try prefix entries
aliasStr := displayAliasName(alias)
for i, e := range s.prefixEntries {
if strings.EqualFold(e.Alias, aliasStr) {
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
return true, s.saveLocked()
}
}
return false, nil
}
// DeleteByString deletes an alias by its raw string value, useful for prefix
// aliases that may not be valid model names.
func (s *store) DeleteByString(alias string) (bool, error) {
alias = strings.TrimSpace(alias)
aliasLower := strings.ToLower(alias)
s.mu.Lock()
defer s.mu.Unlock()
// Try prefix entries first (since this is mainly for prefix aliases)
for i, e := range s.prefixEntries {
if strings.EqualFold(e.Alias, alias) {
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
return true, s.saveLocked()
}
}
// Also check exact entries by normalized key
if _, ok := s.entries[aliasLower]; ok {
delete(s.entries, aliasLower)
return true, s.saveLocked()
}
return false, nil
}
func (s *store) List() []aliasEntry {
s.mu.RLock()
defer s.mu.RUnlock()
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
for _, entry := range s.entries {
entries = append(entries, entry)
}
entries = append(entries, s.prefixEntries...)
sort.Slice(entries, func(i, j int) bool {
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
})
return entries
}
func normalizeAliasKey(name model.Name) string {
return strings.ToLower(displayAliasName(name))
}
func (s *store) sortPrefixEntriesLocked() {
sort.Slice(s.prefixEntries, func(i, j int) bool {
// Sort by length descending (longest prefix first)
return len(s.prefixEntries[i].Alias) > len(s.prefixEntries[j].Alias)
})
}
func normalizeAliasKeyString(value string) string {
n := model.ParseName(value)
if !n.IsValid() {
return strings.ToLower(strings.TrimSpace(value))
}
return normalizeAliasKey(n)
}
func displayAliasName(n model.Name) string {
display := n.DisplayShortest()
if strings.EqualFold(n.Tag, "latest") {
if idx := strings.LastIndex(display, ":"); idx != -1 {
return display[:idx]
}
}
return display
}
func localModelExists(name model.Name) (bool, error) {
manifests, err := manifest.Manifests(true)
if err != nil {
return false, err
}
needle := name.String()
for existing := range manifests {
if strings.EqualFold(existing.String(), needle) {
return true, nil
}
}
return false, nil
}
func serverConfigPath() string {
home, err := os.UserHomeDir()
if err != nil {
return filepath.Join(".ollama", serverConfigFilename)
}
return filepath.Join(home, ".ollama", serverConfigFilename)
}
func (s *Server) aliasStore() (*store, error) {
s.aliasesOnce.Do(func() {
s.aliases, s.aliasesErr = createStore(serverConfigPath())
})
return s.aliases, s.aliasesErr
}
func (s *Server) resolveAlias(name model.Name) (model.Name, bool, error) {
store, err := s.aliasStore()
if err != nil {
return name, false, err
}
if store == nil {
return name, false, nil
}
return store.ResolveName(name)
}

View File

@@ -22,7 +22,6 @@ import (
"os/signal"
"slices"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@@ -52,7 +51,7 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
imagegenmanifest "github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen"
xserver "github.com/ollama/ollama/x/server"
)
@@ -82,9 +81,6 @@ type Server struct {
addr net.Addr
sched *Scheduler
defaultNumCtx int
aliasesOnce sync.Once
aliases *store
aliasesErr error
}
func init() {
@@ -195,16 +191,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
resolvedName, _, err := s.resolveAlias(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
name = resolvedName
// We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure.
name, err = getExistingName(name)
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
@@ -1106,7 +1095,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
if info, err := imagegenmanifest.GetModelInfo(name.String()); err == nil {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization
@@ -1591,9 +1580,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.POST("/api/copy", s.CopyHandler)
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
// Inference
r.GET("/api/ps", s.PsHandler)
@@ -1964,20 +1950,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
resolvedName, _, err := s.resolveAlias(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
name = resolvedName
name, err = getExistingName(name)
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
m, err := GetModel(name.String())
m, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):

View File

@@ -1,159 +0,0 @@
package server
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/types/model"
)
type aliasListResponse struct {
Aliases []aliasEntry `json:"aliases"`
}
type aliasDeleteRequest struct {
Alias string `json:"alias"`
}
func (s *Server) ListAliasesHandler(c *gin.Context) {
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var aliases []aliasEntry
if store != nil {
aliases = store.List()
}
c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases})
}
func (s *Server) CreateAliasHandler(c *gin.Context) {
var req aliasEntry
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Alias = strings.TrimSpace(req.Alias)
req.Target = strings.TrimSpace(req.Target)
if req.Alias == "" || req.Target == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"})
return
}
// Target must always be a valid model name
targetName := model.ParseName(req.Target)
if !targetName.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)})
return
}
var aliasName model.Name
if req.PrefixMatching {
// For prefix aliases, we still parse the alias to normalize it,
// but we allow any non-empty string since prefix patterns may not be valid model names
aliasName = model.ParseName(req.Alias)
// Even if not valid as a model name, we accept it for prefix matching
} else {
aliasName = model.ParseName(req.Alias)
if !aliasName.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
return
}
if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"})
return
}
exists, err := localModelExists(aliasName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if exists {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)})
return
}
}
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := store.Set(aliasName, targetName, req.PrefixMatching); err != nil {
status := http.StatusInternalServerError
if errors.Is(err, errAliasCycle) {
status = http.StatusBadRequest
}
c.AbortWithStatusJSON(status, gin.H{"error": err.Error()})
return
}
resp := aliasEntry{
Alias: displayAliasName(aliasName),
Target: displayAliasName(targetName),
PrefixMatching: req.PrefixMatching,
}
if req.PrefixMatching && !aliasName.IsValid() {
// For prefix aliases that aren't valid model names, use the raw alias
resp.Alias = req.Alias
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) DeleteAliasHandler(c *gin.Context) {
var req aliasDeleteRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Alias = strings.TrimSpace(req.Alias)
if req.Alias == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"})
return
}
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
aliasName := model.ParseName(req.Alias)
var deleted bool
if aliasName.IsValid() {
deleted, err = store.Delete(aliasName)
} else {
// For invalid model names (like prefix aliases), try deleting by raw string
deleted, err = store.DeleteByString(req.Alias)
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !deleted {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)})
return
}
c.JSON(http.StatusOK, gin.H{"deleted": true})
}

View File

@@ -1,426 +0,0 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
func TestAliasShadowingRejected(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
s := Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "shadowed-model",
RemoteHost: "example.com",
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", w.Code)
}
}
func TestAliasResolvesForChatRemote(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
var remoteModel string
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req api.ChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatal(err)
}
remoteModel = req.Model
w.Header().Set("Content-Type", "application/json")
resp := api.ChatResponse{
Model: req.Model,
Done: true,
DoneReason: "load",
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
t.Fatal(err)
}
}))
defer rs.Close()
p, err := url.Parse(rs.URL)
if err != nil {
t.Fatal(err)
}
t.Setenv("OLLAMA_REMOTES", p.Hostname())
s := Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "target-model",
RemoteHost: rs.URL,
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "alias-model",
Messages: []api.Message{{Role: "user", Content: "hi"}},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Model != "alias-model" {
t.Fatalf("expected response model to be alias-model, got %q", resp.Model)
}
if remoteModel != "test" {
t.Fatalf("expected remote model to be 'test', got %q", remoteModel)
}
}
func TestPrefixAliasBasicMatching(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Create a prefix alias: "myprefix-" -> "targetmodel"
targetName := model.ParseName("targetmodel")
// Set a prefix alias (using "myprefix-" as the pattern)
store.mu.Lock()
store.prefixEntries = append(store.prefixEntries, aliasEntry{
Alias: "myprefix-",
Target: "targetmodel",
PrefixMatching: true,
})
store.mu.Unlock()
// Test that "myprefix-foo" resolves to "targetmodel"
testName := model.ParseName("myprefix-foo")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
if resolved.DisplayShortest() != targetName.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", targetName.DisplayShortest(), resolved.DisplayShortest())
}
// Test that "otherprefix-foo" does not resolve
otherName := model.ParseName("otherprefix-foo")
_, wasResolved, err = store.ResolveName(otherName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if wasResolved {
t.Fatal("expected name not to be resolved")
}
// Test that exact alias takes precedence
exactAlias := model.ParseName("myprefix-exact")
exactTarget := model.ParseName("exacttarget")
if err := store.Set(exactAlias, exactTarget, false); err != nil {
t.Fatal(err)
}
resolved, wasResolved, err = store.ResolveName(exactAlias)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
if resolved.DisplayShortest() != exactTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q (exact match), got %q", exactTarget.DisplayShortest(), resolved.DisplayShortest())
}
}
func TestPrefixAliasLongestMatchWins(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Add two prefix aliases with overlapping patterns
store.mu.Lock()
store.prefixEntries = []aliasEntry{
{Alias: "abc-", Target: "short-target", PrefixMatching: true},
{Alias: "abc-def-", Target: "long-target", PrefixMatching: true},
}
store.sortPrefixEntriesLocked()
store.mu.Unlock()
// "abc-def-ghi" should match the longer prefix "abc-def-"
testName := model.ParseName("abc-def-ghi")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
expectedLongTarget := model.ParseName("long-target")
if resolved.DisplayShortest() != expectedLongTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q (longest prefix match), got %q", expectedLongTarget.DisplayShortest(), resolved.DisplayShortest())
}
// "abc-xyz" should match the shorter prefix "abc-"
testName2 := model.ParseName("abc-xyz")
resolved, wasResolved, err = store.ResolveName(testName2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
expectedShortTarget := model.ParseName("short-target")
if resolved.DisplayShortest() != expectedShortTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", expectedShortTarget.DisplayShortest(), resolved.DisplayShortest())
}
}
func TestPrefixAliasChain(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Create a chain: prefix "test-" -> "intermediate" -> "final"
intermediate := model.ParseName("intermediate")
final := model.ParseName("final")
// Add prefix alias
store.mu.Lock()
store.prefixEntries = []aliasEntry{
{Alias: "test-", Target: "intermediate", PrefixMatching: true},
}
store.mu.Unlock()
// Add exact alias for the intermediate step
if err := store.Set(intermediate, final, false); err != nil {
t.Fatal(err)
}
// "test-foo" should resolve through the chain to "final"
testName := model.ParseName("test-foo")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
if resolved.DisplayShortest() != final.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", final.DisplayShortest(), resolved.DisplayShortest())
}
}
func TestPrefixAliasCRUD(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
s := Server{}
// Create a prefix alias via API
w := createRequest(t, s.CreateAliasHandler, aliasEntry{
Alias: "myprefix-",
Target: "llama2",
PrefixMatching: true,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
var createResp aliasEntry
if err := json.NewDecoder(w.Body).Decode(&createResp); err != nil {
t.Fatal(err)
}
if !createResp.PrefixMatching {
t.Fatal("expected prefix_matching to be true in response")
}
// List aliases and verify the prefix alias is included
w = createRequest(t, s.ListAliasesHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var listResp aliasListResponse
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
t.Fatal(err)
}
found := false
for _, a := range listResp.Aliases {
if a.PrefixMatching && a.Target == "llama2" {
found = true
break
}
}
if !found {
t.Fatal("expected to find prefix alias in list")
}
// Delete the prefix alias
w = createRequest(t, s.DeleteAliasHandler, aliasDeleteRequest{Alias: "myprefix-"})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify it's deleted
w = createRequest(t, s.ListAliasesHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
t.Fatal(err)
}
for _, a := range listResp.Aliases {
if a.PrefixMatching {
t.Fatal("expected prefix alias to be deleted")
}
}
}
func TestPrefixAliasCaseInsensitive(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Add a prefix alias with mixed case
store.mu.Lock()
store.prefixEntries = []aliasEntry{
{Alias: "MyPrefix-", Target: "targetmodel", PrefixMatching: true},
}
store.mu.Unlock()
// Test that matching is case-insensitive
testName := model.ParseName("myprefix-foo")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved (case-insensitive)")
}
expectedTarget := model.ParseName("targetmodel")
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
}
// Test uppercase request
testName2 := model.ParseName("MYPREFIX-BAR")
_, wasResolved, err = store.ResolveName(testName2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved (uppercase)")
}
}
func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
s := Server{}
// Create a local model that would match a prefix alias
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "myprefix-localmodel",
RemoteHost: "example.com",
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Create a prefix alias that would match the local model name
w = createRequest(t, s.CreateAliasHandler, aliasEntry{
Alias: "myprefix-",
Target: "someothermodel",
PrefixMatching: true,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify that resolving "myprefix-localmodel" returns the local model, not the alias target
store, err := s.aliasStore()
if err != nil {
t.Fatal(err)
}
localModelName := model.ParseName("myprefix-localmodel")
resolved, wasResolved, err := store.ResolveName(localModelName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if wasResolved {
t.Fatalf("expected local model to take precedence (wasResolved should be false), but got resolved to %q", resolved.DisplayShortest())
}
if resolved.DisplayShortest() != localModelName.DisplayShortest() {
t.Fatalf("expected resolved name to be local model %q, got %q", localModelName.DisplayShortest(), resolved.DisplayShortest())
}
// Also verify that a non-local model matching the prefix DOES resolve to the alias target
nonLocalName := model.ParseName("myprefix-nonexistent")
resolved, wasResolved, err = store.ResolveName(nonLocalName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected non-local model to resolve via prefix alias")
}
expectedTarget := model.ParseName("someothermodel")
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
}
}

View File

@@ -5,13 +5,9 @@ import (
"errors"
"fmt"
"log/slog"
"math/rand"
"os"
"os/exec"
"reflect"
"slices"
"sort"
"strconv"
"strings"
"sync"
"time"
@@ -25,7 +21,6 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
@@ -200,14 +195,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if s.loadSafetensors(pending) {
// Check for image generation models - all use MLX runner
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadMLX(pending) {
break
}
continue
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) {
break
}
continue
}
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -411,9 +417,9 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
numParallel = 1
}
// Some architectures are not safe with num_parallel > 1.
// `mllama`, `qwen3vl`, and `qwen3vlmoe` are snowflakes and uses an encoder cache which cannot be used with num_parallel > 1
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe"}, req.model.Config.ModelFamily) && numParallel != 1 {
numParallel = 1
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
}
@@ -557,101 +563,20 @@ iGPUScan:
return false
}
func subproc(args, environ []string) (*exec.Cmd, int, error) {
exe, err := os.Executable()
if err != nil {
return nil, 0, fmt.Errorf("unable to lookup executable path: %w", err)
}
for range 3 {
// get a random port in the ephemeral range
port := rand.Intn(65535-49152) + 49152
cmd := exec.Command(exe, slices.Concat([]string{"runner"}, args, []string{"--port", strconv.Itoa(port)})...)
cmd.Env = slices.Concat(os.Environ(), environ)
cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
continue
}
return cmd, port, nil
}
return nil, 0, fmt.Errorf("unable to start subprocess after multiple attempts")
}
func (s *Scheduler) loadSafetensors(req *LlmRequest) bool {
if slices.Contains(req.model.Config.Capabilities, "image") {
return s.loadImageGen(req)
}
args := []string{"--mlx-engine", "--model", req.model.ShortName}
environ := []string{}
cmd, port, err := subproc(args, environ)
if err != nil {
req.errCh <- fmt.Errorf("failed to start mlx subprocess: %w", err)
return true
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
llama: &mlxrunner.Client{
Cmd: cmd,
Port: port,
},
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loadedMu.Unlock()
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
s.expiredCh <- runner
})
}
runner.refMu.Unlock()
req.useLoadedRunner(runner, s.finishedReqCh)
for range time.Tick(20 * time.Millisecond) {
if err := func() error {
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
return runner.llama.Ping(ctx)
}(); err != nil {
continue
}
break
}
return true
}
// loadImageGen loads an experimental safetensors model using the unified MLX runner.
// loadMLX loads an experimental safetensors model using the unified MLX runner.
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode imagegen.ModelMode
var mode mlxrunner.ModelMode
if slices.Contains(req.model.Config.Capabilities, "image") {
mode = imagegen.ModeImageGen
mode = mlxrunner.ModeImageGen
} else {
mode = imagegen.ModeLLM
mode = mlxrunner.ModeLLM
}
// Use model name for MLX (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := imagegen.NewServer(modelName, mode)
server, err := mlxrunner.NewServer(modelName, mode)
if err != nil {
req.errCh <- err
return true

View File

@@ -1,310 +0,0 @@
package tokenizer
import (
"encoding/json"
"errors"
"io"
"os"
"github.com/ollama/ollama/types/model"
)
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type Tokenizer interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocabulary() *Vocabulary
}
func New(root *model.Root) (Tokenizer, error) {
f, err := root.Open("tokenizer.json")
if err != nil {
return nil, err
}
defer f.Close()
var tokenizer struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges json.RawMessage `json:"merges"`
} `json:"model"`
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
Decoder json.RawMessage `json:"decoder"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
if err := json.NewDecoder(f).Decode(&tokenizer); err != nil {
return nil, err
}
special := make(map[int32]struct{})
for _, token := range tokenizer.AddedTokens {
tokenizer.Model.Vocab[token.Content] = token.ID
special[token.ID] = struct{}{}
}
vocab, err := specialTokens(root, tokenizer.Model.Vocab)
if err != nil {
return nil, err
}
vocab.Values = make([]string, len(tokenizer.Model.Vocab))
vocab.Scores = make([]float32, len(tokenizer.Model.Vocab))
vocab.Types = make([]int32, len(tokenizer.Model.Vocab))
for content, id := range tokenizer.Model.Vocab {
vocab.Values[id] = content
vocab.Scores[id] = float32(id)
vocab.Types[id] = TOKEN_TYPE_NORMAL
if _, ok := special[id]; ok {
vocab.Types[id] = TOKEN_TYPE_USER_DEFINED
}
}
if tokenizer.Model.Merges != nil {
var pairs [][]string
if err := json.Unmarshal(tokenizer.Model.Merges, &pairs); err == nil {
vocab.Merges = make([]string, len(pairs))
for i, pair := range pairs {
vocab.Merges[i] = pair[0] + " " + pair[1]
}
} else if err := json.Unmarshal(tokenizer.Model.Merges, &vocab.Merges); err != nil {
return nil, err
}
}
vocab.valuesOnce.Do(func() {})
vocab.values = tokenizer.Model.Vocab
if tokenizer.Model.Type == "WordPiece" {
return NewWordPiece(vocab, true), nil
}
if tokenizer.Decoder != nil {
var decoder struct {
Type string `json:"type"`
Decoders []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"string"`
} `json:"pattern"`
} `json:"decoders"`
}
if err := json.Unmarshal(tokenizer.Decoder, &decoder); err != nil {
return nil, err
}
if decoder.Type == "Sequence" {
for _, d := range decoder.Decoders {
if d.Type == "Replace" && d.Pattern.String == "▁" {
return NewSentencePiece(vocab), nil
}
}
}
}
var pretokenizers []string
if tokenizer.PreTokenizer != nil {
var pretokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string
} `json:"pattern"`
IndividualDigits bool `json:"individual_digits"`
}
}
if err := json.Unmarshal(tokenizer.PreTokenizer, &pretokenizer); err != nil {
return nil, err
}
if pretokenizer.Type == "Sequence" {
for _, pretokenizer := range pretokenizer.Pretokenizers {
switch pretokenizer.Type {
case "Digits":
if pretokenizer.IndividualDigits {
pretokenizers = append(pretokenizers, `\d`)
} else {
pretokenizers = append(pretokenizers, `\d+`)
}
case "Punctuation":
pretokenizers = append(pretokenizers, `[^\p{L}\p{N}]+`)
case "Split":
pretokenizers = append(pretokenizers, pretokenizer.Pattern.Regex)
case "WhitespaceSplit":
pretokenizers = append(pretokenizers, `\s+(?!\S)|\s+`)
}
}
}
}
return NewBytePairEncoding(vocab, pretokenizers...), nil
}
// valueOrValues is a type that can unmarshal from either a single value or an array of values.
type valueOrValues[E any] []E
func (m *valueOrValues[E]) UnmarshalJSON(data []byte) error {
var s []E
if err := json.Unmarshal(data, &s); err != nil {
var e E
if err := json.Unmarshal(data, &e); err != nil {
return err
}
s = []E{e}
}
*m = valueOrValues[E](s)
return nil
}
type specialTokenIDs struct {
BOSTokenID valueOrValues[int32] `json:"bos_token_id"`
EOSTokenID valueOrValues[int32] `json:"eos_token_id"`
}
// stringOrContent is a type that can unmarshal from either a string or an object with a "content" field.
type stringOrContent string
func (t *stringOrContent) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return err
}
if content, ok := m["content"].(string); ok {
s = content
}
}
*t = stringOrContent(s)
return nil
}
func specialTokens(root *model.Root, values map[string]int32) (*Vocabulary, error) {
var vocab Vocabulary
for _, c := range []struct {
name string
fn func(io.Reader) error
}{
{
name: "generation_config.json",
fn: func(r io.Reader) error {
var c specialTokenIDs
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
vocab.BOS = c.BOSTokenID
vocab.EOS = c.EOSTokenID
return nil
},
},
{
name: "config.json",
fn: func(r io.Reader) error {
var c specialTokenIDs
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if len(vocab.BOS) == 0 {
vocab.BOS = c.BOSTokenID
}
if len(vocab.EOS) == 0 {
vocab.EOS = c.EOSTokenID
}
return nil
},
},
{
name: "tokenizer_config.json",
fn: func(r io.Reader) error {
var c struct {
BOSToken stringOrContent `json:"bos_token"`
EOSToken stringOrContent `json:"eos_token"`
PADToken stringOrContent `json:"pad_token"`
AddBOSToken bool `json:"add_bos_token"`
AddEOSToken bool `json:"add_eos_token"`
}
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if len(vocab.BOS) == 0 && c.BOSToken != "" {
if id, ok := values[string(c.BOSToken)]; ok {
vocab.BOS = []int32{id}
}
}
if len(vocab.EOS) == 0 && c.EOSToken != "" {
if id, ok := values[string(c.EOSToken)]; ok {
vocab.EOS = []int32{id}
}
}
vocab.AddBOS = c.AddBOSToken
vocab.AddEOS = c.AddEOSToken
return nil
},
},
{
name: "special_tokens_map.json",
fn: func(r io.Reader) error {
var c map[string]stringOrContent
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if bos, ok := c["bos_token"]; ok && len(vocab.BOS) == 0 {
if id, ok := values[string(bos)]; ok {
vocab.BOS = []int32{id}
}
}
if eos, ok := c["eos_token"]; ok && len(vocab.EOS) == 0 {
if id, ok := values[string(eos)]; ok {
vocab.EOS = []int32{id}
}
}
return nil
},
},
} {
if err := func() error {
f, err := root.Open(c.name)
if errors.Is(err, os.ErrNotExist) {
return nil
} else if err != nil {
return err
}
defer f.Close()
return c.fn(f)
}(); err != nil {
return nil, err
}
}
return &vocab, nil
}

View File

@@ -1,309 +0,0 @@
package model
import (
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"io/fs"
"iter"
"maps"
"mime"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/envconfig"
)
func root() (*os.Root, error) {
root, err := os.OpenRoot(envconfig.Models())
if err != nil {
return nil, err
}
for _, sub := range []string{"manifests", "blobs"} {
if _, err := root.Stat(sub); errors.Is(err, fs.ErrNotExist) {
if err := root.MkdirAll(sub, 0o750); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
}
return root, nil
}
// Open opens an existing file for reading. It will return [fs.ErrNotExist]
// if the file does not exist. The returned [*Root] can only be used for reading.
// It is the caller's responsibility to close the file when done.
func Open(n Name) (*Root, error) {
r, err := root()
if err != nil {
return nil, err
}
f, err := r.Open(filepath.Join("manifests", n.Filepath()))
if err != nil {
return nil, err
}
defer f.Close()
var m manifest
if err := json.NewDecoder(f).Decode(&m); err != nil {
return nil, err
}
blobs := make(map[string]*blob, len(m.Layers)+1)
blobs[NamePrefix] = m.Config
for _, layer := range m.Layers {
if layer.Name == "" && layer.MediaType != "" {
mediatype, _, err := mime.ParseMediaType(layer.MediaType)
if err != nil {
return nil, err
}
if suffix, ok := strings.CutPrefix(mediatype, MediaTypePrefix); ok {
layer.Name = NamePrefix + suffix
}
}
blobs[layer.Name] = layer
}
return &Root{
root: r,
name: n,
blobs: blobs,
flags: os.O_RDONLY,
}, nil
}
// Create creates a new file. The returned [Root] can be used for both reading
// and writing. It is the caller's responsibility to close the file when done
// in order to finalize any new blobs and write the manifest.
func Create(n Name) (*Root, error) {
r, err := root()
if err != nil {
return nil, err
}
return &Root{
root: r,
name: n,
blobs: make(map[string]*blob),
flags: os.O_RDWR,
}, nil
}
type blob struct {
Digest string `json:"digest"`
MediaType string `json:"mediaType"`
Name string `json:"name,omitempty"`
Size int64 `json:"size"`
// tempfile is the temporary file where the blob data is written.
tempfile *os.File
// hash is the hash.Hash used to compute the blob digest.
hash hash.Hash
}
func (b *blob) Write(p []byte) (int, error) {
return io.MultiWriter(b.tempfile, b.hash).Write(p)
}
func (b *blob) Filepath() string {
return strings.ReplaceAll(b.Digest, ":", "-")
}
type manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *blob `json:"config"`
Layers []*blob `json:"layers"`
}
// Root represents a model file. It can be used to read and write blobs
// associated with the model.
//
// Blobs are identified by name. Certain names are special and reserved;
// see [NamePrefix] for details.
type Root struct {
root *os.Root
name Name
blobs map[string]*blob
flags int
}
const MediaTypePrefix = "application/vnd.ollama"
// NamePrefix is the prefix used for identifying special names. Names
// with this prefix are idenfitied by their media types:
//
// - name: NamePrefix + suffix
// - mediaType: [MediaTypePrefix] + suffix
//
// For example:
//
// - name: "./..image.model"
// - mediaType: "application/vnd.ollama.image.model"
//
// NamePrefix by itself identifies the manifest config.
const NamePrefix = "./."
// Open opens the named blob for reading. It is the caller's responsibility
// to close the returned [io.ReadCloser] when done. It will return
// [fs.ErrNotExist] if the blob does not exist.
func (r Root) Open(name string) (io.ReadCloser, error) {
if b, ok := r.blobs[name]; ok {
r, err := r.root.Open(filepath.Join("blobs", b.Filepath()))
if err != nil {
return nil, err
}
return r, nil
}
return nil, fs.ErrNotExist
}
func (r Root) ReadFile(name string) ([]byte, error) {
f, err := r.Open(name)
if err != nil {
return nil, err
}
defer f.Close()
return io.ReadAll(f)
}
// Create creates or replaces a named blob in the file. If the blob already
// exists, it will be overwritten. It will return [fs.ErrInvalid] if the file
// was opened in read-only mode. The returned [io.Writer] can be used to write
// to the blob and does not need be closed, but the file must be closed to
// finalize the blob.
func (r *Root) Create(name string) (io.Writer, error) {
if r.flags&os.O_RDWR != 0 {
w, err := os.CreateTemp(r.root.Name(), "")
if err != nil {
return nil, err
}
r.blobs[name] = &blob{Name: name, tempfile: w, hash: sha256.New()}
return r.blobs[name], nil
}
return nil, fs.ErrInvalid
}
// Close closes the file. If the file was opened in read-write mode, it
// will finalize any writeable blobs and write the manifest.
func (r *Root) Close() error {
if r.flags&os.O_RDWR != 0 {
for _, b := range r.blobs {
if b.tempfile != nil {
fi, err := b.tempfile.Stat()
if err != nil {
return err
}
if err := b.tempfile.Close(); err != nil {
return err
}
b.Size = fi.Size()
b.Digest = fmt.Sprintf("sha256:%x", b.hash.Sum(nil))
if suffix, ok := strings.CutPrefix(b.Name, NamePrefix); ok {
if b.Name == NamePrefix {
b.MediaType = "application/vnd.docker.container.image.v1+json"
} else {
b.MediaType = MediaTypePrefix + suffix
}
b.Name = ""
}
rel, err := filepath.Rel(r.root.Name(), b.tempfile.Name())
if err != nil {
return err
}
if err := r.root.Rename(rel, filepath.Join("blobs", b.Filepath())); err != nil {
return err
}
}
}
p := filepath.Join("manifests", r.name.Filepath())
if _, err := r.root.Stat(filepath.Dir(p)); errors.Is(err, os.ErrNotExist) {
if err := r.root.MkdirAll(filepath.Dir(p), 0o750); err != nil {
return err
}
} else if err != nil {
return err
}
f, err := r.root.OpenFile(p, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o640)
if err != nil {
return err
}
defer f.Close()
if err := json.NewEncoder(f).Encode(manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: r.blobs[NamePrefix],
Layers: func() []*blob {
blobs := make([]*blob, 0, len(r.blobs))
for name, b := range r.blobs {
if name != NamePrefix {
blobs = append(blobs, b)
}
}
return blobs
}(),
}); err != nil {
return err
}
}
return r.root.Close()
}
// Name returns the name of the file.
func (r Root) Name() Name {
return r.name
}
// Names returns an iterator over the names in the file.
func (r Root) Names() iter.Seq[string] {
return maps.Keys(r.blobs)
}
// Glob returns an iterator over the names in the file that match the given
// pattern.
//
// The pattern syntax is the same as [filepath.Match]. As with filepath.Match,
// the only possible returned error is ErrBadPattern, when pattern is malformed.
func (r Root) Glob(pattern string) (iter.Seq[string], error) {
if _, err := filepath.Match(pattern, ""); err != nil {
return nil, err
}
return func(yield func(string) bool) {
for name, blob := range r.blobs {
if matched, _ := filepath.Match(pattern, name); matched {
if !yield(blob.Filepath()) {
return
}
}
}
}, nil
}
func (r Root) JoinPath(parts ...string) string {
return filepath.Join(append([]string{r.root.Name()}, parts...)...)
}

View File

@@ -1,90 +0,0 @@
package model
import (
"io"
"strings"
"testing"
)
// setup is a helper function to set up the test environment.
func setup(t *testing.T, models map[Name]map[string]io.Reader) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
for m, s := range models {
f, err := Create(m)
if err != nil {
t.Fatal(err)
}
for n, r := range s {
w, err := f.Create(n)
if err != nil {
t.Fatal(err)
}
if _, err := io.Copy(w, r); err != nil {
t.Fatal(err)
}
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
}
}
func TestOpen(t *testing.T) {
setup(t, map[Name]map[string]io.Reader{
ParseName("namespace/model"): {
"./.": strings.NewReader(`{"key":"value"}`),
},
ParseName("namespace/model:8b"): {
"./.": strings.NewReader(`{"foo":"bar"}`),
},
ParseName("another/model"): {
"./.": strings.NewReader(`{"another":"config"}`),
},
})
f, err := Open(ParseName("namespace/model"))
if err != nil {
t.Fatal(err)
}
for _, name := range []string{"./."} {
r, err := f.Open(name)
if err != nil {
t.Fatal(err)
}
if _, err := io.ReadAll(r); err != nil {
t.Fatal(err)
}
if err := r.Close(); err != nil {
t.Fatal(err)
}
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
t.Run("does not exist", func(t *testing.T) {
if _, err := Open(ParseName("namespace/unknown")); err == nil {
t.Error("expected error for unknown model")
}
})
t.Run("write", func(t *testing.T) {
f, err := Open(ParseName("namespace/model"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := f.Create("new-blob"); err == nil {
t.Error("expected error creating blob in read-only mode")
}
})
}

View File

@@ -1,33 +0,0 @@
package model
import (
"io/fs"
"iter"
"path/filepath"
)
func All() (iter.Seq[Name], error) {
r, err := root()
if err != nil {
return nil, err
}
manifests, err := r.OpenRoot("manifests")
if err != nil {
return nil, err
}
matches, err := fs.Glob(manifests.FS(), "*/*/*/*")
if err != nil {
return nil, err
}
return func(yield func(Name) bool) {
for _, match := range matches {
name := ParseNameFromFilepath(filepath.ToSlash(match))
if !yield(name) {
return
}
}
}, nil
}

View File

@@ -227,17 +227,6 @@ func (n Name) String() string {
return b.String()
}
// Set implements [flag.Value]. It parses the provided input as a name string
// and sets the receiver to the parsed value. If the parsed name is not valid,
// ErrUnqualifiedName is returned.
func (n *Name) Set(s string) error {
*n = ParseName(s)
if !n.IsValid() {
return ErrUnqualifiedName
}
return nil
}
// DisplayShortest returns a short string version of the name.
func (n Name) DisplayShortest() string {
var sb strings.Builder

View File

@@ -1,4 +1,4 @@
package manifest
package imagegen
import (
"encoding/json"

View File

@@ -1,4 +1,4 @@
package manifest
package imagegen
import (
"path/filepath"

View File

@@ -14,8 +14,6 @@ import (
"encoding/json"
"fmt"
"runtime"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// SupportedBackends lists the backends that support image generation.
@@ -43,8 +41,8 @@ func CheckPlatformSupport() error {
// ResolveModelName checks if a model name is a known image generation model.
// Returns the normalized model name if found, empty string otherwise.
func ResolveModelName(modelName string) string {
modelManifest, err := manifest.LoadManifest(modelName)
if err == nil && modelManifest.HasTensorLayers() {
manifest, err := LoadManifest(modelName)
if err == nil && manifest.HasTensorLayers() {
return modelName
}
return ""
@@ -54,12 +52,12 @@ func ResolveModelName(modelName string) string {
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
// Returns empty string if detection fails.
func DetectModelType(modelName string) string {
modelManifest, err := manifest.LoadManifest(modelName)
manifest, err := LoadManifest(modelName)
if err != nil {
return ""
}
data, err := modelManifest.ReadConfig("model_index.json")
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return ""
}

View File

@@ -12,7 +12,7 @@ import (
"math"
"time"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen3"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -61,7 +61,7 @@ func (m *Model) Load(modelName string) error {
m.ModelName = modelName
// Load manifest
manifest, err := manifest.LoadManifest(modelName)
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -14,19 +14,19 @@ import (
// TransformerConfig holds Flux2 transformer configuration
type TransformerConfig struct {
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
Eps float32 `json:"eps"` // 1e-6
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
InChannels int32 `json:"in_channels"` // 128
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
MLPRatio float32 `json:"mlp_ratio"` // 3.0
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
NumLayers int32 `json:"num_layers"` // 5
NumSingleLayers int32 `json:"num_single_layers"` // 20
PatchSize int32 `json:"patch_size"` // 1
RopeTheta int32 `json:"rope_theta"` // 2000
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
Eps float32 `json:"eps"` // 1e-6
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
InChannels int32 `json:"in_channels"` // 128
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
MLPRatio float32 `json:"mlp_ratio"` // 3.0
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
NumLayers int32 `json:"num_layers"` // 5
NumSingleLayers int32 `json:"num_single_layers"` // 20
PatchSize int32 `json:"patch_size"` // 1
RopeTheta int32 `json:"rope_theta"` // 2000
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
}
// Computed dimensions
@@ -392,12 +392,12 @@ type Flux2Transformer2DModel struct {
}
// Load loads the Flux2 transformer from ollama blob storage.
func (m *Flux2Transformer2DModel) Load(modelManifest *manifest.ModelManifest) error {
func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.TransformerConfig = &cfg
@@ -412,7 +412,7 @@ func (m *Flux2Transformer2DModel) Load(modelManifest *manifest.ModelManifest) er
}
// Load weights from tensor blobs
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -15,21 +15,21 @@ import (
// VAEConfig holds AutoencoderKLFlux2 configuration
type VAEConfig struct {
ActFn string `json:"act_fn"` // "silu"
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
ForceUpcast bool `json:"force_upcast"` // true
InChannels int32 `json:"in_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 32
LayersPerBlock int32 `json:"layers_per_block"` // 2
ActFn string `json:"act_fn"` // "silu"
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
ForceUpcast bool `json:"force_upcast"` // true
InChannels int32 `json:"in_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 32
LayersPerBlock int32 `json:"layers_per_block"` // 2
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
NormNumGroups int32 `json:"norm_num_groups"` // 32
OutChannels int32 `json:"out_channels"` // 3
PatchSize []int32 `json:"patch_size"` // [2, 2]
SampleSize int32 `json:"sample_size"` // 1024
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
UseQuantConv bool `json:"use_quant_conv"` // true
NormNumGroups int32 `json:"norm_num_groups"` // 32
OutChannels int32 `json:"out_channels"` // 3
PatchSize []int32 `json:"patch_size"` // [2, 2]
SampleSize int32 `json:"sample_size"` // 1024
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
UseQuantConv bool `json:"use_quant_conv"` // true
}
// BatchNorm2D implements 2D batch normalization with running statistics
@@ -356,18 +356,18 @@ func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
}
// Load loads the Flux2 VAE from ollama blob storage.
func (m *AutoencoderKLFlux2) Load(modelManifest *manifest.ModelManifest) error {
func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading VAE... ")
// Load config from blob
var cfg VAEConfig
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -9,8 +9,8 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -38,11 +38,11 @@ type Config struct {
AttentionBias bool `json:"attention_bias"`
// MLA (Multi-head Latent Attention) parameters
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
// MoE parameters
NRoutedExperts int32 `json:"n_routed_experts"`
@@ -82,7 +82,7 @@ type MLAAttention struct {
// Absorbed MLA projections (derived from kv_b_proj)
// EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
// UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
EmbedQ *nn.MultiLinear `weight:"-"`
EmbedQ *nn.MultiLinear `weight:"-"`
UnembedOut *nn.MultiLinear `weight:"-"`
// Output projection
@@ -194,8 +194,8 @@ func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
// MoEGate implements the expert gating mechanism
type MoEGate struct {
Gate nn.LinearLayer `weight:"mlp.gate"`
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
Gate nn.LinearLayer `weight:"mlp.gate"`
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
}
// Forward computes expert selection indices and scores
@@ -617,9 +617,9 @@ func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numE
}
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
// Read config from manifest
configData, err := modelManifest.ReadConfig("config.json")
configData, err := manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
@@ -634,7 +634,7 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
cfg.Scale = computeScale(&cfg)
// Load weights from manifest blobs
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "")
weights, err := imagegen.LoadWeightsFromManifest(manifest, "")
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
@@ -653,7 +653,7 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
}
// Load tokenizer from manifest with config files for EOS token detection
tokData, err := modelManifest.ReadConfig("tokenizer.json")
tokData, err := manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
@@ -664,12 +664,12 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
}
// Try to load generation_config.json if available (preferred source for EOS)
if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
// Try to load tokenizer_config.json if available
if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}

View File

@@ -7,7 +7,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -181,19 +181,19 @@ type TextEncoder struct {
}
// Load loads the Qwen3 text encoder from ollama blob storage.
func (m *TextEncoder) Load(modelManifest *manifest.ModelManifest, configPath string) error {
func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
fmt.Print(" Loading text encoder... ")
// Load config from blob
var cfg Config
if err := modelManifest.ReadConfigJSON(configPath, &cfg); err != nil {
if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
m.Layers = make([]*Block, cfg.NumHiddenLayers)
// Load weights from tensor blobs
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "text_encoder")
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -7,8 +7,8 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -38,7 +38,7 @@ type TransformerConfig struct {
type TimestepEmbedder struct {
Linear1 nn.LinearLayer `weight:"mlp.0"`
Linear2 nn.LinearLayer `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed)
FreqEmbedSize int32 // 256 (computed)
}
// Forward computes timestep embeddings -> [B, 256]
@@ -85,9 +85,9 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"`
Linear nn.LinearLayer `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
Norm *nn.RMSNorm `weight:"0"`
Linear nn.LinearLayer `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
}
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
@@ -103,9 +103,10 @@ type FeedForward struct {
W1 nn.LinearLayer `weight:"w1"` // gate projection
W2 nn.LinearLayer `weight:"w2"` // down projection
W3 nn.LinearLayer `weight:"w3"` // up projection
OutDim int32 // computed from W2
OutDim int32 // computed from W2
}
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
@@ -131,11 +132,11 @@ type Attention struct {
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
// Fused QKV (computed at init time for efficiency, not loaded from weights)
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
Fused bool `weight:"-"` // Whether to use fused QKV path
Fused bool `weight:"-"` // Whether to use fused QKV path
// Computed fields (not loaded from weights)
NHeads int32 `weight:"-"`
HeadDim int32 `weight:"-"`
@@ -287,13 +288,13 @@ func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
// TransformerBlock is a single transformer block with optional AdaLN modulation
type TransformerBlock struct {
Attention *Attention `weight:"attention"`
FeedForward *FeedForward `weight:"feed_forward"`
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
Attention *Attention `weight:"attention"`
FeedForward *FeedForward `weight:"feed_forward"`
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
@@ -349,7 +350,7 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
type FinalLayer struct {
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output
OutDim int32 // computed from Output
}
// Forward computes final output
@@ -400,12 +401,12 @@ type Transformer struct {
}
// Load loads the Z-Image transformer from ollama blob storage.
func (m *Transformer) Load(modelManifest *manifest.ModelManifest) error {
func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
if len(cfg.AllPatchSize) > 0 {
@@ -416,7 +417,7 @@ func (m *Transformer) Load(modelManifest *manifest.ModelManifest) error {
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/vae"
@@ -562,7 +562,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
if ub.Upsample != nil {
// Stage 1: Upsample2x (nearest neighbor)
{
prev := x
prev := x
x = Upsample2x(x)
prev.Free()
mlx.Eval(x)
@@ -570,7 +570,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
// Stage 2: Upsample conv
{
prev := x
prev := x
x = ub.Upsample.Forward(x)
prev.Free()
mlx.Eval(x)
@@ -643,16 +643,16 @@ type VAEDecoder struct {
}
// Load loads the VAE decoder from ollama blob storage.
func (m *VAEDecoder) Load(modelManifest *manifest.ModelManifest) error {
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
// Load config from blob
var cfg VAEConfig
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -8,8 +8,8 @@ import (
"fmt"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/imagegen/vae"
@@ -18,14 +18,14 @@ import (
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
CapturePath string // GPU capture path (debug)
CapturePath string // GPU capture path (debug)
// TeaCache options (timestep embedding aware caching)
TeaCache bool // TeaCache is always enabled for faster inference
@@ -58,7 +58,7 @@ func (m *Model) Load(modelName string) error {
m.ModelName = modelName
// Load manifest
manifest, err := manifest.LoadManifest(modelName)
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}

View File

@@ -1,203 +0,0 @@
//go:build mlx
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
package imagegen
import (
"context"
"encoding/json"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// Execute is the entry point for the unified MLX runner subprocess.
func Execute(args []string) error {
// Set up logging with appropriate level from environment
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: envconfig.LogLevel()})))
fs := flag.NewFlagSet("mlx-runner", flag.ExitOnError)
modelName := fs.String("model", "", "path to model")
port := fs.Int("port", 0, "port to listen on")
if err := fs.Parse(args); err != nil {
return err
}
if *modelName == "" {
return fmt.Errorf("--model is required")
}
if *port == 0 {
return fmt.Errorf("--port is required")
}
// Initialize MLX
if err := mlx.InitMLX(); err != nil {
slog.Error("unable to initialize MLX", "error", err)
return err
}
slog.Info("MLX library initialized")
// Detect model type from capabilities
mode := detectModelMode(*modelName)
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
// Create and start server
server, err := newServer(*modelName, *port, mode)
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
}
// Set up HTTP handlers
mux := http.NewServeMux()
mux.HandleFunc("/health", server.healthHandler)
mux.HandleFunc("/completion", server.completionHandler)
// LLM-specific endpoints
if mode == ModeLLM {
mux.HandleFunc("/tokenize", server.tokenizeHandler)
mux.HandleFunc("/embedding", server.embeddingHandler)
}
httpServer := &http.Server{
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
Handler: mux,
}
// Handle shutdown
done := make(chan struct{})
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
slog.Info("shutting down mlx runner")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpServer.Shutdown(ctx)
close(done)
}()
slog.Info("mlx runner listening", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
return err
}
<-done
return nil
}
// detectModelMode determines whether a model is an LLM or image generation model.
func detectModelMode(modelName string) ModelMode {
// Check for image generation model by looking at model_index.json
modelType := DetectModelType(modelName)
if modelType != "" {
// Known image generation model types
switch modelType {
case "ZImagePipeline", "FluxPipeline", "Flux2KleinPipeline":
return ModeImageGen
}
}
// Default to LLM mode for safetensors models without known image gen types
return ModeLLM
}
// server holds the model and handles HTTP requests.
type server struct {
mode ModelMode
modelName string
port int
// Image generation model (when mode == ModeImageGen)
imageModel ImageModel
// LLM model (when mode == ModeLLM)
llmModel *llmState
}
// newServer creates a new server instance and loads the appropriate model.
func newServer(modelName string, port int, mode ModelMode) (*server, error) {
s := &server{
mode: mode,
modelName: modelName,
port: port,
}
switch mode {
case ModeImageGen:
if err := s.loadImageModel(); err != nil {
return nil, fmt.Errorf("failed to load image model: %w", err)
}
case ModeLLM:
if err := s.loadLLMModel(); err != nil {
return nil, fmt.Errorf("failed to load LLM model: %w", err)
}
}
return s, nil
}
func (s *server) healthHandler(w http.ResponseWriter, r *http.Request) {
resp := HealthResponse{Status: "ok"}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch s.mode {
case ModeImageGen:
s.handleImageCompletion(w, r, req)
case ModeLLM:
s.handleLLMCompletion(w, r, req)
}
}
func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) {
if s.llmModel == nil {
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
return
}
var req struct {
Content string `json:"content"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
tok := s.llmModel.model.Tokenizer()
tokens := tok.Encode(req.Content, false)
// Convert int32 to int for JSON response
intTokens := make([]int, len(tokens))
for i, t := range tokens {
intTokens[i] = int(t)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens})
}
func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented)
}

View File

@@ -1,471 +0,0 @@
package imagegen
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
//
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
// like any other model. It supports both LLM (safetensors) and image generation models.
type Server struct {
mu sync.Mutex
cmd *exec.Cmd
port int
modelName string
mode ModelMode
vramSize uint64
done chan error
client *http.Client
lastErr string // Last stderr line for error reporting
lastErrLock sync.Mutex
}
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
func NewServer(modelName string, mode ModelMode) (*Server, error) {
// Validate platform support before attempting to start
if err := CheckPlatformSupport(); err != nil {
return nil, err
}
// Find a free port
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
if l, err := net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
port = rand.Intn(65535-49152) + 49152
}
// Get the current executable path (we use the same binary with runner subcommand)
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
if runtime.GOOS == "linux" {
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
// Append existing LD_LIBRARY_PATH if set
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add LD_LIBRARY_PATH in cmd.Env
found := false
for i := range cmd.Env {
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
// Estimate VRAM based on tensor size from manifest
var vramSize uint64
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
vramSize = uint64(modelManifest.TotalTensorSize())
} else {
// Fallback: default to 8GB if manifest can't be loaded
vramSize = 8 * 1024 * 1024 * 1024
}
s := &Server{
cmd: cmd,
port: port,
modelName: modelName,
mode: mode,
vramSize: vramSize,
done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute},
}
// Forward subprocess stdout/stderr to server logs
stdout, _ := cmd.StdoutPipe()
stderr, _ := cmd.StderrPipe()
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
slog.Info("mlx-runner", "msg", scanner.Text())
}
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
slog.Warn("mlx-runner", "msg", line)
s.lastErrLock.Lock()
s.lastErr = line
s.lastErrLock.Unlock()
}
}()
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port, "mode", mode)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
}
// Reap subprocess when it exits
go func() {
err := cmd.Wait()
s.done <- err
}()
// Wait for subprocess to be ready
if err := s.waitUntilRunning(); err != nil {
s.Close()
return nil, err
}
return s, nil
}
// ModelPath returns the path to the model.
func (s *Server) ModelPath() string {
return s.modelName
}
// Load satisfies the LlamaServer interface. MLX models don't need GPU layer assignment.
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
return nil, nil
}
// Ping checks if the subprocess is healthy.
func (s *Server) Ping(ctx context.Context) error {
url := fmt.Sprintf("http://127.0.0.1:%d/health", s.port)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
resp, err := s.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("health check failed: %d", resp.StatusCode)
}
return nil
}
// waitUntilRunning waits for the subprocess to be ready.
func (s *Server) waitUntilRunning() error {
ctx := context.Background()
timeout := time.After(2 * time.Minute)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case err := <-s.done:
// Include recent stderr lines for better error context
errMsg := s.getLastErr()
if errMsg != "" {
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
}
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
case <-timeout:
errMsg := s.getLastErr()
if errMsg != "" {
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
}
return errors.New("timeout waiting for mlx runner to start")
case <-ticker.C:
if err := s.Ping(ctx); err == nil {
slog.Info("mlx runner is ready", "port", s.port)
return nil
}
}
}
}
// getLastErr returns the last stderr line.
func (s *Server) getLastErr() string {
s.lastErrLock.Lock()
defer s.lastErrLock.Unlock()
return s.lastErr
}
// WaitUntilRunning satisfies the LlamaServer interface.
func (s *Server) WaitUntilRunning(ctx context.Context) error {
return nil
}
// Completion handles both text and image generation requests.
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
seed := req.Seed
if seed == 0 {
seed = time.Now().UnixNano()
}
// Extract raw image bytes from llm.ImageData slice
var images [][]byte
for _, img := range req.Images {
images = append(images, img.Data)
}
// Build request for subprocess
creq := Request{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: int(req.Steps),
Seed: seed,
Images: images,
}
// Pass LLM options if present
if req.Options != nil {
creq.Options = &RequestOptions{
NumPredict: req.Options.NumPredict,
Temperature: float64(req.Options.Temperature),
TopP: float64(req.Options.TopP),
TopK: req.Options.TopK,
Stop: req.Options.Stop,
}
}
body, err := json.Marshal(creq)
if err != nil {
return err
}
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(httpReq)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%s", strings.TrimSpace(string(body)))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() {
// Parse subprocess response
var raw struct {
Image string `json:"image,omitempty"`
Content string `json:"content,omitempty"`
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue
}
// Log stop reason when generation completes
if raw.Done && raw.StopReason != "" {
slog.Info("mlx generation completed", "stop_reason", raw.StopReason)
}
// Convert to llm.CompletionResponse
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
Step: raw.Step,
TotalSteps: raw.Total,
Image: raw.Image,
PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
EvalCount: raw.EvalCount,
EvalDuration: time.Duration(raw.EvalDuration),
}
fn(cresp)
if cresp.Done {
return nil
}
}
// Scanner exited without receiving Done - connection was likely closed
scanErr := scanner.Err()
if scanErr != nil {
slog.Error("mlx scanner error", "error", scanErr)
} else {
slog.Warn("mlx scanner EOF without Done response - subprocess may have crashed")
}
// Check if subprocess is still alive
if s.HasExited() {
slog.Error("mlx subprocess has exited unexpectedly")
}
return scanErr
}
// Close terminates the subprocess.
func (s *Server) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cmd != nil && s.cmd.Process != nil {
slog.Info("stopping mlx runner subprocess", "pid", s.cmd.Process.Pid)
s.cmd.Process.Signal(os.Interrupt)
// Wait briefly for graceful shutdown
select {
case <-s.done:
case <-time.After(5 * time.Second):
s.cmd.Process.Kill()
}
s.cmd = nil
}
return nil
}
// VRAMSize returns the estimated VRAM usage.
func (s *Server) VRAMSize() uint64 {
return s.vramSize
}
// TotalSize returns the total memory usage.
func (s *Server) TotalSize() uint64 {
return s.vramSize
}
// VRAMByGPU returns VRAM usage for a specific GPU.
func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
return s.vramSize
}
// ContextLength returns the context length (not applicable for image generation).
func (s *Server) ContextLength() int {
return 0
}
// Embedding returns embeddings for the input.
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
return nil, 0, errors.New("embeddings not supported for MLX models")
}
// Tokenize tokenizes the input content.
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
body, err := json.Marshal(map[string]string{"content": content})
if err != nil {
return nil, err
}
url := fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("tokenize failed: %d", resp.StatusCode)
}
var result struct {
Tokens []int `json:"tokens"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return result.Tokens, nil
}
// Detokenize converts tokens back to text.
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
return "", errors.New("detokenization not supported for MLX models")
}
// Pid returns the process ID of the subprocess.
func (s *Server) Pid() int {
s.mu.Lock()
defer s.mu.Unlock()
if s.cmd != nil && s.cmd.Process != nil {
return s.cmd.Process.Pid
}
return -1
}
// GetPort returns the port the subprocess is listening on.
func (s *Server) GetPort() int {
return s.port
}
// GetDeviceInfos returns device information.
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil
}
// HasExited returns whether the subprocess has exited.
func (s *Server) HasExited() bool {
select {
case <-s.done:
return true
default:
return false
}
}
// Ensure Server implements llm.LlamaServer
var _ llm.LlamaServer = (*Server)(nil)

View File

@@ -1,6 +1,6 @@
//go:build mlx
package manifest
package imagegen
import (
"fmt"
@@ -15,9 +15,9 @@ import (
type ManifestWeights struct {
manifest *ModelManifest
component string
tensors map[string]ManifestLayer // name -> layer
cache map[string]*mlx.Array // name -> loaded array
nativeCache []*mlx.SafetensorsFile // keep native handles alive
tensors map[string]ManifestLayer // name -> layer
cache map[string]*mlx.Array // name -> loaded array
nativeCache []*mlx.SafetensorsFile // keep native handles alive
}
// LoadWeightsFromManifest creates a weight loader from manifest storage.

77
x/kvcache/cache.go Normal file
View File

@@ -0,0 +1,77 @@
package kvcache
import (
"errors"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
var (
ErrKvCacheFull = errors.New("could not find a kv cache slot")
ErrNotSupported = errors.New("model does not support operation")
)
type Cache interface {
// ** used by model implementations **
// SetLayer sets the active layer of the cache
SetLayer(layer int)
// Get returns the history of key and value tensors plus a mask
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
// Put stores a batch of key and value in the cache
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management **
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it
Close()
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. reserve is to preallocate memory
// without actually storing data in the cache.
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)
// CanResume returns true if the cache can continue with the next token at
// the given position and sequence. Assumes that the caller has already
// verified the contents of the cache.
CanResume(seq int, pos int32) bool
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
//
// If an error occurs, the entire context for the sequence should be
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}

144
x/kvcache/causal.go Normal file
View File

@@ -0,0 +1,144 @@
//go:build mlx
package kvcache
import (
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
type Causal struct {
DType ml.DType
// locations for data storage for this batch
curLocPut ml.Tensor
// locations for data storage for this batch
curLocGet ml.Tensor
// the active layer for Get and Put
curLayer int
capacity int
offset int
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
// TODO is this needed per layer, or will it always be consistent?
kHeadDims, vHeadDims, numKVHeads map[int]int
}
func NewCausalCache() *Causal {
return &Causal{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
kHeadDims: make(map[int]int),
vHeadDims: make(map[int]int),
numKVHeads: make(map[int]int),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.DType = dtype
c.capacity = capacity
c.backend = backend
}
func (c *Causal) SetConfig(config ml.CacheConfig) {}
func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
func (c *Causal) Close() {
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
locsPut := make([]int32, len(batch.Positions))
for i := c.offset; i < len(batch.Positions); i++ {
locsPut[i-c.offset] = int32(i)
}
c.offset += len(batch.Positions)
locsGet := make([]int32, c.offset)
for i := range c.offset {
locsGet[i] = int32(i)
}
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
// slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
return nil
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(3)
vHeadDim := value.Dim(3)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
kCellSize := kHeadDim * numKVHeads
vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
c.kHeadDims[c.curLayer] = kHeadDim
c.vHeadDims[c.curLayer] = vHeadDim
c.numKVHeads[c.curLayer] = numKVHeads
}
key = key.Reshape(ctx, batchSize, 1, kCellSize)
// slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
// slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
// slog.Info("XXX Causal.Put ", "key", key)
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
value = value.Reshape(ctx, batchSize, 1, vCellSize)
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
}
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := c.kHeadDims[c.curLayer]
vHeadDim := c.vHeadDims[c.curLayer]
numKVHeads := c.numKVHeads[c.curLayer]
// rowSize := numKVHeads * c.curBatchSize
// cachedSize := c.curMask.Dim(1)
cachedSize := c.curLocGet.Dim(0)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
return key, value, nil
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("not implemented")
}
func (c *Causal) CanResume(seq int, pos int32) bool {
panic("not implemented")
}
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
panic("not implemented")
}

156
x/kvcache/encoder.go Normal file
View File

@@ -0,0 +1,156 @@
package kvcache
// import (
// "fmt"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// // Encoder cache stores K and V tensors that are position independent
// //
// // The tensors can be of any shape and will be returned as they were stored
// // The mask is currently always nil
// //
// // Not currently safe for multiple sequences
// type EncoderCache struct {
// // config controls mostly backend-specific optimizations
// config *ml.CacheConfig
// // ** current forward pass **
// // the active layer for Get and Put
// curLayer int
// // if something is stored during this pass, this
// // will be the position (but there is no guarantee
// // anything will be stored)
// curPos int32
// // curReserve indicates that this forward pass is only for
// // memory reservation and we should not update our metadata
// // based on it.
// curReserve bool
// // ** cache metadata **
// // was something stored in the cache?
// encoderCached bool
// // position of the cached data
// encoderPos int32
// // ** cache data storage **
// backend ml.Backend
// ctxs map[int]ml.Context
// keys, values map[int]ml.Tensor
// }
// func NewEncoderCache() *EncoderCache {
// return &EncoderCache{
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// }
// }
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// if c.config == nil {
// var config ml.CacheConfig
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
// config = cc.CacheConfig()
// }
// c.config = &config
// }
// if maxSequences > 1 {
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
// }
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
// }
// c.backend = backend
// }
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
// if c.config != nil {
// panic("config cannot be changed after being previously set, either by the model or backend")
// }
// c.config = &config
// }
// func (c *EncoderCache) Close() {
// for _, ctx := range c.ctxs {
// ctx.Close()
// }
// }
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// // We work with the most recent image
// if len(batch.Multimodal) > 0 {
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
// }
// c.curReserve = reserve
// return nil
// }
// func (c *EncoderCache) SetLayer(layer int) {
// c.curLayer = layer
// }
// func (c *EncoderCache) EncoderCached() bool {
// return c.encoderCached
// }
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// return c.keys[c.curLayer], c.values[c.curLayer], nil
// }
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
// if !c.curReserve {
// c.encoderPos = c.curPos
// c.encoderCached = true
// }
// if c.config.PermutedV {
// value = value.Transpose(ctx, 1, 2, 0, 3)
// }
// if _, ok := c.ctxs[c.curLayer]; !ok {
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
// }
// if _, ok := c.keys[c.curLayer]; !ok {
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
// }
// if _, ok := c.values[c.curLayer]; !ok {
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
// }
// ctx.Forward(
// key.Copy(ctx, c.keys[c.curLayer]),
// value.Copy(ctx, c.values[c.curLayer]),
// )
// }
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// panic("encoder cache does not support multiple sequences")
// }
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
// return true
// }
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
// c.encoderCached = false
// }
// return nil
// }

Some files were not shown because too many files have changed in this diff Show More