Compare commits

..

11 Commits

Author SHA1 Message Date
Michael Yang
ce262df633 fix build duplicate symbols 2026-02-05 17:22:22 -08:00
Michael Yang
e0a4f0aa14 mlxrunner 2026-02-05 17:22:22 -08:00
Michael Yang
9cb6b8ac6d draft: model manifest file interface
this change makes it easier to address blobs by their original names
without rebuilding the filesystem structure
2026-02-05 17:22:22 -08:00
Michael Yang
87e01abd59 resolve circular dependency 2026-02-05 17:21:41 -08:00
Michael Yang
26448f1e7d mv x/mlxrunner x/imagegen 2026-02-05 17:21:41 -08:00
Michael Yang
693101589a simplify runner selection 2026-02-05 17:21:41 -08:00
Michael Yang
b8dff7e342 clean up unused directories 2026-02-05 17:21:41 -08:00
Michael Yang
b350656b23 move tokenizer to separate package 2026-02-05 17:21:37 -08:00
Parth Sareen
8a4b77f9da cmd: set context limits for cloud models in opencode (#14107) 2026-02-05 16:36:46 -08:00
Parth Sareen
5f53fe7884 cmd: ollama launch improvements (#14099) 2026-02-05 15:08:17 -08:00
Bruce MacDonald
7ab4ca0e7f scripts: add macOS support to install.sh (#14060)
Allow installing Ollama on MacOS directly from the command line. This is in line with other CLI tools and results in a more streamlined experience when the user is looking to use the CLI specifically.
2026-02-05 14:59:01 -08:00
169 changed files with 18909 additions and 7923 deletions

22
.github/workflows/test-install.yaml vendored Normal file
View File

@@ -0,0 +1,22 @@
name: test-install
on:
pull_request:
paths:
- 'scripts/install.sh'
- '.github/workflows/test-install.yaml'
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Run install script
run: sh ./scripts/install.sh
env:
OLLAMA_NO_START: 1 # do not start app
- name: Verify ollama is available
run: ollama --version

View File

@@ -518,24 +518,26 @@ func mapStopReason(reason string, hasToolCalls bool) string {
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
ID: id,
Model: model,
firstWrite: true,
estimatedInputTokens: estimatedInputTokens,
toolCallsSent: make(map[string]bool),
}
}
@@ -551,7 +553,11 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
if c.firstWrite {
c.firstWrite = false
// Use actual metrics if available, otherwise use estimate
c.inputTokens = r.Metrics.PromptEvalCount
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
c.inputTokens = c.estimatedInputTokens
}
events = append(events, StreamEvent{
Event: "message_start",
@@ -779,3 +785,123 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
}
return args
}
// CountTokensRequest represents an Anthropic count_tokens request
type CountTokensRequest struct {
Model string `json:"model"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
}
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
func EstimateInputTokens(req MessagesRequest) int {
return estimateTokens(CountTokensRequest{
Model: req.Model,
Messages: req.Messages,
System: req.System,
Tools: req.Tools,
Thinking: req.Thinking,
})
}
// CountTokensResponse represents an Anthropic count_tokens response
type CountTokensResponse struct {
InputTokens int `json:"input_tokens"`
}
// estimateTokens returns a rough estimate of tokens (len/4).
// TODO: Replace with actual tokenization via Tokenize API for accuracy.
// Current len/4 heuristic is a rough approximation (~4 chars/token average).
func estimateTokens(req CountTokensRequest) int {
var totalLen int
// Count system prompt
if req.System != nil {
totalLen += countAnyContent(req.System)
}
// Count messages
for _, msg := range req.Messages {
// Count role (always present)
totalLen += len(msg.Role)
// Count content
contentLen := countAnyContent(msg.Content)
totalLen += contentLen
}
for _, tool := range req.Tools {
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
}
// Return len/4 as rough token estimate, minimum 1 if there's any content
tokens := totalLen / 4
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
tokens = 1
}
return tokens
}
func countAnyContent(content any) int {
if content == nil {
return 0
}
switch c := content.(type) {
case string:
return len(c)
case []any:
total := 0
for _, block := range c {
total += countContentBlock(block)
}
return total
default:
if data, err := json.Marshal(content); err == nil {
return len(data)
}
return 0
}
}
func countContentBlock(block any) int {
blockMap, ok := block.(map[string]any)
if !ok {
if s, ok := block.(string); ok {
return len(s)
}
return 0
}
total := 0
blockType, _ := blockMap["type"].(string)
if text, ok := blockMap["text"].(string); ok {
total += len(text)
}
if thinking, ok := blockMap["thinking"].(string); ok {
total += len(thinking)
}
if blockType == "tool_use" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data)
}
}
if blockType == "tool_result" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data)
}
}
if source, ok := blockMap["source"].(map[string]any); ok {
if data, ok := source["data"].(string); ok {
total += len(data)
}
}
return total
}

View File

@@ -321,8 +321,6 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
@@ -605,7 +603,7 @@ func TestGenerateMessageID(t *testing.T) {
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
// First chunk
resp1 := api.ChatResponse{
@@ -678,7 +676,7 @@ func TestStreamConverter_Basic(t *testing.T) {
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
resp := api.ChatResponse{
Model: "test-model",
@@ -731,7 +729,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
@@ -778,7 +776,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
@@ -842,10 +840,6 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
@@ -899,11 +893,9 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
resp := api.ChatResponse{
Model: "test-model",
@@ -937,7 +929,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
resp := api.ChatResponse{
Model: "test-model",
@@ -969,3 +961,105 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
}
})
}
func TestEstimateTokens_SimpleMessage(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "Hello, world!"},
},
}
tokens := estimateTokens(req)
// "user" (4) + "Hello, world!" (13) = 17 chars / 4 = 4 tokens
if tokens < 1 {
t.Errorf("expected at least 1 token, got %d", tokens)
}
// Sanity check: shouldn't be wildly off
if tokens > 10 {
t.Errorf("expected fewer than 10 tokens for short message, got %d", tokens)
}
}
func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
tokens := estimateTokens(req)
// System prompt adds to count
if tokens < 5 {
t.Errorf("expected at least 5 tokens with system prompt, got %d", tokens)
}
}
func TestEstimateTokens_WithTools(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "What's the weather?"},
},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get the current weather for a location",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
},
},
}
tokens := estimateTokens(req)
// Tools add significant content
if tokens < 10 {
t.Errorf("expected at least 10 tokens with tools, got %d", tokens)
}
}
func TestEstimateTokens_WithThinking(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this carefully...",
},
map[string]any{
"type": "text",
"text": "Here is my response.",
},
},
},
},
}
tokens := estimateTokens(req)
// Thinking content should be counted
if tokens < 10 {
t.Errorf("expected at least 10 tokens with thinking content, got %d", tokens)
}
}
func TestEstimateTokens_EmptyContent(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{},
}
tokens := estimateTokens(req)
if tokens != 0 {
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
}
}

View File

@@ -466,3 +466,25 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
}
return &resp, nil
}
// AliasRequest is the request body for creating or updating a model alias.
type AliasRequest struct {
Alias string `json:"alias"`
Target string `json:"target"`
PrefixMatching bool `json:"prefix_matching,omitempty"`
}
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
}
// AliasDeleteRequest is the request body for deleting a model alias.
type AliasDeleteRequest struct {
Alias string `json:"alias"`
}
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
}

View File

@@ -1763,7 +1763,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return err
}
if err := startApp(cmd.Context(), client); err != nil {
return fmt.Errorf("ollama server not responding - %w", err)
return err
}
}
return nil

View File

@@ -1,18 +1,23 @@
package config
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner for Claude Code integration
// Claude implements Runner and AliasConfigurer for Claude Code integration
type Claude struct{}
// Compile-time check that Claude implements AliasConfigurer
var _ AliasConfigurer = (*Claude)(nil)
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extra []string) []string {
@@ -60,3 +65,104 @@ func (c *Claude) Run(model string, args []string) error {
)
return cmd.Run()
}
// ConfigureAliases sets up model aliases for Claude Code.
// model: the model to use (if empty, user will be prompted to select)
// aliases: existing alias configuration to preserve/update
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
// there is a better strategy for prompt caching on local models.
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
aliases := make(map[string]string)
for k, v := range existingAliases {
aliases[k] = v
}
if model != "" {
aliases["primary"] = model
}
if !force && aliases["primary"] != "" {
client, _ := api.ClientFromEnvironment()
if isCloudModel(ctx, client, aliases["primary"]) {
if isCloudModel(ctx, client, aliases["fast"]) {
return aliases, false, nil
}
} else {
delete(aliases, "fast")
return aliases, false, nil
}
}
items, existingModels, cloudModels, client, err := listModels(ctx)
if err != nil {
return nil, false, err
}
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
if aliases["primary"] == "" || force {
primary, err := selectPrompt("Select model:", items)
fmt.Fprintf(os.Stderr, "\033[3A\033[J")
if err != nil {
return nil, false, err
}
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
return nil, false, err
}
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
return nil, false, err
}
aliases["primary"] = primary
}
if isCloudModel(ctx, client, aliases["primary"]) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = aliases["primary"]
}
} else {
delete(aliases, "fast")
}
return aliases, true, nil
}
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
// prevent stale routing to a previous cloud model.
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
if aliases["fast"] == "" {
for _, prefix := range prefixes {
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
}
return nil
}
prefixAliases := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
var errs []string
for prefix, target := range prefixAliases {
req := &api.AliasRequest{
Alias: prefix,
Target: target,
PrefixMatching: true,
}
if err := client.SetAliasExperimental(ctx, req); err != nil {
errs = append(errs, prefix)
}
}
if len(errs) > 0 {
return fmt.Errorf("failed to set aliases: %v", errs)
}
return nil
}

View File

@@ -13,7 +13,8 @@ import (
)
type integration struct {
Models []string `json:"models"`
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
}
type config struct {
@@ -133,8 +134,16 @@ func saveIntegration(appName string, models []string) error {
return err
}
cfg.Integrations[strings.ToLower(appName)] = &integration{
Models: models,
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
var aliases map[string]string
if existing != nil && existing.Aliases != nil {
aliases = existing.Aliases
}
cfg.Integrations[key] = &integration{
Models: models,
Aliases: aliases,
}
return save(cfg)
@@ -154,6 +163,29 @@ func loadIntegration(appName string) (*integration, error) {
return ic, nil
}
func saveAliases(appName string, aliases map[string]string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}
cfg, err := load()
if err != nil {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
if existing == nil {
existing = &integration{}
}
// Replace aliases entirely (not merge) so deletions are persisted
existing.Aliases = aliases
cfg.Integrations[key] = existing
return save(cfg)
}
func listIntegrations() ([]integration, error) {
cfg, err := load()
if err != nil {

View File

@@ -0,0 +1,677 @@
package config
import (
"context"
"errors"
"os"
"path/filepath"
"testing"
)
func TestSetAliases_CloudModel(t *testing.T) {
// Test the SetAliases logic by checking the alias map behavior
aliases := map[string]string{
"primary": "kimi-k2.5:cloud",
"fast": "kimi-k2.5:cloud",
}
// Verify fast is set (cloud model behavior)
if aliases["fast"] == "" {
t.Error("cloud model should have fast alias set")
}
if aliases["fast"] != aliases["primary"] {
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
}
}
func TestSetAliases_LocalModel(t *testing.T) {
aliases := map[string]string{
"primary": "llama3.2:latest",
}
// Simulate local model behavior: fast should be empty
delete(aliases, "fast")
if aliases["fast"] != "" {
t.Error("local model should have empty fast alias")
}
}
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// First save with both primary and fast
initial := map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}
if err := saveAliases("claude", initial); err != nil {
t.Fatalf("failed to save initial aliases: %v", err)
}
// Verify both are saved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
}
// Now save without fast (simulating switch to local model)
updated := map[string]string{
"primary": "local-model",
// fast intentionally missing
}
if err := saveAliases("claude", updated); err != nil {
t.Fatalf("failed to save updated aliases: %v", err)
}
// Verify fast is GONE (not merged/preserved)
loaded, err = loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
}
if loaded.Aliases["primary"] != "local-model" {
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
}
}
func TestSaveAliases_PreservesModels(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// First save integration with models
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
t.Fatalf("failed to save integration: %v", err)
}
// Then update aliases
aliases := map[string]string{"primary": "new-model"}
if err := saveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save aliases: %v", err)
}
// Verify models are preserved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
t.Errorf("models should be preserved, got %v", loaded.Models)
}
}
// TestSaveAliases_EmptyMap clears all aliases
func TestSaveAliases_EmptyMap(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save with aliases
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save empty map
if err := saveAliases("claude", map[string]string{}); err != nil {
t.Fatalf("failed to save empty: %v", err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Aliases) != 0 {
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
}
}
// TestSaveAliases_NilMap handles nil gracefully
func TestSaveAliases_NilMap(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save with aliases first
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save nil map - should clear aliases
if err := saveAliases("claude", nil); err != nil {
t.Fatalf("failed to save nil: %v", err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Aliases) > 0 {
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
}
}
// TestSaveAliases_EmptyAppName returns error
func TestSaveAliases_EmptyAppName(t *testing.T) {
err := saveAliases("", map[string]string{"primary": "model"})
if err == nil {
t.Error("expected error for empty app name")
}
}
func TestSaveAliases_CaseInsensitive(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Load with different case
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model1" {
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
}
// Update with different case
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to update: %v", err)
}
loaded, err = loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
if loaded.Aliases["primary"] != "model2" {
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
}
}
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
func TestSaveAliases_CreatesIntegration(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save aliases for non-existent integration
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
loaded, err := loadIntegration("newintegration")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model" {
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
}
}
func TestConfigureAliases_AliasMap(t *testing.T) {
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
aliases := make(map[string]string)
aliases["primary"] = "cloud-model"
// Simulate cloud model behavior
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "cloud-model" {
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
}
})
t.Run("cloud model preserves custom fast", func(t *testing.T) {
aliases := map[string]string{
"primary": "cloud-model",
"fast": "custom-fast-model",
}
// Simulate cloud model behavior - should preserve existing fast
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "custom-fast-model" {
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
}
})
t.Run("local model clears fast", func(t *testing.T) {
aliases := map[string]string{
"primary": "local-model",
"fast": "should-be-cleared",
}
// Simulate local model behavior
isCloud := false
if !isCloud {
delete(aliases, "fast")
}
if aliases["fast"] != "" {
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
}
})
t.Run("switching cloud to local clears fast", func(t *testing.T) {
// Start with cloud config
aliases := map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}
// Switch to local
aliases["primary"] = "local-model"
isCloud := false
if !isCloud {
delete(aliases, "fast")
}
if aliases["fast"] != "" {
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
}
if aliases["primary"] != "local-model" {
t.Errorf("primary should be updated, got %q", aliases["primary"])
}
})
t.Run("switching local to cloud sets fast", func(t *testing.T) {
// Start with local config (no fast)
aliases := map[string]string{
"primary": "local-model",
}
// Switch to cloud
aliases["primary"] = "cloud-model"
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "cloud-model" {
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
}
})
}
func TestSetAliases_PrefixMapping(t *testing.T) {
// This tests the expected mapping without needing a real client
aliases := map[string]string{
"primary": "my-cloud-model",
"fast": "my-fast-model",
}
expectedMappings := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
t.Errorf("claude-sonnet- should map to primary")
}
if expectedMappings["claude-haiku-"] != "my-fast-model" {
t.Errorf("claude-haiku- should map to fast")
}
}
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
aliases := map[string]string{
"primary": "local-model",
// fast is empty/missing - indicates local model
}
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
// Verify the logic: when fast is empty, we should delete
if aliases["fast"] != "" {
t.Error("fast should be empty for local model")
}
// Verify we have the right prefixes to delete
if len(prefixesToDelete) != 2 {
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
}
}
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate: server fails, config should NOT be saved
serverErr := errors.New("server unavailable")
if serverErr == nil {
t.Error("config should NOT be saved when server fails")
}
}
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate: server succeeds, config should be saved
var serverErr error
if serverErr != nil {
t.Fatal("server should succeed")
}
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("saveAliases failed: %v", err)
}
// Verify it was actually saved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model" {
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
}
}
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Write config with extra fields
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
os.MkdirAll(filepath.Dir(configPath), 0o755)
// Note: Our config struct only has Integrations, so top-level unknown fields
// won't be preserved by our current implementation. This test documents that.
initialConfig := `{
"integrations": {
"claude": {
"models": ["model1"],
"aliases": {"primary": "model1"},
"unknownField": "should be lost"
}
},
"topLevelUnknown": "will be lost"
}`
os.WriteFile(configPath, []byte(initialConfig), 0o644)
// Update aliases
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Read raw file to check
data, _ := os.ReadFile(configPath)
content := string(data)
// models should be preserved
if !contains(content, "model1") {
t.Error("models should be preserved")
}
// primary should be updated
if !contains(content, "model2") {
t.Error("primary should be updated to model2")
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
c := &Claude{}
var _ AliasConfigurer = c // Compile-time check
}
func TestModelNameEdgeCases(t *testing.T) {
testCases := []struct {
name string
model string
}{
{"simple", "llama3.2"},
{"with tag", "llama3.2:latest"},
{"with cloud tag", "kimi-k2.5:cloud"},
{"with namespace", "library/llama3.2"},
{"with dots", "glm-4.7-flash"},
{"with numbers", "qwen3:8b"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
aliases := map[string]string{"primary": tc.model}
if err := saveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save model %q: %v", tc.model, err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != tc.model {
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
}
})
}
}
func TestSwitchingScenarios(t *testing.T) {
t.Run("cloud to local removes fast", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial cloud config
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
t.Fatal(err)
}
// Switch to local (no fast)
if err := saveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
}
if loaded.Aliases["primary"] != "local-model" {
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
}
})
t.Run("local to cloud adds fast", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial local config
if err := saveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
// Switch to cloud (with fast)
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
}
})
t.Run("cloud to different cloud updates both", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial cloud config
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model-1",
"fast": "cloud-model-1",
}); err != nil {
t.Fatal(err)
}
// Switch to different cloud
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model-2",
"fast": "cloud-model-2",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["primary"] != "cloud-model-2" {
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
}
if loaded.Aliases["fast"] != "cloud-model-2" {
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
}
})
}
func TestToolCapabilityFiltering(t *testing.T) {
t.Run("all models checked for tool capability", func(t *testing.T) {
// Both cloud and local models are checked for tool capability via Show API
// Only models with "tools" in capabilities are included
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
if !m.ToolCapable {
t.Error("tool capable model should be marked as such")
}
})
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
if !m.ToolCapable {
t.Error("ToolCapable field should be accessible")
}
})
}
func TestIsCloudModel_RequiresClient(t *testing.T) {
t.Run("nil client always returns false", func(t *testing.T) {
// isCloudModel now only uses Show API, no suffix detection
if isCloudModel(context.Background(), nil, "model:cloud") {
t.Error("nil client should return false regardless of suffix")
}
if isCloudModel(context.Background(), nil, "local-model") {
t.Error("nil client should return false")
}
})
}
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save aliases with one model
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
t.Fatal(err)
}
// Save integration with same model (this is the pattern we use)
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["primary"] != loaded.Models[0] {
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
}
})
t.Run("out of sync config is detectable", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate out-of-sync state (like manual edit or bug)
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
// They should be different (this is the bug state)
if loaded.Models[0] == loaded.Aliases["primary"] {
t.Error("expected out-of-sync state for this test")
}
// The fix: when updating aliases, also update models
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
t.Fatal(err)
}
loaded, _ = loadIntegration("claude")
if loaded.Models[0] != loaded.Aliases["primary"] {
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
loaded.Models[0], loaded.Aliases["primary"])
}
})
t.Run("updating primary alias updates models too", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial state
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
t.Fatal(err)
}
// Update aliases AND models together
newAliases := map[string]string{"primary": "updated-model"}
if err := saveAliases("claude", newAliases); err != nil {
t.Fatal(err)
}
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Models[0] != "updated-model" {
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
}
if loaded.Aliases["primary"] != "updated-model" {
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
}
})
}

View File

@@ -46,6 +46,53 @@ func TestIntegrationConfig(t *testing.T) {
}
})
t.Run("save and load aliases", func(t *testing.T) {
models := []string{"llama3.2"}
if err := saveIntegration("claude", models); err != nil {
t.Fatal(err)
}
aliases := map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
}
if err := saveAliases("claude", aliases); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases == nil {
t.Fatal("expected aliases to be saved")
}
for k, v := range aliases {
if config.Aliases[k] != v {
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
}
}
})
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
t.Fatal(err)
}
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases["primary"] != "model-a" {
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
}
})
t.Run("defaultModel returns first model", func(t *testing.T) {
saveIntegration("codex", []string{"model-a", "model-b"})

View File

@@ -39,6 +39,15 @@ type Editor interface {
Models() []string
}
// AliasConfigurer can configure model aliases (e.g., for subagent routing).
// Integrations like Claude and Codex use this to route model requests to local models.
type AliasConfigurer interface {
// ConfigureAliases prompts the user to configure aliases and returns the updated map.
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
// SetAliases syncs the configured aliases to the server
SetAliases(ctx context.Context, aliases map[string]string) error
}
// integrations is the registry of available integrations.
var integrations = map[string]Runner{
"claude": &Claude{},
@@ -129,7 +138,11 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
return nil, err
}
} else {
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
prompt := fmt.Sprintf("Select model for %s:", r)
if _, ok := r.(AliasConfigurer); ok {
prompt = fmt.Sprintf("Select Primary model for %s:", r)
}
model, err := selectPrompt(prompt, items)
if err != nil {
return nil, err
}
@@ -157,73 +170,123 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
}
}
if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
return nil, err
}
return selected, nil
}
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
if existingModels[model] {
return nil
}
msg := fmt.Sprintf("Download %s?", model)
if ok, err := confirmPrompt(msg); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
if err := pullModel(ctx, client, model); err != nil {
return fmt.Errorf("failed to pull %s: %w", model, err)
}
return nil
}
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, nil, nil, nil, err
}
models, err := client.List(ctx)
if err != nil {
return nil, nil, nil, nil, err
}
var existing []modelInfo
for _, m := range models.Models {
existing = append(existing, modelInfo{
Name: m.Name,
Remote: m.RemoteModel != "",
})
}
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
if len(items) == 0 {
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
return items, existingModels, cloudModels, client, nil
}
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) > 0 {
// ensure user is signed in
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return selected, nil
}
if len(selectedCloudModels) == 0 {
return nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return nil, err
}
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return nil
}
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return nil, fmt.Errorf("%s requires sign in", modelList)
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return err
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return fmt.Errorf("%s requires sign in", modelList)
}
// TODO(parthsareen): extract into auth package for cmd
// Auto-open browser (best effort, fail silently)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return nil, ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return selected, nil
}
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return nil
}
}
}
}
return selected, nil
}
func runIntegration(name, modelName string, args []string) error {
@@ -231,10 +294,33 @@ func runIntegration(name, modelName string, args []string) error {
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName, args)
}
// syncAliases syncs aliases to server and saves locally for an AliasConfigurer.
func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, name, model string, existing map[string]string) error {
aliases := make(map[string]string)
for k, v := range existing {
aliases[k] = v
}
aliases["primary"] = model
if isCloudModel(ctx, client, model) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = model
}
} else {
delete(aliases, "fast")
}
if err := ac.SetAliases(ctx, aliases); err != nil {
return err
}
return saveAliases(name, aliases)
}
// LaunchCmd returns the cobra command for launching integrations.
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
var modelFlag string
@@ -302,9 +388,87 @@ Examples:
return fmt.Errorf("unknown integration: %s", name)
}
if !configFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0], passArgs)
// Handle AliasConfigurer integrations (claude, codex)
if ac, ok := r.(AliasConfigurer); ok {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Validate --model flag if provided
if modelFlag != "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
}
}
var model string
var existingAliases map[string]string
// Load saved config
if cfg, err := loadIntegration(name); err == nil {
existingAliases = cfg.Aliases
if len(cfg.Models) > 0 {
model = cfg.Models[0]
// AliasConfigurer integrations use single model; sanitize if multiple
if len(cfg.Models) > 1 {
_ = saveIntegration(name, []string{model})
}
}
}
// --model flag overrides saved model
if modelFlag != "" {
model = modelFlag
}
// Validate saved model still exists
if model != "" && modelFlag == "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}); err != nil {
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
model = ""
}
}
// If no valid model or --config flag, show picker
if model == "" || configFlag {
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
}
// Sync aliases and save
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
}
if err := saveIntegration(name, []string{model}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
// Launch (unless --config without confirmation)
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
return runIntegration(name, model, passArgs)
}
return nil
}
return runIntegration(name, model, passArgs)
}
// Validate --model flag for non-AliasConfigurer integrations
if modelFlag != "" {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
}
}
@@ -318,6 +482,8 @@ Examples:
}
}
}
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
return runIntegration(name, saved.Models[0], passArgs)
} else {
var err error
models, err = selectModels(cmd.Context(), name, "")
@@ -380,8 +546,9 @@ Examples:
}
type modelInfo struct {
Name string
Remote bool
Name string
Remote bool
ToolCapable bool
}
// buildModelList merges existing models with recommendations, sorts them, and returns
@@ -418,7 +585,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
continue
}
items = append(items, rec)
if isCloudModel(rec.Name) {
if strings.HasSuffix(rec.Name, ":cloud") {
cloudModels[rec.Name] = true
}
}
@@ -478,8 +645,16 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
return items, preChecked, existingModels, cloudModels
}
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud")
// isCloudModel checks if a model is a cloud model using the Show API.
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
if client == nil {
return false
}
resp, err := client.Show(ctx, &api.ShowRequest{Name: name})
if err != nil {
return false
}
return resp.RemoteModel != ""
}
func pullModel(ctx context.Context, client *api.Client, model string) error {

View File

@@ -1,6 +1,7 @@
package config
import (
"context"
"fmt"
"slices"
"strings"
@@ -297,24 +298,15 @@ func TestParseArgs(t *testing.T) {
}
func TestIsCloudModel(t *testing.T) {
tests := []struct {
name string
want bool
}{
{"glm-4.7:cloud", true},
{"kimi-k2.5:cloud", true},
{"glm-4.7-flash", false},
{"glm-4.7-flash:latest", false},
{"cloud-model", false},
{"model:cloudish", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isCloudModel(tt.name); got != tt.want {
t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want)
// isCloudModel now only uses Show API, so nil client always returns false
t.Run("nil client returns false", func(t *testing.T) {
models := []string{"glm-4.7:cloud", "kimi-k2.5:cloud", "local-model"}
for _, model := range models {
if isCloudModel(context.Background(), nil, model) {
t.Errorf("isCloudModel(%q) with nil client should return false", model)
}
})
}
}
})
}
func names(items []selectItem) []string {
@@ -509,3 +501,41 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
t.Error("llama3.2 should not be in cloudModels")
}
}
func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save a config for opencode so it looks like a previous launch
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Verify loadIntegration returns the saved models
saved, err := loadIntegration("opencode")
if err != nil {
t.Fatal(err)
}
if len(saved.Models) == 0 {
t.Fatal("expected saved models")
}
if saved.Models[0] != "llama3.2" {
t.Errorf("expected llama3.2, got %s", saved.Models[0])
}
}
func TestAliasConfigurerInterface(t *testing.T) {
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
claude := &Claude{}
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
t.Error("Claude should implement AliasConfigurer")
}
})
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
codex := &Codex{}
if _, ok := interface{}(codex).(AliasConfigurer); ok {
t.Error("Codex should not implement AliasConfigurer")
}
})
}

View File

@@ -17,8 +17,6 @@ type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" }
const ansiGreen = "\033[32m"
func (c *Openclaw) Run(model string, args []string) error {
bin := "openclaw"
if _, err := exec.LookPath(bin); err != nil {

View File

@@ -1,6 +1,7 @@
package config
import (
"context"
"encoding/json"
"fmt"
"maps"
@@ -10,12 +11,52 @@ import (
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
// cloudModelLimit holds context and output token limits for a cloud model.
type cloudModelLimit struct {
Context int
Output int
}
// cloudModelLimits maps cloud model base names to their token limits.
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
var cloudModelLimits = map[string]cloudModelLimit{
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
"glm-4.6": {Context: 202_752, Output: 131_072},
"glm-4.7": {Context: 202_752, Output: 131_072},
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
"kimi-k2.5": {Context: 262_144, Output: 262_144},
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It tries the exact name first, then strips the ":cloud" suffix.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
if l, ok := cloudModelLimits[name]; ok {
return l, true
}
base := strings.TrimSuffix(name, ":cloud")
if base != name {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string, args []string) error {
@@ -113,6 +154,8 @@ func (o *OpenCode) Edit(modelList []string) error {
}
}
client, _ := api.ClientFromEnvironment()
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
@@ -122,12 +165,29 @@ func (o *OpenCode) Edit(modelList []string) error {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
existing["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
continue
}
models[model] = map[string]any{
entry := map[string]any{
"name": model,
"_launch": true,
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
models[model] = entry
}
ollama["models"] = models

View File

@@ -2,6 +2,7 @@ package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
@@ -495,6 +496,165 @@ func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
}
}
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
t.Helper()
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry, ok := models[model].(map[string]any)
if !ok {
t.Fatalf("model %s not found in config", model)
}
return entry
}
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
if entry["limit"] != nil {
t.Errorf("local model should not have limit set, got %v", entry["limit"])
}
}
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
// Set up a model with a user-configured limit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{
"provider": {
"ollama": {
"models": {
"llama3.2": {
"name": "llama3.2",
"_launch": true,
"limit": {"context": 8192, "output": 4096}
}
}
}
}
}`), 0o644)
// Re-edit should preserve the user's limit (not delete it)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("user-configured limit was removed")
}
if limit["context"] != float64(8192) {
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
}
if limit["output"] != float64(4096) {
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
}
}
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
// Verify that when a cloud model entry has limits set (as Edit would do),
// the structure matches what opencode expects and re-edit preserves them.
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
expected := cloudModelLimits["glm-4.7"]
// Simulate a cloud model that already has the limit set by a previous Edit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
"provider": {
"ollama": {
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud",
"_launch": true,
"limit": {"context": %d, "output": %d}
}
}
}
}
}`, expected.Context, expected.Output)), 0o644)
// Re-edit should preserve the cloud model limit
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model limit was removed on re-edit")
}
if limit["context"] != float64(expected.Context) {
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
}
if limit["output"] != float64(expected.Output) {
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
}
}
func TestLookupCloudModelLimit(t *testing.T) {
tests := []struct {
name string
wantOK bool
wantContext int
wantOutput int
}{
{"glm-4.7", true, 202_752, 131_072},
{"glm-4.7:cloud", true, 202_752, 131_072},
{"kimi-k2.5", true, 262_144, 262_144},
{"kimi-k2.5:cloud", true, 262_144, 262_144},
{"deepseek-v3.2", true, 163_840, 65_536},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3-coder:480b", true, 262_144, 65_536},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(tt.name)
if ok != tt.wantOK {
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
}
if ok {
if l.Context != tt.wantContext {
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
}
if l.Output != tt.wantOutput {
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
}
}
})
}
}
func TestOpenCodeModels_NoConfig(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()

View File

@@ -17,6 +17,7 @@ const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiClearDown = "\033[J"
)

View File

@@ -96,6 +96,14 @@ func TestSelectState(t *testing.T) {
}
})
t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) {
s := newSelectState([]selectItem{})
done, result, err := s.handleInput(eventEnter, 0)
if done || result != "" || err != nil {
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEscape, 0)
@@ -574,8 +582,19 @@ func TestRenderSelect(t *testing.T) {
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "no matches") {
t.Errorf("expected 'no matches' message, got: %s", output)
}
})
t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) {
s := newSelectState([]selectItem{})
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' message")
t.Error("expected 'no matches' message for empty list with no filter")
}
})

View File

@@ -10,19 +10,21 @@ import (
"github.com/ollama/ollama/api"
)
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
func startApp(ctx context.Context, client *api.Client) error {
exe, err := os.Executable()
if err != nil {
return err
return errNotRunning
}
link, err := os.Readlink(exe)
if err != nil {
return err
return errNotRunning
}
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
m := r.FindStringSubmatch(link)
if len(m) != 1 {
return errors.New("could not find ollama app")
return errNotRunning
}
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
return err

View File

@@ -188,8 +188,6 @@ func LogLevel() slog.Level {
var (
// FlashAttention enables the experimental flash attention feature.
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
// DebugLogRequests logs inference requests to disk for replay/debugging.
DebugLogRequests = Bool("OLLAMA_DEBUG_LOG_REQUESTS")
// KvCacheType is the quantization type for the K/V cache.
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
// NoHistory disables readline history.
@@ -275,27 +273,26 @@ type EnvVar struct {
func AsMap() map[string]EnvVar {
ret := map[string]EnvVar{
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_DEBUG_LOG_REQUESTS": {"OLLAMA_DEBUG_LOG_REQUESTS", DebugLogRequests(), "Log inference request bodies and replay curl commands to a temp directory"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
// Informational
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},

5
go.mod
View File

@@ -13,7 +13,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.24
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/stretchr/testify v1.10.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
@@ -29,6 +29,8 @@ require (
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/tree-sitter/go-tree-sitter v0.25.0
github.com/tree-sitter/tree-sitter-cpp v0.23.4
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0
@@ -50,6 +52,7 @@ require (
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-pointer v0.0.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect

31
go.sum
View File

@@ -152,6 +152,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
@@ -206,12 +208,39 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=

View File

@@ -34,6 +34,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
type filteredEnv []string
@@ -116,7 +117,7 @@ type llamaServer struct {
type ollamaServer struct {
llmServer
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
}
// LoadModel will load a model from disk. The model must be in the GGML format.
@@ -142,11 +143,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
var llamaModel *llama.Model
var textProcessor model.TextProcessor
var tok tokenizer.Tokenizer
var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 {
textProcessor, err = model.NewTextProcessor(modelPath)
tok, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
@@ -155,7 +156,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
}
}
if textProcessor == nil {
if tok == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
@@ -211,7 +212,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType())
if textProcessor == nil {
if tok == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
@@ -261,7 +262,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
gpuLibs := ml.LibraryPaths(gpus)
status := NewStatusWriter(os.Stderr)
cmd, port, err := StartRunner(
textProcessor != nil,
tok != nil,
modelPath,
gpuLibs,
status,
@@ -310,8 +311,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}
}()
if textProcessor != nil {
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
if tok != nil {
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
}
@@ -1774,7 +1775,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
}
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
tokens, err := s.textProcessor.Encode(content, false)
tokens, err := s.tokenizer.Encode(content, false)
if err != nil {
return nil, err
}
@@ -1809,7 +1810,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
toks[i] = int32(t)
}
content, err := s.textProcessor.Decode(toks)
content, err := s.tokenizer.Decode(toks)
if err != nil {
return "", err
}

View File

@@ -131,12 +131,15 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
messageID := anthropic.GenerateMessageID()
// Estimate input tokens for streaming (actual count not available until generation completes)
estimatedTokens := anthropic.EstimateInputTokens(req)
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model),
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
}
if req.Stream {

View File

@@ -1,272 +0,0 @@
package model
import (
"cmp"
"iter"
"slices"
"strings"
"github.com/dlclark/regexp2"
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil"
)
type BytePairEncoding struct {
vocab *Vocabulary
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
}
}),
}
}
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
return bpe.vocab
}
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
return bpe.vocab.Is(id, special)
}
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
parts := []string{s}
for _, re := range bpe.regexps {
parts = slices.Collect(func(yield func(string) bool) {
for _, part := range parts {
r := []rune(part)
var offset int
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
if offset-m.Index != 0 {
if !yield(string(r[:m.Index])) {
return
}
}
if !yield(m.String()) {
return
}
offset = m.Index + m.Length
}
if offset < len(r) {
if !yield(string(r[offset:])) {
return
}
}
}
})
}
return slices.Values(parts)
}
// fragment is a string fragment and their corresponding token IDs
type fragment struct {
value string
ids []int32
}
// pair is a pair of runes and its rank
type pair struct {
a, b int
rank int
value string
}
type merge struct {
p, n int
runes []rune
}
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := bpe.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range bpe.split(frag.value) {
// TODO: process splits concurrently
var sb strings.Builder
for _, b := range []byte(split) {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
sb.WriteRune(r)
}
// short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *pair {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
rank := bpe.vocab.Merge(left, right)
if rank < 0 {
return nil
}
return &pair{
a: a,
b: b,
rank: rank,
value: left + right,
}
}
pairs := heap.NewWith(func(i, j *pair) int {
return cmp.Compare(i.rank, j.rank)
})
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pairs.Push(pair)
}
}
for !pairs.Empty() {
pair, _ := pairs.Pop()
left, right := merges[pair.a], merges[pair.b]
if len(left.runes) == 0 || len(right.runes) == 0 ||
string(left.runes)+string(right.runes) != pair.value {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pairs.Push(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pairs.Push(pair)
}
}
for _, merge := range merges {
if len(merge.runes) > 0 {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
}
}
}
}
}
if addSpecial {
ids = bpe.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
for _, r := range bpe.vocab.Decode(id) {
switch {
case r == 0x0100:
// this produces 0x00 aka NULL
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// NOTE: not using WriteRune here because it writes the UTF-8
// encoding of the rune which is _not_ what we want
if err := sb.WriteByte(byte(r)); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "string", sb.String(), "from", ids)
return sb.String(), nil
}

View File

@@ -23,6 +23,7 @@ import (
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var (
@@ -133,7 +134,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil
}
func NewTextProcessor(s string) (TextProcessor, error) {
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
@@ -150,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
tp, ok := m.(TextProcessor)
tp, ok := m.(tokenizer.Tokenizer)
if !ok {
return nil, ErrUnsupportedTokenizer
}

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
}
func New(c fs.Config) (model.Model, error) {
vocab := &model.Vocabulary{
vocab := &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
},
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(vocab, true)
t = tokenizer.NewWordPiece(vocab, true)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Tokenizer: t,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
@@ -134,8 +135,8 @@ func init() {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -43,8 +44,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
poolingType pooling.Type
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
),
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model") {
case "gpt2":
processor = model.NewBytePairEncoding(&vocabulary)
t = tokenizer.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
processor = model.NewSentencePiece(&vocabulary)
t = tokenizer.NewSentencePiece(&vocabulary)
}
m := Model{
TextProcessor: processor,
Tokenizer: t,
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),

View File

@@ -6,11 +6,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
}
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
@@ -198,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -236,8 +237,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -37,8 +38,8 @@ func New(c fs.Config) (model.Model, error) {
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Transformer struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TransformerBlocks []TransformerBlock `gguf:"blk"`
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -59,7 +60,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -78,7 +79,7 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedTokenizer
}
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -104,8 +105,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedModel
}
var processor model.TextProcessor
vocabulary := model.Vocabulary{
var processor tokenizer.Tokenizer
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
processor = model.NewSentencePiece(&vocabulary)
processor = tokenizer.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
ImageProcessor
*VisionModel `gguf:"v"`
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -28,12 +29,12 @@ type Model struct {
var _ model.MultimodalProcessor = (*Model)(nil)
// Implement TextProcessor interface
var _ model.TextProcessor = (*Model)(nil)
var _ tokenizer.Tokenizer = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -32,8 +33,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -178,29 +179,6 @@ func New(c fs.Config) (model.Model, error) {
numHeads := int(c.Uint("attention.head_count"))
headDim := hiddenSize / numHeads
processor := model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
)
blockCount := int(c.Uint("block_count"))
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
layers := make([]EncoderLayer, blockCount)
@@ -219,8 +197,29 @@ func New(c fs.Config) (model.Model, error) {
}
return &Model{
TextProcessor: processor,
Layers: layers,
Tokenizer: tokenizer.NewWordPiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
),
Layers: layers,
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
const (
@@ -33,7 +34,7 @@ type Options struct {
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -44,28 +45,24 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
processor := model.NewBytePairEncoding(
&vocabulary,
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []DecoderLayer `gguf:"blk"`
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*Model
poolingType pooling.Type
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
layers[i].MLP = &dense{}
}
m := embedModel{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
// Options contains model configuration
@@ -207,7 +208,7 @@ func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outp
// Model is the main Qwen3-Next model
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -353,8 +354,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -1,17 +0,0 @@
package model
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocabulary() *Vocabulary
}

View File

@@ -1,53 +0,0 @@
package model
import (
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWordPiece(t *testing.T) {
wpm := NewWordPiece(
&Vocabulary{
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
AddBOS: true,
AddEOS: true,
BOS: []int32{1},
EOS: []int32{2},
},
true, // lowercase
)
ids, err := wpm.Encode("Hello world!", true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
t.Errorf("unexpected ids (-want +got):\n%s", diff)
}
words, err := wpm.Decode(ids)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}
func TestWordPieceWords(t *testing.T) {
var wpm WordPiece
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}

View File

@@ -37,6 +37,7 @@ import (
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/tokenizer"
_ "github.com/ollama/ollama/model/models"
)
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
decoder := func(tokenID int) string {
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
text, _ := tok.Decode([]int32{int32(tokenID)})
return text
}
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
if err != nil {
return nil, nil, nil, err
}
@@ -764,7 +765,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
@@ -773,14 +774,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
if err != nil {
panic("failed to decode token")
}
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
if seq.logprobs {
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
}
@@ -878,7 +879,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var grammar *sample.GrammarSampler
var err error
if req.Grammar != "" {
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return

View File

@@ -3,6 +3,7 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
@@ -11,22 +12,15 @@ func Execute(args []string) error {
args = args[1:]
}
var newRunner bool
var mlxRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--mlx-engine" {
args = args[1:]
mlxRunner = true
}
if mlxRunner {
return mlxrunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {
return llamarunner.Execute(args)
if len(args) > 0 {
switch args[0] {
case "--ollama-engine":
return ollamarunner.Execute(args[1:])
case "--imagegen-engine":
return imagegen.Execute(args[1:])
case "--mlx-engine":
return mlxrunner.Execute(args[1:])
}
}
return llamarunner.Execute(args)
}

View File

@@ -7,7 +7,7 @@ import (
"slices"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
// token represents information about a single token during sampling
@@ -168,15 +168,15 @@ type GrammarSampler struct {
grammar *llama.Grammar
}
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(model.Vocabulary().Values))
pieces := make([]string, len(model.Vocabulary().Values))
for i := range model.Vocabulary().Values {
pieces[i], _ = model.Decode([]int32{int32(i)})
func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(tok.Vocabulary().Values))
pieces := make([]string, len(tok.Vocabulary().Values))
for i := range tok.Vocabulary().Values {
pieces[i], _ = tok.Decode([]int32{int32(i)})
vocabIds[i] = uint32(i)
}
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tok.Vocabulary().EOS)
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}

View File

@@ -8,7 +8,7 @@ import (
"path/filepath"
"testing"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
func TestWeighted(t *testing.T) {
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
}
}
func modelHelper(t testing.TB) model.BytePairEncoding {
func modelHelper(t testing.TB) tokenizer.Tokenizer {
t.Helper()
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
f, err := os.Open(filepath.FromSlash("../tokenizer/testdata/llama3.2/encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
merges := make([]string, 0, 1)
// Only need vocab for Grammar Test
return model.NewBytePairEncoding(
&model.Vocabulary{
return tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: tokens,
Types: make([]int32, len(vocab)),
Merges: merges,

View File

@@ -1,5 +1,5 @@
#!/bin/sh
# This script installs Ollama on Linux.
# This script installs Ollama on Linux and macOS.
# It detects the current operating system architecture and installs the appropriate version of Ollama.
set -eu
@@ -27,8 +27,7 @@ require() {
echo $MISSING
}
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
OS="$(uname -s)"
ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;;
@@ -36,6 +35,65 @@ case "$ARCH" in
*) error "Unsupported architecture: $ARCH" ;;
esac
###########################################
# macOS
###########################################
if [ "$OS" = "Darwin" ]; then
NEEDS=$(require curl unzip)
if [ -n "$NEEDS" ]; then
status "ERROR: The following tools are required but missing:"
for NEED in $NEEDS; do
echo " - $NEED"
done
exit 1
fi
if [ -n "${OLLAMA_VERSION:-}" ]; then
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/download/${OLLAMA_VERSION}/Ollama-darwin.zip"
else
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/latest/download/Ollama-darwin.zip"
fi
if pgrep -x Ollama >/dev/null 2>&1; then
status "Stopping running Ollama instance..."
pkill -x Ollama 2>/dev/null || true
sleep 2
fi
if [ -d "/Applications/Ollama.app" ]; then
status "Removing existing Ollama installation..."
rm -rf "/Applications/Ollama.app"
fi
status "Downloading Ollama for macOS..."
curl --fail --show-error --location --progress-bar \
-o "$TEMP_DIR/Ollama-darwin.zip" "$DOWNLOAD_URL"
status "Installing Ollama to /Applications..."
unzip -q "$TEMP_DIR/Ollama-darwin.zip" -d "$TEMP_DIR"
mv "$TEMP_DIR/Ollama.app" "/Applications/"
status "Adding 'ollama' command to PATH (may require password)..."
mkdir -p "/usr/local/bin" 2>/dev/null || sudo mkdir -p "/usr/local/bin"
ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama" 2>/dev/null || \
sudo ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama"
if [ -z "${OLLAMA_NO_START:-}" ]; then
status "Starting Ollama..."
open -a Ollama --args hidden
fi
status "Install complete. You can now run 'ollama'."
exit 0
fi
###########################################
# Linux
###########################################
[ "$OS" = "Linux" ] || error 'This script is intended to run on Linux and macOS only.'
IS_WSL2=false
KERN=$(uname -r)

422
server/aliases.go Normal file
View File

@@ -0,0 +1,422 @@
package server
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
const (
serverConfigFilename = "server.json"
serverConfigVersion = 1
)
var errAliasCycle = errors.New("alias cycle detected")
type aliasEntry struct {
Alias string `json:"alias"`
Target string `json:"target"`
PrefixMatching bool `json:"prefix_matching,omitempty"`
}
type serverConfig struct {
Version int `json:"version"`
Aliases []aliasEntry `json:"aliases"`
}
type store struct {
mu sync.RWMutex
path string
entries map[string]aliasEntry // normalized alias -> entry (exact matches)
prefixEntries []aliasEntry // prefix matches, sorted longest-first
}
func createStore(path string) (*store, error) {
store := &store{
path: path,
entries: make(map[string]aliasEntry),
}
if err := store.load(); err != nil {
return nil, err
}
return store, nil
}
func (s *store) load() error {
data, err := os.ReadFile(s.path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
var cfg serverConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return err
}
if cfg.Version != 0 && cfg.Version != serverConfigVersion {
return fmt.Errorf("unsupported router config version %d", cfg.Version)
}
for _, entry := range cfg.Aliases {
targetName := model.ParseName(entry.Target)
if !targetName.IsValid() {
slog.Warn("invalid alias target in router config", "target", entry.Target)
continue
}
canonicalTarget := displayAliasName(targetName)
if entry.PrefixMatching {
// Prefix aliases don't need to be valid model names
alias := strings.TrimSpace(entry.Alias)
if alias == "" {
slog.Warn("empty prefix alias in router config")
continue
}
s.prefixEntries = append(s.prefixEntries, aliasEntry{
Alias: alias,
Target: canonicalTarget,
PrefixMatching: true,
})
} else {
aliasName := model.ParseName(entry.Alias)
if !aliasName.IsValid() {
slog.Warn("invalid alias name in router config", "alias", entry.Alias)
continue
}
canonicalAlias := displayAliasName(aliasName)
s.entries[normalizeAliasKey(aliasName)] = aliasEntry{
Alias: canonicalAlias,
Target: canonicalTarget,
}
}
}
// Sort prefix entries by alias length descending (longest prefix wins)
s.sortPrefixEntriesLocked()
return nil
}
func (s *store) saveLocked() error {
dir := filepath.Dir(s.path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
// Combine exact and prefix entries
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
for _, entry := range s.entries {
entries = append(entries, entry)
}
entries = append(entries, s.prefixEntries...)
sort.Slice(entries, func(i, j int) bool {
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
})
cfg := serverConfig{
Version: serverConfigVersion,
Aliases: entries,
}
f, err := os.CreateTemp(dir, "router-*.json")
if err != nil {
return err
}
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(cfg); err != nil {
_ = f.Close()
_ = os.Remove(f.Name())
return err
}
if err := f.Close(); err != nil {
_ = os.Remove(f.Name())
return err
}
if err := os.Chmod(f.Name(), 0o644); err != nil {
_ = os.Remove(f.Name())
return err
}
return os.Rename(f.Name(), s.path)
}
func (s *store) ResolveName(name model.Name) (model.Name, bool, error) {
// If a local model exists, do not allow alias shadowing (highest priority).
exists, err := localModelExists(name)
if err != nil {
return name, false, err
}
if exists {
return name, false, nil
}
key := normalizeAliasKey(name)
s.mu.RLock()
entry, exactMatch := s.entries[key]
var prefixMatch *aliasEntry
if !exactMatch {
// Try prefix matching - prefixEntries is sorted longest-first
nameStr := strings.ToLower(displayAliasName(name))
for i := range s.prefixEntries {
prefix := strings.ToLower(s.prefixEntries[i].Alias)
if strings.HasPrefix(nameStr, prefix) {
prefixMatch = &s.prefixEntries[i]
break // First match is longest due to sorting
}
}
}
s.mu.RUnlock()
if !exactMatch && prefixMatch == nil {
return name, false, nil
}
var current string
var visited map[string]struct{}
if exactMatch {
visited = map[string]struct{}{key: {}}
current = entry.Target
} else {
// For prefix match, use the target as-is
visited = map[string]struct{}{}
current = prefixMatch.Target
}
targetKey := normalizeAliasKeyString(current)
for {
targetName := model.ParseName(current)
if !targetName.IsValid() {
return name, false, fmt.Errorf("alias target %q is invalid", current)
}
if _, seen := visited[targetKey]; seen {
return name, false, errAliasCycle
}
visited[targetKey] = struct{}{}
s.mu.RLock()
next, ok := s.entries[targetKey]
s.mu.RUnlock()
if !ok {
return targetName, true, nil
}
current = next.Target
targetKey = normalizeAliasKeyString(current)
}
}
func (s *store) Set(alias, target model.Name, prefixMatching bool) error {
targetKey := normalizeAliasKey(target)
s.mu.Lock()
defer s.mu.Unlock()
if prefixMatching {
// For prefix aliases, we skip cycle detection since prefix matching
// works differently and the target is a specific model
aliasStr := displayAliasName(alias)
// Remove any existing prefix entry with the same alias
for i, e := range s.prefixEntries {
if strings.EqualFold(e.Alias, aliasStr) {
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
break
}
}
s.prefixEntries = append(s.prefixEntries, aliasEntry{
Alias: aliasStr,
Target: displayAliasName(target),
PrefixMatching: true,
})
s.sortPrefixEntriesLocked()
return s.saveLocked()
}
aliasKey := normalizeAliasKey(alias)
if aliasKey == targetKey {
return fmt.Errorf("alias cannot point to itself")
}
visited := map[string]struct{}{aliasKey: {}}
currentKey := targetKey
for {
if _, seen := visited[currentKey]; seen {
return errAliasCycle
}
visited[currentKey] = struct{}{}
next, ok := s.entries[currentKey]
if !ok {
break
}
currentKey = normalizeAliasKeyString(next.Target)
}
s.entries[aliasKey] = aliasEntry{
Alias: displayAliasName(alias),
Target: displayAliasName(target),
}
return s.saveLocked()
}
func (s *store) Delete(alias model.Name) (bool, error) {
aliasKey := normalizeAliasKey(alias)
s.mu.Lock()
defer s.mu.Unlock()
// Try exact match first
if _, ok := s.entries[aliasKey]; ok {
delete(s.entries, aliasKey)
return true, s.saveLocked()
}
// Try prefix entries
aliasStr := displayAliasName(alias)
for i, e := range s.prefixEntries {
if strings.EqualFold(e.Alias, aliasStr) {
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
return true, s.saveLocked()
}
}
return false, nil
}
// DeleteByString deletes an alias by its raw string value, useful for prefix
// aliases that may not be valid model names.
func (s *store) DeleteByString(alias string) (bool, error) {
alias = strings.TrimSpace(alias)
aliasLower := strings.ToLower(alias)
s.mu.Lock()
defer s.mu.Unlock()
// Try prefix entries first (since this is mainly for prefix aliases)
for i, e := range s.prefixEntries {
if strings.EqualFold(e.Alias, alias) {
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
return true, s.saveLocked()
}
}
// Also check exact entries by normalized key
if _, ok := s.entries[aliasLower]; ok {
delete(s.entries, aliasLower)
return true, s.saveLocked()
}
return false, nil
}
func (s *store) List() []aliasEntry {
s.mu.RLock()
defer s.mu.RUnlock()
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
for _, entry := range s.entries {
entries = append(entries, entry)
}
entries = append(entries, s.prefixEntries...)
sort.Slice(entries, func(i, j int) bool {
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
})
return entries
}
func normalizeAliasKey(name model.Name) string {
return strings.ToLower(displayAliasName(name))
}
func (s *store) sortPrefixEntriesLocked() {
sort.Slice(s.prefixEntries, func(i, j int) bool {
// Sort by length descending (longest prefix first)
return len(s.prefixEntries[i].Alias) > len(s.prefixEntries[j].Alias)
})
}
func normalizeAliasKeyString(value string) string {
n := model.ParseName(value)
if !n.IsValid() {
return strings.ToLower(strings.TrimSpace(value))
}
return normalizeAliasKey(n)
}
func displayAliasName(n model.Name) string {
display := n.DisplayShortest()
if strings.EqualFold(n.Tag, "latest") {
if idx := strings.LastIndex(display, ":"); idx != -1 {
return display[:idx]
}
}
return display
}
func localModelExists(name model.Name) (bool, error) {
manifests, err := manifest.Manifests(true)
if err != nil {
return false, err
}
needle := name.String()
for existing := range manifests {
if strings.EqualFold(existing.String(), needle) {
return true, nil
}
}
return false, nil
}
func serverConfigPath() string {
home, err := os.UserHomeDir()
if err != nil {
return filepath.Join(".ollama", serverConfigFilename)
}
return filepath.Join(home, ".ollama", serverConfigFilename)
}
func (s *Server) aliasStore() (*store, error) {
s.aliasesOnce.Do(func() {
s.aliases, s.aliasesErr = createStore(serverConfigPath())
})
return s.aliases, s.aliasesErr
}
func (s *Server) resolveAlias(name model.Name) (model.Name, bool, error) {
store, err := s.aliasStore()
if err != nil {
return name, false, err
}
if store == nil {
return name, false, nil
}
return store.ResolveName(name)
}

View File

@@ -1,144 +0,0 @@
package server
import (
"bytes"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/envconfig"
)
type inferenceRequestLogger struct {
dir string
counter uint64
}
func newInferenceRequestLogger() (*inferenceRequestLogger, error) {
dir, err := os.MkdirTemp("", "ollama-request-logs-*")
if err != nil {
return nil, err
}
return &inferenceRequestLogger{dir: dir}, nil
}
func (s *Server) initRequestLogging() error {
if !envconfig.DebugLogRequests() {
return nil
}
requestLogger, err := newInferenceRequestLogger()
if err != nil {
return fmt.Errorf("enable OLLAMA_DEBUG_LOG_REQUESTS: %w", err)
}
s.requestLogger = requestLogger
slog.Info(fmt.Sprintf("request debug logging enabled; inference request logs will be stored in %s and include request bodies and replay curl commands", requestLogger.dir))
return nil
}
func (s *Server) withInferenceRequestLogging(route string, handlers ...gin.HandlerFunc) []gin.HandlerFunc {
if s.requestLogger == nil {
return handlers
}
return append([]gin.HandlerFunc{s.requestLogger.middleware(route)}, handlers...)
}
func (l *inferenceRequestLogger) middleware(route string) gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request == nil {
c.Next()
return
}
method := c.Request.Method
host := c.Request.Host
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
contentType := c.GetHeader("Content-Type")
var body []byte
if c.Request.Body != nil {
var err error
body, err = io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(body))
if err != nil {
slog.Warn("failed to read request body for debug logging", "route", route, "error", err)
}
}
c.Next()
l.log(route, method, scheme, host, contentType, body)
}
}
func (l *inferenceRequestLogger) log(route, method, scheme, host, contentType string, body []byte) {
if l == nil || l.dir == "" {
return
}
if contentType == "" {
contentType = "application/json"
}
if host == "" || scheme == "" {
base := envconfig.Host()
if host == "" {
host = base.Host
}
if scheme == "" {
scheme = base.Scheme
}
}
routeForFilename := sanitizeRouteForFilename(route)
timestamp := fmt.Sprintf("%s-%06d", time.Now().UTC().Format("20060102T150405.000000000Z"), atomic.AddUint64(&l.counter, 1))
bodyFilename := fmt.Sprintf("%s_%s_body.json", timestamp, routeForFilename)
curlFilename := fmt.Sprintf("%s_%s_request.sh", timestamp, routeForFilename)
bodyPath := filepath.Join(l.dir, bodyFilename)
curlPath := filepath.Join(l.dir, curlFilename)
if err := os.WriteFile(bodyPath, body, 0o600); err != nil {
slog.Warn("failed to write debug request body", "route", route, "error", err)
return
}
url := fmt.Sprintf("%s://%s%s", scheme, host, route)
curl := fmt.Sprintf("#!/bin/sh\nSCRIPT_DIR=\"$(CDPATH= cd -- \"$(dirname -- \"$0\")\" && pwd)\"\ncurl --request %s --url %q --header %q --data-binary @\"${SCRIPT_DIR}/%s\"\n", method, url, "Content-Type: "+contentType, bodyFilename)
if err := os.WriteFile(curlPath, []byte(curl), 0o600); err != nil {
slog.Warn("failed to write debug request replay command", "route", route, "error", err)
return
}
slog.Info(fmt.Sprintf("logged to %s, replay using curl with `sh %s`", bodyPath, curlPath))
}
func sanitizeRouteForFilename(route string) string {
route = strings.TrimPrefix(route, "/")
if route == "" {
return "root"
}
var b strings.Builder
b.Grow(len(route))
for _, r := range route {
if ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || ('0' <= r && r <= '9') {
b.WriteRune(r)
} else {
b.WriteByte('_')
}
}
return b.String()
}

View File

@@ -22,6 +22,7 @@ import (
"os/signal"
"slices"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@@ -51,7 +52,7 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen"
imagegenmanifest "github.com/ollama/ollama/x/imagegen/manifest"
xserver "github.com/ollama/ollama/x/server"
)
@@ -81,7 +82,9 @@ type Server struct {
addr net.Addr
sched *Scheduler
defaultNumCtx int
requestLogger *inferenceRequestLogger
aliasesOnce sync.Once
aliases *store
aliasesErr error
}
func init() {
@@ -192,9 +195,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
resolvedName, _, err := s.resolveAlias(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
name = resolvedName
// We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure.
name, err := getExistingName(name)
name, err = getExistingName(name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
@@ -1096,7 +1106,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
if info, err := imagegenmanifest.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization
@@ -1581,27 +1591,30 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.POST("/api/copy", s.CopyHandler)
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
// Inference
r.GET("/api/ps", s.PsHandler)
r.POST("/api/generate", s.withInferenceRequestLogging("/api/generate", s.GenerateHandler)...)
r.POST("/api/chat", s.withInferenceRequestLogging("/api/chat", s.ChatHandler)...)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
// Inference (OpenAI compatibility)
r.POST("/v1/chat/completions", s.withInferenceRequestLogging("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)...)
r.POST("/v1/completions", s.withInferenceRequestLogging("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)...)
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", s.withInferenceRequestLogging("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)...)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// OpenAI-compatible image generation endpoints
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", s.withInferenceRequestLogging("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)...)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
if rc != nil {
// wrap old with new
@@ -1651,9 +1664,6 @@ func Serve(ln net.Listener) error {
}
s := &Server{addr: ln.Addr()}
if err := s.initRequestLogging(); err != nil {
return err
}
var rc *ollama.Registry
if useClient2 {
@@ -1954,13 +1964,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
name, err := getExistingName(name)
resolvedName, _, err := s.resolveAlias(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
name = resolvedName
name, err = getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
m, err := GetModel(req.Model)
m, err := GetModel(name.String())
if err != nil {
switch {
case os.IsNotExist(err):

159
server/routes_aliases.go Normal file
View File

@@ -0,0 +1,159 @@
package server
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/types/model"
)
type aliasListResponse struct {
Aliases []aliasEntry `json:"aliases"`
}
type aliasDeleteRequest struct {
Alias string `json:"alias"`
}
func (s *Server) ListAliasesHandler(c *gin.Context) {
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var aliases []aliasEntry
if store != nil {
aliases = store.List()
}
c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases})
}
func (s *Server) CreateAliasHandler(c *gin.Context) {
var req aliasEntry
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Alias = strings.TrimSpace(req.Alias)
req.Target = strings.TrimSpace(req.Target)
if req.Alias == "" || req.Target == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"})
return
}
// Target must always be a valid model name
targetName := model.ParseName(req.Target)
if !targetName.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)})
return
}
var aliasName model.Name
if req.PrefixMatching {
// For prefix aliases, we still parse the alias to normalize it,
// but we allow any non-empty string since prefix patterns may not be valid model names
aliasName = model.ParseName(req.Alias)
// Even if not valid as a model name, we accept it for prefix matching
} else {
aliasName = model.ParseName(req.Alias)
if !aliasName.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
return
}
if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"})
return
}
exists, err := localModelExists(aliasName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if exists {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)})
return
}
}
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := store.Set(aliasName, targetName, req.PrefixMatching); err != nil {
status := http.StatusInternalServerError
if errors.Is(err, errAliasCycle) {
status = http.StatusBadRequest
}
c.AbortWithStatusJSON(status, gin.H{"error": err.Error()})
return
}
resp := aliasEntry{
Alias: displayAliasName(aliasName),
Target: displayAliasName(targetName),
PrefixMatching: req.PrefixMatching,
}
if req.PrefixMatching && !aliasName.IsValid() {
// For prefix aliases that aren't valid model names, use the raw alias
resp.Alias = req.Alias
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) DeleteAliasHandler(c *gin.Context) {
var req aliasDeleteRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Alias = strings.TrimSpace(req.Alias)
if req.Alias == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"})
return
}
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
aliasName := model.ParseName(req.Alias)
var deleted bool
if aliasName.IsValid() {
deleted, err = store.Delete(aliasName)
} else {
// For invalid model names (like prefix aliases), try deleting by raw string
deleted, err = store.DeleteByString(req.Alias)
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !deleted {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)})
return
}
c.JSON(http.StatusOK, gin.H{"deleted": true})
}

View File

@@ -0,0 +1,426 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
func TestAliasShadowingRejected(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
s := Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "shadowed-model",
RemoteHost: "example.com",
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", w.Code)
}
}
func TestAliasResolvesForChatRemote(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
var remoteModel string
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req api.ChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatal(err)
}
remoteModel = req.Model
w.Header().Set("Content-Type", "application/json")
resp := api.ChatResponse{
Model: req.Model,
Done: true,
DoneReason: "load",
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
t.Fatal(err)
}
}))
defer rs.Close()
p, err := url.Parse(rs.URL)
if err != nil {
t.Fatal(err)
}
t.Setenv("OLLAMA_REMOTES", p.Hostname())
s := Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "target-model",
RemoteHost: rs.URL,
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "alias-model",
Messages: []api.Message{{Role: "user", Content: "hi"}},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Model != "alias-model" {
t.Fatalf("expected response model to be alias-model, got %q", resp.Model)
}
if remoteModel != "test" {
t.Fatalf("expected remote model to be 'test', got %q", remoteModel)
}
}
func TestPrefixAliasBasicMatching(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Create a prefix alias: "myprefix-" -> "targetmodel"
targetName := model.ParseName("targetmodel")
// Set a prefix alias (using "myprefix-" as the pattern)
store.mu.Lock()
store.prefixEntries = append(store.prefixEntries, aliasEntry{
Alias: "myprefix-",
Target: "targetmodel",
PrefixMatching: true,
})
store.mu.Unlock()
// Test that "myprefix-foo" resolves to "targetmodel"
testName := model.ParseName("myprefix-foo")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
if resolved.DisplayShortest() != targetName.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", targetName.DisplayShortest(), resolved.DisplayShortest())
}
// Test that "otherprefix-foo" does not resolve
otherName := model.ParseName("otherprefix-foo")
_, wasResolved, err = store.ResolveName(otherName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if wasResolved {
t.Fatal("expected name not to be resolved")
}
// Test that exact alias takes precedence
exactAlias := model.ParseName("myprefix-exact")
exactTarget := model.ParseName("exacttarget")
if err := store.Set(exactAlias, exactTarget, false); err != nil {
t.Fatal(err)
}
resolved, wasResolved, err = store.ResolveName(exactAlias)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
if resolved.DisplayShortest() != exactTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q (exact match), got %q", exactTarget.DisplayShortest(), resolved.DisplayShortest())
}
}
func TestPrefixAliasLongestMatchWins(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Add two prefix aliases with overlapping patterns
store.mu.Lock()
store.prefixEntries = []aliasEntry{
{Alias: "abc-", Target: "short-target", PrefixMatching: true},
{Alias: "abc-def-", Target: "long-target", PrefixMatching: true},
}
store.sortPrefixEntriesLocked()
store.mu.Unlock()
// "abc-def-ghi" should match the longer prefix "abc-def-"
testName := model.ParseName("abc-def-ghi")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
expectedLongTarget := model.ParseName("long-target")
if resolved.DisplayShortest() != expectedLongTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q (longest prefix match), got %q", expectedLongTarget.DisplayShortest(), resolved.DisplayShortest())
}
// "abc-xyz" should match the shorter prefix "abc-"
testName2 := model.ParseName("abc-xyz")
resolved, wasResolved, err = store.ResolveName(testName2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
expectedShortTarget := model.ParseName("short-target")
if resolved.DisplayShortest() != expectedShortTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", expectedShortTarget.DisplayShortest(), resolved.DisplayShortest())
}
}
func TestPrefixAliasChain(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Create a chain: prefix "test-" -> "intermediate" -> "final"
intermediate := model.ParseName("intermediate")
final := model.ParseName("final")
// Add prefix alias
store.mu.Lock()
store.prefixEntries = []aliasEntry{
{Alias: "test-", Target: "intermediate", PrefixMatching: true},
}
store.mu.Unlock()
// Add exact alias for the intermediate step
if err := store.Set(intermediate, final, false); err != nil {
t.Fatal(err)
}
// "test-foo" should resolve through the chain to "final"
testName := model.ParseName("test-foo")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
if resolved.DisplayShortest() != final.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", final.DisplayShortest(), resolved.DisplayShortest())
}
}
func TestPrefixAliasCRUD(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
s := Server{}
// Create a prefix alias via API
w := createRequest(t, s.CreateAliasHandler, aliasEntry{
Alias: "myprefix-",
Target: "llama2",
PrefixMatching: true,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
var createResp aliasEntry
if err := json.NewDecoder(w.Body).Decode(&createResp); err != nil {
t.Fatal(err)
}
if !createResp.PrefixMatching {
t.Fatal("expected prefix_matching to be true in response")
}
// List aliases and verify the prefix alias is included
w = createRequest(t, s.ListAliasesHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var listResp aliasListResponse
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
t.Fatal(err)
}
found := false
for _, a := range listResp.Aliases {
if a.PrefixMatching && a.Target == "llama2" {
found = true
break
}
}
if !found {
t.Fatal("expected to find prefix alias in list")
}
// Delete the prefix alias
w = createRequest(t, s.DeleteAliasHandler, aliasDeleteRequest{Alias: "myprefix-"})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify it's deleted
w = createRequest(t, s.ListAliasesHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
t.Fatal(err)
}
for _, a := range listResp.Aliases {
if a.PrefixMatching {
t.Fatal("expected prefix alias to be deleted")
}
}
}
func TestPrefixAliasCaseInsensitive(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Add a prefix alias with mixed case
store.mu.Lock()
store.prefixEntries = []aliasEntry{
{Alias: "MyPrefix-", Target: "targetmodel", PrefixMatching: true},
}
store.mu.Unlock()
// Test that matching is case-insensitive
testName := model.ParseName("myprefix-foo")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved (case-insensitive)")
}
expectedTarget := model.ParseName("targetmodel")
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
}
// Test uppercase request
testName2 := model.ParseName("MYPREFIX-BAR")
_, wasResolved, err = store.ResolveName(testName2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved (uppercase)")
}
}
func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
s := Server{}
// Create a local model that would match a prefix alias
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "myprefix-localmodel",
RemoteHost: "example.com",
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Create a prefix alias that would match the local model name
w = createRequest(t, s.CreateAliasHandler, aliasEntry{
Alias: "myprefix-",
Target: "someothermodel",
PrefixMatching: true,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify that resolving "myprefix-localmodel" returns the local model, not the alias target
store, err := s.aliasStore()
if err != nil {
t.Fatal(err)
}
localModelName := model.ParseName("myprefix-localmodel")
resolved, wasResolved, err := store.ResolveName(localModelName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if wasResolved {
t.Fatalf("expected local model to take precedence (wasResolved should be false), but got resolved to %q", resolved.DisplayShortest())
}
if resolved.DisplayShortest() != localModelName.DisplayShortest() {
t.Fatalf("expected resolved name to be local model %q, got %q", localModelName.DisplayShortest(), resolved.DisplayShortest())
}
// Also verify that a non-local model matching the prefix DOES resolve to the alias target
nonLocalName := model.ParseName("myprefix-nonexistent")
resolved, wasResolved, err = store.ResolveName(nonLocalName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected non-local model to resolve via prefix alias")
}
expectedTarget := model.ParseName("someothermodel")
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
}
}

View File

@@ -1,128 +0,0 @@
package server
import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
func TestInferenceRequestLoggerMiddlewareWritesReplayArtifacts(t *testing.T) {
gin.SetMode(gin.TestMode)
logDir := t.TempDir()
requestLogger := &inferenceRequestLogger{dir: logDir}
const route = "/v1/chat/completions"
const requestBody = `{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`
var bodySeenByHandler string
r := gin.New()
r.POST(route, requestLogger.middleware(route), func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
t.Fatalf("failed to read body in handler: %v", err)
}
bodySeenByHandler = string(body)
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodPost, route, strings.NewReader(requestBody))
req.Host = "127.0.0.1:11434"
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if bodySeenByHandler != requestBody {
t.Fatalf("handler body mismatch:\nexpected: %s\ngot: %s", requestBody, bodySeenByHandler)
}
bodyFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_body.json"))
if err != nil {
t.Fatalf("failed to glob body logs: %v", err)
}
if len(bodyFiles) != 1 {
t.Fatalf("expected 1 body log, got %d (%v)", len(bodyFiles), bodyFiles)
}
curlFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_request.sh"))
if err != nil {
t.Fatalf("failed to glob curl logs: %v", err)
}
if len(curlFiles) != 1 {
t.Fatalf("expected 1 curl log, got %d (%v)", len(curlFiles), curlFiles)
}
bodyData, err := os.ReadFile(bodyFiles[0])
if err != nil {
t.Fatalf("failed to read body log: %v", err)
}
if string(bodyData) != requestBody {
t.Fatalf("body log mismatch:\nexpected: %s\ngot: %s", requestBody, string(bodyData))
}
curlData, err := os.ReadFile(curlFiles[0])
if err != nil {
t.Fatalf("failed to read curl log: %v", err)
}
curlString := string(curlData)
if !strings.Contains(curlString, "http://127.0.0.1:11434"+route) {
t.Fatalf("curl log does not contain expected route URL: %s", curlString)
}
bodyFileName := filepath.Base(bodyFiles[0])
if !strings.Contains(curlString, "@\"${SCRIPT_DIR}/"+bodyFileName+"\"") {
t.Fatalf("curl log does not reference sibling body file: %s", curlString)
}
}
func TestNewInferenceRequestLoggerCreatesDirectory(t *testing.T) {
requestLogger, err := newInferenceRequestLogger()
if err != nil {
t.Fatalf("expected no error creating request logger: %v", err)
}
t.Cleanup(func() {
_ = os.RemoveAll(requestLogger.dir)
})
if requestLogger == nil || requestLogger.dir == "" {
t.Fatalf("expected request logger directory to be set")
}
info, err := os.Stat(requestLogger.dir)
if err != nil {
t.Fatalf("expected directory to exist: %v", err)
}
if !info.IsDir() {
t.Fatalf("expected %q to be a directory", requestLogger.dir)
}
}
func TestSanitizeRouteForFilename(t *testing.T) {
tests := []struct {
route string
want string
}{
{route: "/api/generate", want: "api_generate"},
{route: "/v1/chat/completions", want: "v1_chat_completions"},
{route: "/v1/messages", want: "v1_messages"},
}
for _, tt := range tests {
if got := sanitizeRouteForFilename(tt.route); got != tt.want {
t.Fatalf("sanitizeRouteForFilename(%q) = %q, want %q", tt.route, got, tt.want)
}
}
}

View File

@@ -5,9 +5,13 @@ import (
"errors"
"fmt"
"log/slog"
"math/rand"
"os"
"os/exec"
"reflect"
"slices"
"sort"
"strconv"
"strings"
"sync"
"time"
@@ -21,6 +25,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
@@ -195,25 +200,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for image generation models - all use MLX runner
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadMLX(pending) {
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if s.loadSafetensors(pending) {
break
}
continue
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) {
break
}
continue
}
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -563,20 +557,101 @@ iGPUScan:
return false
}
// loadMLX loads an experimental safetensors model using the unified MLX runner.
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode mlxrunner.ModelMode
func subproc(args, environ []string) (*exec.Cmd, int, error) {
exe, err := os.Executable()
if err != nil {
return nil, 0, fmt.Errorf("unable to lookup executable path: %w", err)
}
for range 3 {
// get a random port in the ephemeral range
port := rand.Intn(65535-49152) + 49152
cmd := exec.Command(exe, slices.Concat([]string{"runner"}, args, []string{"--port", strconv.Itoa(port)})...)
cmd.Env = slices.Concat(os.Environ(), environ)
cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
continue
}
return cmd, port, nil
}
return nil, 0, fmt.Errorf("unable to start subprocess after multiple attempts")
}
func (s *Scheduler) loadSafetensors(req *LlmRequest) bool {
if slices.Contains(req.model.Config.Capabilities, "image") {
mode = mlxrunner.ModeImageGen
return s.loadImageGen(req)
}
args := []string{"--mlx-engine", "--model", req.model.ShortName}
environ := []string{}
cmd, port, err := subproc(args, environ)
if err != nil {
req.errCh <- fmt.Errorf("failed to start mlx subprocess: %w", err)
return true
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
llama: &mlxrunner.Client{
Cmd: cmd,
Port: port,
},
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loadedMu.Unlock()
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
s.expiredCh <- runner
})
}
runner.refMu.Unlock()
req.useLoadedRunner(runner, s.finishedReqCh)
for range time.Tick(20 * time.Millisecond) {
if err := func() error {
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
return runner.llama.Ping(ctx)
}(); err != nil {
continue
}
break
}
return true
}
// loadImageGen loads an experimental safetensors model using the unified MLX runner.
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode imagegen.ModelMode
if slices.Contains(req.model.Config.Capabilities, "image") {
mode = imagegen.ModeImageGen
} else {
mode = mlxrunner.ModeLLM
mode = imagegen.ModeLLM
}
// Use model name for MLX (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := mlxrunner.NewServer(modelName, mode)
server, err := imagegen.NewServer(modelName, mode)
if err != nil {
req.errCh <- err
return true

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"cmp"
@@ -18,19 +18,19 @@ type BytePairEncoding struct {
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
var _ Tokenizer = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
if len(pretokenizer) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
for _, p := range pretokenizer {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"bufio"
@@ -17,7 +17,7 @@ import (
func llama(t testing.TB) BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
}
}
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"container/heap"
@@ -17,7 +17,7 @@ type SentencePiece struct {
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePiece)(nil)
var _ Tokenizer = (*SentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
@@ -224,7 +224,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizers that use byte tokens like "<0xEA>"
// For tokenizer that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"log/slog"
@@ -15,7 +15,7 @@ import (
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
if err != nil {
t.Fatal(err)
}

310
tokenizer/tokenizer.go Normal file
View File

@@ -0,0 +1,310 @@
package tokenizer
import (
"encoding/json"
"errors"
"io"
"os"
"github.com/ollama/ollama/types/model"
)
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type Tokenizer interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocabulary() *Vocabulary
}
func New(root *model.Root) (Tokenizer, error) {
f, err := root.Open("tokenizer.json")
if err != nil {
return nil, err
}
defer f.Close()
var tokenizer struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges json.RawMessage `json:"merges"`
} `json:"model"`
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
Decoder json.RawMessage `json:"decoder"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
if err := json.NewDecoder(f).Decode(&tokenizer); err != nil {
return nil, err
}
special := make(map[int32]struct{})
for _, token := range tokenizer.AddedTokens {
tokenizer.Model.Vocab[token.Content] = token.ID
special[token.ID] = struct{}{}
}
vocab, err := specialTokens(root, tokenizer.Model.Vocab)
if err != nil {
return nil, err
}
vocab.Values = make([]string, len(tokenizer.Model.Vocab))
vocab.Scores = make([]float32, len(tokenizer.Model.Vocab))
vocab.Types = make([]int32, len(tokenizer.Model.Vocab))
for content, id := range tokenizer.Model.Vocab {
vocab.Values[id] = content
vocab.Scores[id] = float32(id)
vocab.Types[id] = TOKEN_TYPE_NORMAL
if _, ok := special[id]; ok {
vocab.Types[id] = TOKEN_TYPE_USER_DEFINED
}
}
if tokenizer.Model.Merges != nil {
var pairs [][]string
if err := json.Unmarshal(tokenizer.Model.Merges, &pairs); err == nil {
vocab.Merges = make([]string, len(pairs))
for i, pair := range pairs {
vocab.Merges[i] = pair[0] + " " + pair[1]
}
} else if err := json.Unmarshal(tokenizer.Model.Merges, &vocab.Merges); err != nil {
return nil, err
}
}
vocab.valuesOnce.Do(func() {})
vocab.values = tokenizer.Model.Vocab
if tokenizer.Model.Type == "WordPiece" {
return NewWordPiece(vocab, true), nil
}
if tokenizer.Decoder != nil {
var decoder struct {
Type string `json:"type"`
Decoders []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"string"`
} `json:"pattern"`
} `json:"decoders"`
}
if err := json.Unmarshal(tokenizer.Decoder, &decoder); err != nil {
return nil, err
}
if decoder.Type == "Sequence" {
for _, d := range decoder.Decoders {
if d.Type == "Replace" && d.Pattern.String == "▁" {
return NewSentencePiece(vocab), nil
}
}
}
}
var pretokenizers []string
if tokenizer.PreTokenizer != nil {
var pretokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string
} `json:"pattern"`
IndividualDigits bool `json:"individual_digits"`
}
}
if err := json.Unmarshal(tokenizer.PreTokenizer, &pretokenizer); err != nil {
return nil, err
}
if pretokenizer.Type == "Sequence" {
for _, pretokenizer := range pretokenizer.Pretokenizers {
switch pretokenizer.Type {
case "Digits":
if pretokenizer.IndividualDigits {
pretokenizers = append(pretokenizers, `\d`)
} else {
pretokenizers = append(pretokenizers, `\d+`)
}
case "Punctuation":
pretokenizers = append(pretokenizers, `[^\p{L}\p{N}]+`)
case "Split":
pretokenizers = append(pretokenizers, pretokenizer.Pattern.Regex)
case "WhitespaceSplit":
pretokenizers = append(pretokenizers, `\s+(?!\S)|\s+`)
}
}
}
}
return NewBytePairEncoding(vocab, pretokenizers...), nil
}
// valueOrValues is a type that can unmarshal from either a single value or an array of values.
type valueOrValues[E any] []E
func (m *valueOrValues[E]) UnmarshalJSON(data []byte) error {
var s []E
if err := json.Unmarshal(data, &s); err != nil {
var e E
if err := json.Unmarshal(data, &e); err != nil {
return err
}
s = []E{e}
}
*m = valueOrValues[E](s)
return nil
}
type specialTokenIDs struct {
BOSTokenID valueOrValues[int32] `json:"bos_token_id"`
EOSTokenID valueOrValues[int32] `json:"eos_token_id"`
}
// stringOrContent is a type that can unmarshal from either a string or an object with a "content" field.
type stringOrContent string
func (t *stringOrContent) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return err
}
if content, ok := m["content"].(string); ok {
s = content
}
}
*t = stringOrContent(s)
return nil
}
func specialTokens(root *model.Root, values map[string]int32) (*Vocabulary, error) {
var vocab Vocabulary
for _, c := range []struct {
name string
fn func(io.Reader) error
}{
{
name: "generation_config.json",
fn: func(r io.Reader) error {
var c specialTokenIDs
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
vocab.BOS = c.BOSTokenID
vocab.EOS = c.EOSTokenID
return nil
},
},
{
name: "config.json",
fn: func(r io.Reader) error {
var c specialTokenIDs
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if len(vocab.BOS) == 0 {
vocab.BOS = c.BOSTokenID
}
if len(vocab.EOS) == 0 {
vocab.EOS = c.EOSTokenID
}
return nil
},
},
{
name: "tokenizer_config.json",
fn: func(r io.Reader) error {
var c struct {
BOSToken stringOrContent `json:"bos_token"`
EOSToken stringOrContent `json:"eos_token"`
PADToken stringOrContent `json:"pad_token"`
AddBOSToken bool `json:"add_bos_token"`
AddEOSToken bool `json:"add_eos_token"`
}
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if len(vocab.BOS) == 0 && c.BOSToken != "" {
if id, ok := values[string(c.BOSToken)]; ok {
vocab.BOS = []int32{id}
}
}
if len(vocab.EOS) == 0 && c.EOSToken != "" {
if id, ok := values[string(c.EOSToken)]; ok {
vocab.EOS = []int32{id}
}
}
vocab.AddBOS = c.AddBOSToken
vocab.AddEOS = c.AddEOSToken
return nil
},
},
{
name: "special_tokens_map.json",
fn: func(r io.Reader) error {
var c map[string]stringOrContent
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if bos, ok := c["bos_token"]; ok && len(vocab.BOS) == 0 {
if id, ok := values[string(bos)]; ok {
vocab.BOS = []int32{id}
}
}
if eos, ok := c["eos_token"]; ok && len(vocab.EOS) == 0 {
if id, ok := values[string(eos)]; ok {
vocab.EOS = []int32{id}
}
}
return nil
},
},
} {
if err := func() error {
f, err := root.Open(c.name)
if errors.Is(err, os.ErrNotExist) {
return nil
} else if err != nil {
return err
}
defer f.Close()
return c.fn(f)
}(); err != nil {
return nil, err
}
}
return &vocab, nil
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"log/slog"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"testing"

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"fmt"
@@ -32,7 +32,7 @@ var wordPieceReplacer = strings.NewReplacer(
" 're", "'re",
)
// Decode implements TextProcessor.
// Decode implements Tokenizer.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
@@ -96,7 +96,7 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
}
}
// Encode implements TextProcessor.
// Encode implements Tokenizer.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
@@ -151,17 +151,17 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
// Is implements TextProcessor.
// Is implements Tokenizer.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements TextProcessor.
// Vocabulary implements Tokenizer.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ TextProcessor = (*WordPiece)(nil)
var _ Tokenizer = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
return WordPiece{

View File

@@ -1,4 +1,4 @@
package model
package tokenizer
import (
"slices"

309
types/model/file.go Normal file
View File

@@ -0,0 +1,309 @@
package model
import (
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"io/fs"
"iter"
"maps"
"mime"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/envconfig"
)
func root() (*os.Root, error) {
root, err := os.OpenRoot(envconfig.Models())
if err != nil {
return nil, err
}
for _, sub := range []string{"manifests", "blobs"} {
if _, err := root.Stat(sub); errors.Is(err, fs.ErrNotExist) {
if err := root.MkdirAll(sub, 0o750); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
}
return root, nil
}
// Open opens an existing file for reading. It will return [fs.ErrNotExist]
// if the file does not exist. The returned [*Root] can only be used for reading.
// It is the caller's responsibility to close the file when done.
func Open(n Name) (*Root, error) {
r, err := root()
if err != nil {
return nil, err
}
f, err := r.Open(filepath.Join("manifests", n.Filepath()))
if err != nil {
return nil, err
}
defer f.Close()
var m manifest
if err := json.NewDecoder(f).Decode(&m); err != nil {
return nil, err
}
blobs := make(map[string]*blob, len(m.Layers)+1)
blobs[NamePrefix] = m.Config
for _, layer := range m.Layers {
if layer.Name == "" && layer.MediaType != "" {
mediatype, _, err := mime.ParseMediaType(layer.MediaType)
if err != nil {
return nil, err
}
if suffix, ok := strings.CutPrefix(mediatype, MediaTypePrefix); ok {
layer.Name = NamePrefix + suffix
}
}
blobs[layer.Name] = layer
}
return &Root{
root: r,
name: n,
blobs: blobs,
flags: os.O_RDONLY,
}, nil
}
// Create creates a new file. The returned [Root] can be used for both reading
// and writing. It is the caller's responsibility to close the file when done
// in order to finalize any new blobs and write the manifest.
func Create(n Name) (*Root, error) {
r, err := root()
if err != nil {
return nil, err
}
return &Root{
root: r,
name: n,
blobs: make(map[string]*blob),
flags: os.O_RDWR,
}, nil
}
type blob struct {
Digest string `json:"digest"`
MediaType string `json:"mediaType"`
Name string `json:"name,omitempty"`
Size int64 `json:"size"`
// tempfile is the temporary file where the blob data is written.
tempfile *os.File
// hash is the hash.Hash used to compute the blob digest.
hash hash.Hash
}
func (b *blob) Write(p []byte) (int, error) {
return io.MultiWriter(b.tempfile, b.hash).Write(p)
}
func (b *blob) Filepath() string {
return strings.ReplaceAll(b.Digest, ":", "-")
}
type manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *blob `json:"config"`
Layers []*blob `json:"layers"`
}
// Root represents a model file. It can be used to read and write blobs
// associated with the model.
//
// Blobs are identified by name. Certain names are special and reserved;
// see [NamePrefix] for details.
type Root struct {
root *os.Root
name Name
blobs map[string]*blob
flags int
}
const MediaTypePrefix = "application/vnd.ollama"
// NamePrefix is the prefix used for identifying special names. Names
// with this prefix are idenfitied by their media types:
//
// - name: NamePrefix + suffix
// - mediaType: [MediaTypePrefix] + suffix
//
// For example:
//
// - name: "./..image.model"
// - mediaType: "application/vnd.ollama.image.model"
//
// NamePrefix by itself identifies the manifest config.
const NamePrefix = "./."
// Open opens the named blob for reading. It is the caller's responsibility
// to close the returned [io.ReadCloser] when done. It will return
// [fs.ErrNotExist] if the blob does not exist.
func (r Root) Open(name string) (io.ReadCloser, error) {
if b, ok := r.blobs[name]; ok {
r, err := r.root.Open(filepath.Join("blobs", b.Filepath()))
if err != nil {
return nil, err
}
return r, nil
}
return nil, fs.ErrNotExist
}
func (r Root) ReadFile(name string) ([]byte, error) {
f, err := r.Open(name)
if err != nil {
return nil, err
}
defer f.Close()
return io.ReadAll(f)
}
// Create creates or replaces a named blob in the file. If the blob already
// exists, it will be overwritten. It will return [fs.ErrInvalid] if the file
// was opened in read-only mode. The returned [io.Writer] can be used to write
// to the blob and does not need be closed, but the file must be closed to
// finalize the blob.
func (r *Root) Create(name string) (io.Writer, error) {
if r.flags&os.O_RDWR != 0 {
w, err := os.CreateTemp(r.root.Name(), "")
if err != nil {
return nil, err
}
r.blobs[name] = &blob{Name: name, tempfile: w, hash: sha256.New()}
return r.blobs[name], nil
}
return nil, fs.ErrInvalid
}
// Close closes the file. If the file was opened in read-write mode, it
// will finalize any writeable blobs and write the manifest.
func (r *Root) Close() error {
if r.flags&os.O_RDWR != 0 {
for _, b := range r.blobs {
if b.tempfile != nil {
fi, err := b.tempfile.Stat()
if err != nil {
return err
}
if err := b.tempfile.Close(); err != nil {
return err
}
b.Size = fi.Size()
b.Digest = fmt.Sprintf("sha256:%x", b.hash.Sum(nil))
if suffix, ok := strings.CutPrefix(b.Name, NamePrefix); ok {
if b.Name == NamePrefix {
b.MediaType = "application/vnd.docker.container.image.v1+json"
} else {
b.MediaType = MediaTypePrefix + suffix
}
b.Name = ""
}
rel, err := filepath.Rel(r.root.Name(), b.tempfile.Name())
if err != nil {
return err
}
if err := r.root.Rename(rel, filepath.Join("blobs", b.Filepath())); err != nil {
return err
}
}
}
p := filepath.Join("manifests", r.name.Filepath())
if _, err := r.root.Stat(filepath.Dir(p)); errors.Is(err, os.ErrNotExist) {
if err := r.root.MkdirAll(filepath.Dir(p), 0o750); err != nil {
return err
}
} else if err != nil {
return err
}
f, err := r.root.OpenFile(p, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o640)
if err != nil {
return err
}
defer f.Close()
if err := json.NewEncoder(f).Encode(manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: r.blobs[NamePrefix],
Layers: func() []*blob {
blobs := make([]*blob, 0, len(r.blobs))
for name, b := range r.blobs {
if name != NamePrefix {
blobs = append(blobs, b)
}
}
return blobs
}(),
}); err != nil {
return err
}
}
return r.root.Close()
}
// Name returns the name of the file.
func (r Root) Name() Name {
return r.name
}
// Names returns an iterator over the names in the file.
func (r Root) Names() iter.Seq[string] {
return maps.Keys(r.blobs)
}
// Glob returns an iterator over the names in the file that match the given
// pattern.
//
// The pattern syntax is the same as [filepath.Match]. As with filepath.Match,
// the only possible returned error is ErrBadPattern, when pattern is malformed.
func (r Root) Glob(pattern string) (iter.Seq[string], error) {
if _, err := filepath.Match(pattern, ""); err != nil {
return nil, err
}
return func(yield func(string) bool) {
for name, blob := range r.blobs {
if matched, _ := filepath.Match(pattern, name); matched {
if !yield(blob.Filepath()) {
return
}
}
}
}, nil
}
func (r Root) JoinPath(parts ...string) string {
return filepath.Join(append([]string{r.root.Name()}, parts...)...)
}

90
types/model/file_test.go Normal file
View File

@@ -0,0 +1,90 @@
package model
import (
"io"
"strings"
"testing"
)
// setup is a helper function to set up the test environment.
func setup(t *testing.T, models map[Name]map[string]io.Reader) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
for m, s := range models {
f, err := Create(m)
if err != nil {
t.Fatal(err)
}
for n, r := range s {
w, err := f.Create(n)
if err != nil {
t.Fatal(err)
}
if _, err := io.Copy(w, r); err != nil {
t.Fatal(err)
}
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
}
}
func TestOpen(t *testing.T) {
setup(t, map[Name]map[string]io.Reader{
ParseName("namespace/model"): {
"./.": strings.NewReader(`{"key":"value"}`),
},
ParseName("namespace/model:8b"): {
"./.": strings.NewReader(`{"foo":"bar"}`),
},
ParseName("another/model"): {
"./.": strings.NewReader(`{"another":"config"}`),
},
})
f, err := Open(ParseName("namespace/model"))
if err != nil {
t.Fatal(err)
}
for _, name := range []string{"./."} {
r, err := f.Open(name)
if err != nil {
t.Fatal(err)
}
if _, err := io.ReadAll(r); err != nil {
t.Fatal(err)
}
if err := r.Close(); err != nil {
t.Fatal(err)
}
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
t.Run("does not exist", func(t *testing.T) {
if _, err := Open(ParseName("namespace/unknown")); err == nil {
t.Error("expected error for unknown model")
}
})
t.Run("write", func(t *testing.T) {
f, err := Open(ParseName("namespace/model"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := f.Create("new-blob"); err == nil {
t.Error("expected error creating blob in read-only mode")
}
})
}

33
types/model/files.go Normal file
View File

@@ -0,0 +1,33 @@
package model
import (
"io/fs"
"iter"
"path/filepath"
)
func All() (iter.Seq[Name], error) {
r, err := root()
if err != nil {
return nil, err
}
manifests, err := r.OpenRoot("manifests")
if err != nil {
return nil, err
}
matches, err := fs.Glob(manifests.FS(), "*/*/*/*")
if err != nil {
return nil, err
}
return func(yield func(Name) bool) {
for _, match := range matches {
name := ParseNameFromFilepath(filepath.ToSlash(match))
if !yield(name) {
return
}
}
}, nil
}

View File

@@ -227,6 +227,17 @@ func (n Name) String() string {
return b.String()
}
// Set implements [flag.Value]. It parses the provided input as a name string
// and sets the receiver to the parsed value. If the parsed name is not valid,
// ErrUnqualifiedName is returned.
func (n *Name) Set(s string) error {
*n = ParseName(s)
if !n.IsValid() {
return ErrUnqualifiedName
}
return nil
}
// DisplayShortest returns a short string version of the name.
func (n Name) DisplayShortest() string {
var sb strings.Builder

View File

@@ -1,6 +1,6 @@
//go:build mlx
package mlxrunner
package imagegen
import (
"context"
@@ -11,7 +11,7 @@ import (
"sync"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/zimage"
@@ -28,8 +28,8 @@ var imageGenMu sync.Mutex
func (s *server) loadImageModel() error {
// Check memory requirements before loading
var requiredMemory uint64
if manifest, err := imagegen.LoadManifest(s.modelName); err == nil {
requiredMemory = uint64(manifest.TotalTensorSize())
if modelManifest, err := manifest.LoadManifest(s.modelName); err == nil {
requiredMemory = uint64(modelManifest.TotalTensorSize())
}
availableMemory := mlx.GetMemoryLimit()
if availableMemory > 0 && requiredMemory > 0 && availableMemory < requiredMemory {
@@ -38,7 +38,7 @@ func (s *server) loadImageModel() error {
}
// Detect model type and load appropriate model
modelType := imagegen.DetectModelType(s.modelName)
modelType := DetectModelType(s.modelName)
slog.Info("detected image model type", "type", modelType)
var model ImageModel
@@ -108,7 +108,7 @@ func (s *server) handleImageCompletion(w http.ResponseWriter, r *http.Request, r
}
// Encode image as base64 PNG
imageData, err := imagegen.EncodeImageBase64(img)
imageData, err := EncodeImageBase64(img)
if err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
data, _ := json.Marshal(resp)

View File

@@ -1,6 +1,6 @@
//go:build mlx
package mlxrunner
package imagegen
import (
"encoding/json"
@@ -12,8 +12,8 @@ import (
"sync"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -197,13 +197,13 @@ func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
func (s *server) loadLLMModel() error {
// Load the manifest to get model information
manifest, err := imagegen.LoadManifest(s.modelName)
modelManifest, err := manifest.LoadManifest(s.modelName)
if err != nil {
return fmt.Errorf("failed to load manifest: %w", err)
}
// Detect model architecture from config.json
configData, err := manifest.ReadConfig("config.json")
configData, err := modelManifest.ReadConfig("config.json")
if err != nil {
return fmt.Errorf("failed to read config.json: %w", err)
}
@@ -232,7 +232,7 @@ func (s *server) loadLLMModel() error {
switch {
case strings.Contains(archLower, "glm4moelite"):
m, err := glm4_moe_lite.LoadFromManifest(manifest)
m, err := glm4_moe_lite.LoadFromManifest(modelManifest)
if err != nil {
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
}

View File

@@ -1,4 +1,4 @@
package imagegen
package manifest
import (
"encoding/json"

View File

@@ -1,4 +1,4 @@
package imagegen
package manifest
import (
"path/filepath"

View File

@@ -1,6 +1,6 @@
//go:build mlx
package imagegen
package manifest
import (
"fmt"
@@ -15,9 +15,9 @@ import (
type ManifestWeights struct {
manifest *ModelManifest
component string
tensors map[string]ManifestLayer // name -> layer
cache map[string]*mlx.Array // name -> loaded array
nativeCache []*mlx.SafetensorsFile // keep native handles alive
tensors map[string]ManifestLayer // name -> layer
cache map[string]*mlx.Array // name -> loaded array
nativeCache []*mlx.SafetensorsFile // keep native handles alive
}
// LoadWeightsFromManifest creates a weight loader from manifest storage.

View File

@@ -14,6 +14,8 @@ import (
"encoding/json"
"fmt"
"runtime"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// SupportedBackends lists the backends that support image generation.
@@ -41,8 +43,8 @@ func CheckPlatformSupport() error {
// ResolveModelName checks if a model name is a known image generation model.
// Returns the normalized model name if found, empty string otherwise.
func ResolveModelName(modelName string) string {
manifest, err := LoadManifest(modelName)
if err == nil && manifest.HasTensorLayers() {
modelManifest, err := manifest.LoadManifest(modelName)
if err == nil && modelManifest.HasTensorLayers() {
return modelName
}
return ""
@@ -52,12 +54,12 @@ func ResolveModelName(modelName string) string {
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
// Returns empty string if detection fails.
func DetectModelType(modelName string) string {
manifest, err := LoadManifest(modelName)
modelManifest, err := manifest.LoadManifest(modelName)
if err != nil {
return ""
}
data, err := manifest.ReadConfig("model_index.json")
data, err := modelManifest.ReadConfig("model_index.json")
if err != nil {
return ""
}

View File

@@ -12,7 +12,7 @@ import (
"math"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen3"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -61,7 +61,7 @@ func (m *Model) Load(modelName string) error {
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
manifest, err := manifest.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -14,19 +14,19 @@ import (
// TransformerConfig holds Flux2 transformer configuration
type TransformerConfig struct {
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
Eps float32 `json:"eps"` // 1e-6
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
InChannels int32 `json:"in_channels"` // 128
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
MLPRatio float32 `json:"mlp_ratio"` // 3.0
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
NumLayers int32 `json:"num_layers"` // 5
NumSingleLayers int32 `json:"num_single_layers"` // 20
PatchSize int32 `json:"patch_size"` // 1
RopeTheta int32 `json:"rope_theta"` // 2000
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
Eps float32 `json:"eps"` // 1e-6
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
InChannels int32 `json:"in_channels"` // 128
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
MLPRatio float32 `json:"mlp_ratio"` // 3.0
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
NumLayers int32 `json:"num_layers"` // 5
NumSingleLayers int32 `json:"num_single_layers"` // 20
PatchSize int32 `json:"patch_size"` // 1
RopeTheta int32 `json:"rope_theta"` // 2000
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
}
// Computed dimensions
@@ -392,12 +392,12 @@ type Flux2Transformer2DModel struct {
}
// Load loads the Flux2 transformer from ollama blob storage.
func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
func (m *Flux2Transformer2DModel) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.TransformerConfig = &cfg
@@ -412,7 +412,7 @@ func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
}
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -15,21 +15,21 @@ import (
// VAEConfig holds AutoencoderKLFlux2 configuration
type VAEConfig struct {
ActFn string `json:"act_fn"` // "silu"
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
ForceUpcast bool `json:"force_upcast"` // true
InChannels int32 `json:"in_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 32
LayersPerBlock int32 `json:"layers_per_block"` // 2
ActFn string `json:"act_fn"` // "silu"
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
ForceUpcast bool `json:"force_upcast"` // true
InChannels int32 `json:"in_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 32
LayersPerBlock int32 `json:"layers_per_block"` // 2
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
NormNumGroups int32 `json:"norm_num_groups"` // 32
OutChannels int32 `json:"out_channels"` // 3
PatchSize []int32 `json:"patch_size"` // [2, 2]
SampleSize int32 `json:"sample_size"` // 1024
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
UseQuantConv bool `json:"use_quant_conv"` // true
NormNumGroups int32 `json:"norm_num_groups"` // 32
OutChannels int32 `json:"out_channels"` // 3
PatchSize []int32 `json:"patch_size"` // [2, 2]
SampleSize int32 `json:"sample_size"` // 1024
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
UseQuantConv bool `json:"use_quant_conv"` // true
}
// BatchNorm2D implements 2D batch normalization with running statistics
@@ -356,18 +356,18 @@ func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
}
// Load loads the Flux2 VAE from ollama blob storage.
func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
func (m *AutoencoderKLFlux2) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading VAE... ")
// Load config from blob
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -9,8 +9,8 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -38,11 +38,11 @@ type Config struct {
AttentionBias bool `json:"attention_bias"`
// MLA (Multi-head Latent Attention) parameters
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
// MoE parameters
NRoutedExperts int32 `json:"n_routed_experts"`
@@ -82,7 +82,7 @@ type MLAAttention struct {
// Absorbed MLA projections (derived from kv_b_proj)
// EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
// UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
EmbedQ *nn.MultiLinear `weight:"-"`
EmbedQ *nn.MultiLinear `weight:"-"`
UnembedOut *nn.MultiLinear `weight:"-"`
// Output projection
@@ -194,8 +194,8 @@ func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
// MoEGate implements the expert gating mechanism
type MoEGate struct {
Gate nn.LinearLayer `weight:"mlp.gate"`
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
Gate nn.LinearLayer `weight:"mlp.gate"`
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
}
// Forward computes expert selection indices and scores
@@ -617,9 +617,9 @@ func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numE
}
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
// Read config from manifest
configData, err := manifest.ReadConfig("config.json")
configData, err := modelManifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
@@ -634,7 +634,7 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
cfg.Scale = computeScale(&cfg)
// Load weights from manifest blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "")
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
@@ -653,7 +653,7 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
}
// Load tokenizer from manifest with config files for EOS token detection
tokData, err := manifest.ReadConfig("tokenizer.json")
tokData, err := modelManifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
@@ -664,12 +664,12 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
}
// Try to load generation_config.json if available (preferred source for EOS)
if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
// Try to load tokenizer_config.json if available
if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}

View File

@@ -7,7 +7,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -181,19 +181,19 @@ type TextEncoder struct {
}
// Load loads the Qwen3 text encoder from ollama blob storage.
func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
func (m *TextEncoder) Load(modelManifest *manifest.ModelManifest, configPath string) error {
fmt.Print(" Loading text encoder... ")
// Load config from blob
var cfg Config
if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
if err := modelManifest.ReadConfigJSON(configPath, &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
m.Layers = make([]*Block, cfg.NumHiddenLayers)
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -7,8 +7,8 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -38,7 +38,7 @@ type TransformerConfig struct {
type TimestepEmbedder struct {
Linear1 nn.LinearLayer `weight:"mlp.0"`
Linear2 nn.LinearLayer `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed)
FreqEmbedSize int32 // 256 (computed)
}
// Forward computes timestep embeddings -> [B, 256]
@@ -85,9 +85,9 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"`
Linear nn.LinearLayer `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
Norm *nn.RMSNorm `weight:"0"`
Linear nn.LinearLayer `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
}
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
@@ -103,10 +103,9 @@ type FeedForward struct {
W1 nn.LinearLayer `weight:"w1"` // gate projection
W2 nn.LinearLayer `weight:"w2"` // down projection
W3 nn.LinearLayer `weight:"w3"` // up projection
OutDim int32 // computed from W2
OutDim int32 // computed from W2
}
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
@@ -132,11 +131,11 @@ type Attention struct {
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
// Fused QKV (computed at init time for efficiency, not loaded from weights)
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
Fused bool `weight:"-"` // Whether to use fused QKV path
Fused bool `weight:"-"` // Whether to use fused QKV path
// Computed fields (not loaded from weights)
NHeads int32 `weight:"-"`
HeadDim int32 `weight:"-"`
@@ -288,13 +287,13 @@ func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
// TransformerBlock is a single transformer block with optional AdaLN modulation
type TransformerBlock struct {
Attention *Attention `weight:"attention"`
FeedForward *FeedForward `weight:"feed_forward"`
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
Attention *Attention `weight:"attention"`
FeedForward *FeedForward `weight:"feed_forward"`
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
@@ -350,7 +349,7 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
type FinalLayer struct {
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output
OutDim int32 // computed from Output
}
// Forward computes final output
@@ -401,12 +400,12 @@ type Transformer struct {
}
// Load loads the Z-Image transformer from ollama blob storage.
func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
func (m *Transformer) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
if len(cfg.AllPatchSize) > 0 {
@@ -417,7 +416,7 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/vae"
@@ -562,7 +562,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
if ub.Upsample != nil {
// Stage 1: Upsample2x (nearest neighbor)
{
prev := x
prev := x
x = Upsample2x(x)
prev.Free()
mlx.Eval(x)
@@ -570,7 +570,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
// Stage 2: Upsample conv
{
prev := x
prev := x
x = ub.Upsample.Forward(x)
prev.Free()
mlx.Eval(x)
@@ -643,16 +643,16 @@ type VAEDecoder struct {
}
// Load loads the VAE decoder from ollama blob storage.
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
func (m *VAEDecoder) Load(modelManifest *manifest.ModelManifest) error {
// Load config from blob
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -8,8 +8,8 @@ import (
"fmt"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/imagegen/vae"
@@ -18,14 +18,14 @@ import (
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
CapturePath string // GPU capture path (debug)
CapturePath string // GPU capture path (debug)
// TeaCache options (timestep embedding aware caching)
TeaCache bool // TeaCache is always enabled for faster inference
@@ -58,7 +58,7 @@ func (m *Model) Load(modelName string) error {
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
manifest, err := manifest.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}

203
x/imagegen/runner.go Normal file
View File

@@ -0,0 +1,203 @@
//go:build mlx
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
package imagegen
import (
"context"
"encoding/json"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// Execute is the entry point for the unified MLX runner subprocess.
func Execute(args []string) error {
// Set up logging with appropriate level from environment
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: envconfig.LogLevel()})))
fs := flag.NewFlagSet("mlx-runner", flag.ExitOnError)
modelName := fs.String("model", "", "path to model")
port := fs.Int("port", 0, "port to listen on")
if err := fs.Parse(args); err != nil {
return err
}
if *modelName == "" {
return fmt.Errorf("--model is required")
}
if *port == 0 {
return fmt.Errorf("--port is required")
}
// Initialize MLX
if err := mlx.InitMLX(); err != nil {
slog.Error("unable to initialize MLX", "error", err)
return err
}
slog.Info("MLX library initialized")
// Detect model type from capabilities
mode := detectModelMode(*modelName)
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
// Create and start server
server, err := newServer(*modelName, *port, mode)
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
}
// Set up HTTP handlers
mux := http.NewServeMux()
mux.HandleFunc("/health", server.healthHandler)
mux.HandleFunc("/completion", server.completionHandler)
// LLM-specific endpoints
if mode == ModeLLM {
mux.HandleFunc("/tokenize", server.tokenizeHandler)
mux.HandleFunc("/embedding", server.embeddingHandler)
}
httpServer := &http.Server{
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
Handler: mux,
}
// Handle shutdown
done := make(chan struct{})
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
slog.Info("shutting down mlx runner")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpServer.Shutdown(ctx)
close(done)
}()
slog.Info("mlx runner listening", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
return err
}
<-done
return nil
}
// detectModelMode determines whether a model is an LLM or image generation model.
func detectModelMode(modelName string) ModelMode {
// Check for image generation model by looking at model_index.json
modelType := DetectModelType(modelName)
if modelType != "" {
// Known image generation model types
switch modelType {
case "ZImagePipeline", "FluxPipeline", "Flux2KleinPipeline":
return ModeImageGen
}
}
// Default to LLM mode for safetensors models without known image gen types
return ModeLLM
}
// server holds the model and handles HTTP requests.
type server struct {
mode ModelMode
modelName string
port int
// Image generation model (when mode == ModeImageGen)
imageModel ImageModel
// LLM model (when mode == ModeLLM)
llmModel *llmState
}
// newServer creates a new server instance and loads the appropriate model.
func newServer(modelName string, port int, mode ModelMode) (*server, error) {
s := &server{
mode: mode,
modelName: modelName,
port: port,
}
switch mode {
case ModeImageGen:
if err := s.loadImageModel(); err != nil {
return nil, fmt.Errorf("failed to load image model: %w", err)
}
case ModeLLM:
if err := s.loadLLMModel(); err != nil {
return nil, fmt.Errorf("failed to load LLM model: %w", err)
}
}
return s, nil
}
func (s *server) healthHandler(w http.ResponseWriter, r *http.Request) {
resp := HealthResponse{Status: "ok"}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch s.mode {
case ModeImageGen:
s.handleImageCompletion(w, r, req)
case ModeLLM:
s.handleLLMCompletion(w, r, req)
}
}
func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) {
if s.llmModel == nil {
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
return
}
var req struct {
Content string `json:"content"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
tok := s.llmModel.model.Tokenizer()
tokens := tok.Encode(req.Content, false)
// Convert int32 to int for JSON response
intTokens := make([]int, len(tokens))
for i, t := range tokens {
intTokens[i] = int(t)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens})
}
func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented)
}

View File

@@ -1,6 +1,6 @@
//go:build !mlx
package mlxrunner
package imagegen
import "errors"

471
x/imagegen/server.go Normal file
View File

@@ -0,0 +1,471 @@
package imagegen
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
//
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
// like any other model. It supports both LLM (safetensors) and image generation models.
type Server struct {
mu sync.Mutex
cmd *exec.Cmd
port int
modelName string
mode ModelMode
vramSize uint64
done chan error
client *http.Client
lastErr string // Last stderr line for error reporting
lastErrLock sync.Mutex
}
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
func NewServer(modelName string, mode ModelMode) (*Server, error) {
// Validate platform support before attempting to start
if err := CheckPlatformSupport(); err != nil {
return nil, err
}
// Find a free port
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
if l, err := net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
port = rand.Intn(65535-49152) + 49152
}
// Get the current executable path (we use the same binary with runner subcommand)
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
if runtime.GOOS == "linux" {
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
// Append existing LD_LIBRARY_PATH if set
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add LD_LIBRARY_PATH in cmd.Env
found := false
for i := range cmd.Env {
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
// Estimate VRAM based on tensor size from manifest
var vramSize uint64
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
vramSize = uint64(modelManifest.TotalTensorSize())
} else {
// Fallback: default to 8GB if manifest can't be loaded
vramSize = 8 * 1024 * 1024 * 1024
}
s := &Server{
cmd: cmd,
port: port,
modelName: modelName,
mode: mode,
vramSize: vramSize,
done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute},
}
// Forward subprocess stdout/stderr to server logs
stdout, _ := cmd.StdoutPipe()
stderr, _ := cmd.StderrPipe()
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
slog.Info("mlx-runner", "msg", scanner.Text())
}
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
slog.Warn("mlx-runner", "msg", line)
s.lastErrLock.Lock()
s.lastErr = line
s.lastErrLock.Unlock()
}
}()
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port, "mode", mode)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
}
// Reap subprocess when it exits
go func() {
err := cmd.Wait()
s.done <- err
}()
// Wait for subprocess to be ready
if err := s.waitUntilRunning(); err != nil {
s.Close()
return nil, err
}
return s, nil
}
// ModelPath returns the path to the model.
func (s *Server) ModelPath() string {
return s.modelName
}
// Load satisfies the LlamaServer interface. MLX models don't need GPU layer assignment.
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
return nil, nil
}
// Ping checks if the subprocess is healthy.
func (s *Server) Ping(ctx context.Context) error {
url := fmt.Sprintf("http://127.0.0.1:%d/health", s.port)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
resp, err := s.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("health check failed: %d", resp.StatusCode)
}
return nil
}
// waitUntilRunning waits for the subprocess to be ready.
func (s *Server) waitUntilRunning() error {
ctx := context.Background()
timeout := time.After(2 * time.Minute)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case err := <-s.done:
// Include recent stderr lines for better error context
errMsg := s.getLastErr()
if errMsg != "" {
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
}
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
case <-timeout:
errMsg := s.getLastErr()
if errMsg != "" {
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
}
return errors.New("timeout waiting for mlx runner to start")
case <-ticker.C:
if err := s.Ping(ctx); err == nil {
slog.Info("mlx runner is ready", "port", s.port)
return nil
}
}
}
}
// getLastErr returns the last stderr line.
func (s *Server) getLastErr() string {
s.lastErrLock.Lock()
defer s.lastErrLock.Unlock()
return s.lastErr
}
// WaitUntilRunning satisfies the LlamaServer interface.
func (s *Server) WaitUntilRunning(ctx context.Context) error {
return nil
}
// Completion handles both text and image generation requests.
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
seed := req.Seed
if seed == 0 {
seed = time.Now().UnixNano()
}
// Extract raw image bytes from llm.ImageData slice
var images [][]byte
for _, img := range req.Images {
images = append(images, img.Data)
}
// Build request for subprocess
creq := Request{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: int(req.Steps),
Seed: seed,
Images: images,
}
// Pass LLM options if present
if req.Options != nil {
creq.Options = &RequestOptions{
NumPredict: req.Options.NumPredict,
Temperature: float64(req.Options.Temperature),
TopP: float64(req.Options.TopP),
TopK: req.Options.TopK,
Stop: req.Options.Stop,
}
}
body, err := json.Marshal(creq)
if err != nil {
return err
}
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(httpReq)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%s", strings.TrimSpace(string(body)))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() {
// Parse subprocess response
var raw struct {
Image string `json:"image,omitempty"`
Content string `json:"content,omitempty"`
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue
}
// Log stop reason when generation completes
if raw.Done && raw.StopReason != "" {
slog.Info("mlx generation completed", "stop_reason", raw.StopReason)
}
// Convert to llm.CompletionResponse
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
Step: raw.Step,
TotalSteps: raw.Total,
Image: raw.Image,
PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
EvalCount: raw.EvalCount,
EvalDuration: time.Duration(raw.EvalDuration),
}
fn(cresp)
if cresp.Done {
return nil
}
}
// Scanner exited without receiving Done - connection was likely closed
scanErr := scanner.Err()
if scanErr != nil {
slog.Error("mlx scanner error", "error", scanErr)
} else {
slog.Warn("mlx scanner EOF without Done response - subprocess may have crashed")
}
// Check if subprocess is still alive
if s.HasExited() {
slog.Error("mlx subprocess has exited unexpectedly")
}
return scanErr
}
// Close terminates the subprocess.
func (s *Server) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cmd != nil && s.cmd.Process != nil {
slog.Info("stopping mlx runner subprocess", "pid", s.cmd.Process.Pid)
s.cmd.Process.Signal(os.Interrupt)
// Wait briefly for graceful shutdown
select {
case <-s.done:
case <-time.After(5 * time.Second):
s.cmd.Process.Kill()
}
s.cmd = nil
}
return nil
}
// VRAMSize returns the estimated VRAM usage.
func (s *Server) VRAMSize() uint64 {
return s.vramSize
}
// TotalSize returns the total memory usage.
func (s *Server) TotalSize() uint64 {
return s.vramSize
}
// VRAMByGPU returns VRAM usage for a specific GPU.
func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
return s.vramSize
}
// ContextLength returns the context length (not applicable for image generation).
func (s *Server) ContextLength() int {
return 0
}
// Embedding returns embeddings for the input.
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
return nil, 0, errors.New("embeddings not supported for MLX models")
}
// Tokenize tokenizes the input content.
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
body, err := json.Marshal(map[string]string{"content": content})
if err != nil {
return nil, err
}
url := fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("tokenize failed: %d", resp.StatusCode)
}
var result struct {
Tokens []int `json:"tokens"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return result.Tokens, nil
}
// Detokenize converts tokens back to text.
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
return "", errors.New("detokenization not supported for MLX models")
}
// Pid returns the process ID of the subprocess.
func (s *Server) Pid() int {
s.mu.Lock()
defer s.mu.Unlock()
if s.cmd != nil && s.cmd.Process != nil {
return s.cmd.Process.Pid
}
return -1
}
// GetPort returns the port the subprocess is listening on.
func (s *Server) GetPort() int {
return s.port
}
// GetDeviceInfos returns device information.
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil
}
// HasExited returns whether the subprocess has exited.
func (s *Server) HasExited() bool {
select {
case <-s.done:
return true
default:
return false
}
}
// Ensure Server implements llm.LlamaServer
var _ llm.LlamaServer = (*Server)(nil)

View File

@@ -1,9 +1,9 @@
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
//
// This package handles safetensors models created with `ollama create --experimental`,
// supporting both text generation (LLM) and image generation (diffusion) models
// through a single unified interface.
package mlxrunner
package imagegen
// Request is the request format for completion requests.
type Request struct {

View File

@@ -1,77 +0,0 @@
package kvcache
import (
"errors"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
var (
ErrKvCacheFull = errors.New("could not find a kv cache slot")
ErrNotSupported = errors.New("model does not support operation")
)
type Cache interface {
// ** used by model implementations **
// SetLayer sets the active layer of the cache
SetLayer(layer int)
// Get returns the history of key and value tensors plus a mask
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
// Put stores a batch of key and value in the cache
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management **
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it
Close()
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. reserve is to preallocate memory
// without actually storing data in the cache.
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)
// CanResume returns true if the cache can continue with the next token at
// the given position and sequence. Assumes that the caller has already
// verified the contents of the cache.
CanResume(seq int, pos int32) bool
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
//
// If an error occurs, the entire context for the sequence should be
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}

View File

@@ -1,144 +0,0 @@
//go:build mlx
package kvcache
import (
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
type Causal struct {
DType ml.DType
// locations for data storage for this batch
curLocPut ml.Tensor
// locations for data storage for this batch
curLocGet ml.Tensor
// the active layer for Get and Put
curLayer int
capacity int
offset int
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
// TODO is this needed per layer, or will it always be consistent?
kHeadDims, vHeadDims, numKVHeads map[int]int
}
func NewCausalCache() *Causal {
return &Causal{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
kHeadDims: make(map[int]int),
vHeadDims: make(map[int]int),
numKVHeads: make(map[int]int),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.DType = dtype
c.capacity = capacity
c.backend = backend
}
func (c *Causal) SetConfig(config ml.CacheConfig) {}
func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
func (c *Causal) Close() {
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
locsPut := make([]int32, len(batch.Positions))
for i := c.offset; i < len(batch.Positions); i++ {
locsPut[i-c.offset] = int32(i)
}
c.offset += len(batch.Positions)
locsGet := make([]int32, c.offset)
for i := range c.offset {
locsGet[i] = int32(i)
}
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
// slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
return nil
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(3)
vHeadDim := value.Dim(3)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
kCellSize := kHeadDim * numKVHeads
vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
c.kHeadDims[c.curLayer] = kHeadDim
c.vHeadDims[c.curLayer] = vHeadDim
c.numKVHeads[c.curLayer] = numKVHeads
}
key = key.Reshape(ctx, batchSize, 1, kCellSize)
// slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
// slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
// slog.Info("XXX Causal.Put ", "key", key)
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
value = value.Reshape(ctx, batchSize, 1, vCellSize)
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
}
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := c.kHeadDims[c.curLayer]
vHeadDim := c.vHeadDims[c.curLayer]
numKVHeads := c.numKVHeads[c.curLayer]
// rowSize := numKVHeads * c.curBatchSize
// cachedSize := c.curMask.Dim(1)
cachedSize := c.curLocGet.Dim(0)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
return key, value, nil
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("not implemented")
}
func (c *Causal) CanResume(seq int, pos int32) bool {
panic("not implemented")
}
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
panic("not implemented")
}

View File

@@ -1,156 +0,0 @@
package kvcache
// import (
// "fmt"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// // Encoder cache stores K and V tensors that are position independent
// //
// // The tensors can be of any shape and will be returned as they were stored
// // The mask is currently always nil
// //
// // Not currently safe for multiple sequences
// type EncoderCache struct {
// // config controls mostly backend-specific optimizations
// config *ml.CacheConfig
// // ** current forward pass **
// // the active layer for Get and Put
// curLayer int
// // if something is stored during this pass, this
// // will be the position (but there is no guarantee
// // anything will be stored)
// curPos int32
// // curReserve indicates that this forward pass is only for
// // memory reservation and we should not update our metadata
// // based on it.
// curReserve bool
// // ** cache metadata **
// // was something stored in the cache?
// encoderCached bool
// // position of the cached data
// encoderPos int32
// // ** cache data storage **
// backend ml.Backend
// ctxs map[int]ml.Context
// keys, values map[int]ml.Tensor
// }
// func NewEncoderCache() *EncoderCache {
// return &EncoderCache{
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// }
// }
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// if c.config == nil {
// var config ml.CacheConfig
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
// config = cc.CacheConfig()
// }
// c.config = &config
// }
// if maxSequences > 1 {
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
// }
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
// }
// c.backend = backend
// }
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
// if c.config != nil {
// panic("config cannot be changed after being previously set, either by the model or backend")
// }
// c.config = &config
// }
// func (c *EncoderCache) Close() {
// for _, ctx := range c.ctxs {
// ctx.Close()
// }
// }
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// // We work with the most recent image
// if len(batch.Multimodal) > 0 {
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
// }
// c.curReserve = reserve
// return nil
// }
// func (c *EncoderCache) SetLayer(layer int) {
// c.curLayer = layer
// }
// func (c *EncoderCache) EncoderCached() bool {
// return c.encoderCached
// }
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// return c.keys[c.curLayer], c.values[c.curLayer], nil
// }
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
// if !c.curReserve {
// c.encoderPos = c.curPos
// c.encoderCached = true
// }
// if c.config.PermutedV {
// value = value.Transpose(ctx, 1, 2, 0, 3)
// }
// if _, ok := c.ctxs[c.curLayer]; !ok {
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
// }
// if _, ok := c.keys[c.curLayer]; !ok {
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
// }
// if _, ok := c.values[c.curLayer]; !ok {
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
// }
// ctx.Forward(
// key.Copy(ctx, c.keys[c.curLayer]),
// value.Copy(ctx, c.values[c.curLayer]),
// )
// }
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// panic("encoder cache does not support multiple sequences")
// }
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
// return true
// }
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
// c.encoderCached = false
// }
// return nil
// }

View File

@@ -1,110 +0,0 @@
package kvcache
// import (
// "math"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// // Wrapper cache is a container for multiple types of caches,
// // such as for the encoding and decoding portions of a model.
// type WrapperCache struct {
// // caches we are wrapping
// caches []Cache
// // cache to be used for this layer
// curType int
// }
// func NewWrapperCache(caches ...Cache) *WrapperCache {
// return &WrapperCache{
// caches: caches,
// }
// }
// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// for _, cache := range c.caches {
// cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
// }
// }
// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
// for _, cache := range c.caches {
// cache.SetConfig(config)
// }
// }
// func (c *WrapperCache) Close() {
// for _, cache := range c.caches {
// cache.Close()
// }
// }
// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// for i, cache := range c.caches {
// err := cache.StartForward(ctx, batch, reserve)
// if err != nil {
// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
// for j := i - 1; j >= 0; j-- {
// for k := range batch.Positions {
// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
// }
// }
// return err
// }
// }
// c.curType = 0
// return nil
// }
// func (c *WrapperCache) SetLayer(layer int) {
// for _, cache := range c.caches {
// cache.SetLayer(layer)
// }
// }
// func (c *WrapperCache) SetLayerType(layerType int) {
// c.curType = layerType
// }
// func (c *WrapperCache) UnderlyingCache() Cache {
// return c.caches[c.curType]
// }
// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// return c.caches[c.curType].Get(ctx)
// }
// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
// c.caches[c.curType].Put(ctx, key, value)
// }
// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// for _, cache := range c.caches {
// cache.CopyPrefix(srcSeq, dstSeq, len)
// }
// }
// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
// for _, cache := range c.caches {
// if !cache.CanResume(seq, pos) {
// return false
// }
// }
// return true
// }
// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
// for _, cache := range c.caches {
// err := cache.Remove(seq, beginIndex, endIndex)
// if err != nil {
// return err
// }
// }
// return nil
// }

Some files were not shown because too many files have changed in this diff Show More