Compare commits

..

1 Commits

Author SHA1 Message Date
jmorganca
c330ea33ed qwen3next: handle mixed recurrent batches
Allow mixed token-count batches by tracking per-seq indices

and falling back to per-seq recurrent processing when layouts

differ.

Add per-slot conv/delta state access with checkpoint capture,

relax attention layout handling, and reuse projections in mixed

batches to reduce overhead.
2026-02-05 11:50:00 -08:00
106 changed files with 7117 additions and 3024 deletions

View File

@@ -1,22 +0,0 @@
name: test-install
on:
pull_request:
paths:
- 'scripts/install.sh'
- '.github/workflows/test-install.yaml'
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Run install script
run: sh ./scripts/install.sh
env:
OLLAMA_NO_START: 1 # do not start app
- name: Verify ollama is available
run: ollama --version

View File

@@ -518,26 +518,24 @@ func mapStopReason(reason string, hasToolCalls bool) string {
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
estimatedInputTokens: estimatedInputTokens,
toolCallsSent: make(map[string]bool),
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
@@ -553,11 +551,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
if c.firstWrite {
c.firstWrite = false
// Use actual metrics if available, otherwise use estimate
c.inputTokens = r.Metrics.PromptEvalCount
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
c.inputTokens = c.estimatedInputTokens
}
events = append(events, StreamEvent{
Event: "message_start",
@@ -785,123 +779,3 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
}
return args
}
// CountTokensRequest represents an Anthropic count_tokens request
type CountTokensRequest struct {
Model string `json:"model"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
}
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
func EstimateInputTokens(req MessagesRequest) int {
return estimateTokens(CountTokensRequest{
Model: req.Model,
Messages: req.Messages,
System: req.System,
Tools: req.Tools,
Thinking: req.Thinking,
})
}
// CountTokensResponse represents an Anthropic count_tokens response
type CountTokensResponse struct {
InputTokens int `json:"input_tokens"`
}
// estimateTokens returns a rough estimate of tokens (len/4).
// TODO: Replace with actual tokenization via Tokenize API for accuracy.
// Current len/4 heuristic is a rough approximation (~4 chars/token average).
func estimateTokens(req CountTokensRequest) int {
var totalLen int
// Count system prompt
if req.System != nil {
totalLen += countAnyContent(req.System)
}
// Count messages
for _, msg := range req.Messages {
// Count role (always present)
totalLen += len(msg.Role)
// Count content
contentLen := countAnyContent(msg.Content)
totalLen += contentLen
}
for _, tool := range req.Tools {
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
}
// Return len/4 as rough token estimate, minimum 1 if there's any content
tokens := totalLen / 4
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
tokens = 1
}
return tokens
}
func countAnyContent(content any) int {
if content == nil {
return 0
}
switch c := content.(type) {
case string:
return len(c)
case []any:
total := 0
for _, block := range c {
total += countContentBlock(block)
}
return total
default:
if data, err := json.Marshal(content); err == nil {
return len(data)
}
return 0
}
}
func countContentBlock(block any) int {
blockMap, ok := block.(map[string]any)
if !ok {
if s, ok := block.(string); ok {
return len(s)
}
return 0
}
total := 0
blockType, _ := blockMap["type"].(string)
if text, ok := blockMap["text"].(string); ok {
total += len(text)
}
if thinking, ok := blockMap["thinking"].(string); ok {
total += len(thinking)
}
if blockType == "tool_use" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data)
}
}
if blockType == "tool_result" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data)
}
}
if source, ok := blockMap["source"].(map[string]any); ok {
if data, ok := source["data"].(string); ok {
total += len(data)
}
}
return total
}

View File

@@ -321,6 +321,8 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
@@ -603,7 +605,7 @@ func TestGenerateMessageID(t *testing.T) {
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0)
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
@@ -676,7 +678,7 @@ func TestStreamConverter_Basic(t *testing.T) {
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0)
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
@@ -729,7 +731,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model", 0)
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
@@ -776,7 +778,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model", 0)
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
@@ -840,6 +842,10 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
@@ -893,9 +899,11 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0)
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
@@ -929,7 +937,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0)
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
@@ -961,105 +969,3 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
}
})
}
func TestEstimateTokens_SimpleMessage(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "Hello, world!"},
},
}
tokens := estimateTokens(req)
// "user" (4) + "Hello, world!" (13) = 17 chars / 4 = 4 tokens
if tokens < 1 {
t.Errorf("expected at least 1 token, got %d", tokens)
}
// Sanity check: shouldn't be wildly off
if tokens > 10 {
t.Errorf("expected fewer than 10 tokens for short message, got %d", tokens)
}
}
func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
tokens := estimateTokens(req)
// System prompt adds to count
if tokens < 5 {
t.Errorf("expected at least 5 tokens with system prompt, got %d", tokens)
}
}
func TestEstimateTokens_WithTools(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "What's the weather?"},
},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get the current weather for a location",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
},
},
}
tokens := estimateTokens(req)
// Tools add significant content
if tokens < 10 {
t.Errorf("expected at least 10 tokens with tools, got %d", tokens)
}
}
func TestEstimateTokens_WithThinking(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this carefully...",
},
map[string]any{
"type": "text",
"text": "Here is my response.",
},
},
},
},
}
tokens := estimateTokens(req)
// Thinking content should be counted
if tokens < 10 {
t.Errorf("expected at least 10 tokens with thinking content, got %d", tokens)
}
}
func TestEstimateTokens_EmptyContent(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{},
}
tokens := estimateTokens(req)
if tokens != 0 {
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
}
}

View File

@@ -466,25 +466,3 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
}
return &resp, nil
}
// AliasRequest is the request body for creating or updating a model alias.
type AliasRequest struct {
Alias string `json:"alias"`
Target string `json:"target"`
PrefixMatching bool `json:"prefix_matching,omitempty"`
}
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
}
// AliasDeleteRequest is the request body for deleting a model alias.
type AliasDeleteRequest struct {
Alias string `json:"alias"`
}
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
}

View File

@@ -1763,7 +1763,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return err
}
if err := startApp(cmd.Context(), client); err != nil {
return err
return fmt.Errorf("ollama server not responding - %w", err)
}
}
return nil

View File

@@ -1,23 +1,18 @@
package config
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner and AliasConfigurer for Claude Code integration
// Claude implements Runner for Claude Code integration
type Claude struct{}
// Compile-time check that Claude implements AliasConfigurer
var _ AliasConfigurer = (*Claude)(nil)
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extra []string) []string {
@@ -65,104 +60,3 @@ func (c *Claude) Run(model string, args []string) error {
)
return cmd.Run()
}
// ConfigureAliases sets up model aliases for Claude Code.
// model: the model to use (if empty, user will be prompted to select)
// aliases: existing alias configuration to preserve/update
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
// there is a better strategy for prompt caching on local models.
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
aliases := make(map[string]string)
for k, v := range existingAliases {
aliases[k] = v
}
if model != "" {
aliases["primary"] = model
}
if !force && aliases["primary"] != "" {
client, _ := api.ClientFromEnvironment()
if isCloudModel(ctx, client, aliases["primary"]) {
if isCloudModel(ctx, client, aliases["fast"]) {
return aliases, false, nil
}
} else {
delete(aliases, "fast")
return aliases, false, nil
}
}
items, existingModels, cloudModels, client, err := listModels(ctx)
if err != nil {
return nil, false, err
}
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
if aliases["primary"] == "" || force {
primary, err := selectPrompt("Select model:", items)
fmt.Fprintf(os.Stderr, "\033[3A\033[J")
if err != nil {
return nil, false, err
}
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
return nil, false, err
}
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
return nil, false, err
}
aliases["primary"] = primary
}
if isCloudModel(ctx, client, aliases["primary"]) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = aliases["primary"]
}
} else {
delete(aliases, "fast")
}
return aliases, true, nil
}
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
// prevent stale routing to a previous cloud model.
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
if aliases["fast"] == "" {
for _, prefix := range prefixes {
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
}
return nil
}
prefixAliases := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
var errs []string
for prefix, target := range prefixAliases {
req := &api.AliasRequest{
Alias: prefix,
Target: target,
PrefixMatching: true,
}
if err := client.SetAliasExperimental(ctx, req); err != nil {
errs = append(errs, prefix)
}
}
if len(errs) > 0 {
return fmt.Errorf("failed to set aliases: %v", errs)
}
return nil
}

View File

@@ -13,8 +13,7 @@ import (
)
type integration struct {
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
Models []string `json:"models"`
}
type config struct {
@@ -134,16 +133,8 @@ func saveIntegration(appName string, models []string) error {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
var aliases map[string]string
if existing != nil && existing.Aliases != nil {
aliases = existing.Aliases
}
cfg.Integrations[key] = &integration{
Models: models,
Aliases: aliases,
cfg.Integrations[strings.ToLower(appName)] = &integration{
Models: models,
}
return save(cfg)
@@ -163,29 +154,6 @@ func loadIntegration(appName string) (*integration, error) {
return ic, nil
}
func saveAliases(appName string, aliases map[string]string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}
cfg, err := load()
if err != nil {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
if existing == nil {
existing = &integration{}
}
// Replace aliases entirely (not merge) so deletions are persisted
existing.Aliases = aliases
cfg.Integrations[key] = existing
return save(cfg)
}
func listIntegrations() ([]integration, error) {
cfg, err := load()
if err != nil {

View File

@@ -1,677 +0,0 @@
package config
import (
"context"
"errors"
"os"
"path/filepath"
"testing"
)
func TestSetAliases_CloudModel(t *testing.T) {
// Test the SetAliases logic by checking the alias map behavior
aliases := map[string]string{
"primary": "kimi-k2.5:cloud",
"fast": "kimi-k2.5:cloud",
}
// Verify fast is set (cloud model behavior)
if aliases["fast"] == "" {
t.Error("cloud model should have fast alias set")
}
if aliases["fast"] != aliases["primary"] {
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
}
}
func TestSetAliases_LocalModel(t *testing.T) {
aliases := map[string]string{
"primary": "llama3.2:latest",
}
// Simulate local model behavior: fast should be empty
delete(aliases, "fast")
if aliases["fast"] != "" {
t.Error("local model should have empty fast alias")
}
}
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// First save with both primary and fast
initial := map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}
if err := saveAliases("claude", initial); err != nil {
t.Fatalf("failed to save initial aliases: %v", err)
}
// Verify both are saved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
}
// Now save without fast (simulating switch to local model)
updated := map[string]string{
"primary": "local-model",
// fast intentionally missing
}
if err := saveAliases("claude", updated); err != nil {
t.Fatalf("failed to save updated aliases: %v", err)
}
// Verify fast is GONE (not merged/preserved)
loaded, err = loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
}
if loaded.Aliases["primary"] != "local-model" {
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
}
}
func TestSaveAliases_PreservesModels(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// First save integration with models
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
t.Fatalf("failed to save integration: %v", err)
}
// Then update aliases
aliases := map[string]string{"primary": "new-model"}
if err := saveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save aliases: %v", err)
}
// Verify models are preserved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
t.Errorf("models should be preserved, got %v", loaded.Models)
}
}
// TestSaveAliases_EmptyMap clears all aliases
func TestSaveAliases_EmptyMap(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save with aliases
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save empty map
if err := saveAliases("claude", map[string]string{}); err != nil {
t.Fatalf("failed to save empty: %v", err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Aliases) != 0 {
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
}
}
// TestSaveAliases_NilMap handles nil gracefully
func TestSaveAliases_NilMap(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save with aliases first
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save nil map - should clear aliases
if err := saveAliases("claude", nil); err != nil {
t.Fatalf("failed to save nil: %v", err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Aliases) > 0 {
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
}
}
// TestSaveAliases_EmptyAppName returns error
func TestSaveAliases_EmptyAppName(t *testing.T) {
err := saveAliases("", map[string]string{"primary": "model"})
if err == nil {
t.Error("expected error for empty app name")
}
}
func TestSaveAliases_CaseInsensitive(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Load with different case
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model1" {
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
}
// Update with different case
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to update: %v", err)
}
loaded, err = loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
if loaded.Aliases["primary"] != "model2" {
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
}
}
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
func TestSaveAliases_CreatesIntegration(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save aliases for non-existent integration
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
loaded, err := loadIntegration("newintegration")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model" {
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
}
}
func TestConfigureAliases_AliasMap(t *testing.T) {
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
aliases := make(map[string]string)
aliases["primary"] = "cloud-model"
// Simulate cloud model behavior
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "cloud-model" {
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
}
})
t.Run("cloud model preserves custom fast", func(t *testing.T) {
aliases := map[string]string{
"primary": "cloud-model",
"fast": "custom-fast-model",
}
// Simulate cloud model behavior - should preserve existing fast
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "custom-fast-model" {
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
}
})
t.Run("local model clears fast", func(t *testing.T) {
aliases := map[string]string{
"primary": "local-model",
"fast": "should-be-cleared",
}
// Simulate local model behavior
isCloud := false
if !isCloud {
delete(aliases, "fast")
}
if aliases["fast"] != "" {
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
}
})
t.Run("switching cloud to local clears fast", func(t *testing.T) {
// Start with cloud config
aliases := map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}
// Switch to local
aliases["primary"] = "local-model"
isCloud := false
if !isCloud {
delete(aliases, "fast")
}
if aliases["fast"] != "" {
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
}
if aliases["primary"] != "local-model" {
t.Errorf("primary should be updated, got %q", aliases["primary"])
}
})
t.Run("switching local to cloud sets fast", func(t *testing.T) {
// Start with local config (no fast)
aliases := map[string]string{
"primary": "local-model",
}
// Switch to cloud
aliases["primary"] = "cloud-model"
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "cloud-model" {
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
}
})
}
func TestSetAliases_PrefixMapping(t *testing.T) {
// This tests the expected mapping without needing a real client
aliases := map[string]string{
"primary": "my-cloud-model",
"fast": "my-fast-model",
}
expectedMappings := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
t.Errorf("claude-sonnet- should map to primary")
}
if expectedMappings["claude-haiku-"] != "my-fast-model" {
t.Errorf("claude-haiku- should map to fast")
}
}
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
aliases := map[string]string{
"primary": "local-model",
// fast is empty/missing - indicates local model
}
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
// Verify the logic: when fast is empty, we should delete
if aliases["fast"] != "" {
t.Error("fast should be empty for local model")
}
// Verify we have the right prefixes to delete
if len(prefixesToDelete) != 2 {
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
}
}
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate: server fails, config should NOT be saved
serverErr := errors.New("server unavailable")
if serverErr == nil {
t.Error("config should NOT be saved when server fails")
}
}
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate: server succeeds, config should be saved
var serverErr error
if serverErr != nil {
t.Fatal("server should succeed")
}
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("saveAliases failed: %v", err)
}
// Verify it was actually saved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model" {
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
}
}
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Write config with extra fields
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
os.MkdirAll(filepath.Dir(configPath), 0o755)
// Note: Our config struct only has Integrations, so top-level unknown fields
// won't be preserved by our current implementation. This test documents that.
initialConfig := `{
"integrations": {
"claude": {
"models": ["model1"],
"aliases": {"primary": "model1"},
"unknownField": "should be lost"
}
},
"topLevelUnknown": "will be lost"
}`
os.WriteFile(configPath, []byte(initialConfig), 0o644)
// Update aliases
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Read raw file to check
data, _ := os.ReadFile(configPath)
content := string(data)
// models should be preserved
if !contains(content, "model1") {
t.Error("models should be preserved")
}
// primary should be updated
if !contains(content, "model2") {
t.Error("primary should be updated to model2")
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
c := &Claude{}
var _ AliasConfigurer = c // Compile-time check
}
func TestModelNameEdgeCases(t *testing.T) {
testCases := []struct {
name string
model string
}{
{"simple", "llama3.2"},
{"with tag", "llama3.2:latest"},
{"with cloud tag", "kimi-k2.5:cloud"},
{"with namespace", "library/llama3.2"},
{"with dots", "glm-4.7-flash"},
{"with numbers", "qwen3:8b"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
aliases := map[string]string{"primary": tc.model}
if err := saveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save model %q: %v", tc.model, err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != tc.model {
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
}
})
}
}
func TestSwitchingScenarios(t *testing.T) {
t.Run("cloud to local removes fast", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial cloud config
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
t.Fatal(err)
}
// Switch to local (no fast)
if err := saveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
}
if loaded.Aliases["primary"] != "local-model" {
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
}
})
t.Run("local to cloud adds fast", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial local config
if err := saveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
// Switch to cloud (with fast)
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
}
})
t.Run("cloud to different cloud updates both", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial cloud config
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model-1",
"fast": "cloud-model-1",
}); err != nil {
t.Fatal(err)
}
// Switch to different cloud
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model-2",
"fast": "cloud-model-2",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["primary"] != "cloud-model-2" {
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
}
if loaded.Aliases["fast"] != "cloud-model-2" {
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
}
})
}
func TestToolCapabilityFiltering(t *testing.T) {
t.Run("all models checked for tool capability", func(t *testing.T) {
// Both cloud and local models are checked for tool capability via Show API
// Only models with "tools" in capabilities are included
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
if !m.ToolCapable {
t.Error("tool capable model should be marked as such")
}
})
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
if !m.ToolCapable {
t.Error("ToolCapable field should be accessible")
}
})
}
func TestIsCloudModel_RequiresClient(t *testing.T) {
t.Run("nil client always returns false", func(t *testing.T) {
// isCloudModel now only uses Show API, no suffix detection
if isCloudModel(context.Background(), nil, "model:cloud") {
t.Error("nil client should return false regardless of suffix")
}
if isCloudModel(context.Background(), nil, "local-model") {
t.Error("nil client should return false")
}
})
}
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save aliases with one model
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
t.Fatal(err)
}
// Save integration with same model (this is the pattern we use)
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["primary"] != loaded.Models[0] {
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
}
})
t.Run("out of sync config is detectable", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate out-of-sync state (like manual edit or bug)
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
// They should be different (this is the bug state)
if loaded.Models[0] == loaded.Aliases["primary"] {
t.Error("expected out-of-sync state for this test")
}
// The fix: when updating aliases, also update models
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
t.Fatal(err)
}
loaded, _ = loadIntegration("claude")
if loaded.Models[0] != loaded.Aliases["primary"] {
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
loaded.Models[0], loaded.Aliases["primary"])
}
})
t.Run("updating primary alias updates models too", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial state
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
t.Fatal(err)
}
// Update aliases AND models together
newAliases := map[string]string{"primary": "updated-model"}
if err := saveAliases("claude", newAliases); err != nil {
t.Fatal(err)
}
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Models[0] != "updated-model" {
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
}
if loaded.Aliases["primary"] != "updated-model" {
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
}
})
}

View File

@@ -46,53 +46,6 @@ func TestIntegrationConfig(t *testing.T) {
}
})
t.Run("save and load aliases", func(t *testing.T) {
models := []string{"llama3.2"}
if err := saveIntegration("claude", models); err != nil {
t.Fatal(err)
}
aliases := map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
}
if err := saveAliases("claude", aliases); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases == nil {
t.Fatal("expected aliases to be saved")
}
for k, v := range aliases {
if config.Aliases[k] != v {
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
}
}
})
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
t.Fatal(err)
}
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases["primary"] != "model-a" {
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
}
})
t.Run("defaultModel returns first model", func(t *testing.T) {
saveIntegration("codex", []string{"model-a", "model-b"})

View File

@@ -39,15 +39,6 @@ type Editor interface {
Models() []string
}
// AliasConfigurer can configure model aliases (e.g., for subagent routing).
// Integrations like Claude and Codex use this to route model requests to local models.
type AliasConfigurer interface {
// ConfigureAliases prompts the user to configure aliases and returns the updated map.
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
// SetAliases syncs the configured aliases to the server
SetAliases(ctx context.Context, aliases map[string]string) error
}
// integrations is the registry of available integrations.
var integrations = map[string]Runner{
"claude": &Claude{},
@@ -138,11 +129,7 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
return nil, err
}
} else {
prompt := fmt.Sprintf("Select model for %s:", r)
if _, ok := r.(AliasConfigurer); ok {
prompt = fmt.Sprintf("Select Primary model for %s:", r)
}
model, err := selectPrompt(prompt, items)
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
if err != nil {
return nil, err
}
@@ -170,123 +157,73 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
}
}
if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
return nil, err
}
return selected, nil
}
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
if existingModels[model] {
return nil
}
msg := fmt.Sprintf("Download %s?", model)
if ok, err := confirmPrompt(msg); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
if err := pullModel(ctx, client, model); err != nil {
return fmt.Errorf("failed to pull %s: %w", model, err)
}
return nil
}
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, nil, nil, nil, err
}
models, err := client.List(ctx)
if err != nil {
return nil, nil, nil, nil, err
}
var existing []modelInfo
for _, m := range models.Models {
existing = append(existing, modelInfo{
Name: m.Name,
Remote: m.RemoteModel != "",
})
}
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
if len(items) == 0 {
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
return items, existingModels, cloudModels, client, nil
}
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) == 0 {
return nil
}
if len(selectedCloudModels) > 0 {
// ensure user is signed in
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return selected, nil
}
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return nil, err
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return err
}
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return nil, fmt.Errorf("%s requires sign in", modelList)
}
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return fmt.Errorf("%s requires sign in", modelList)
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
// TODO(parthsareen): extract into auth package for cmd
// Auto-open browser (best effort, fail silently)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return nil, ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return nil
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return selected, nil
}
}
}
}
}
return selected, nil
}
func runIntegration(name, modelName string, args []string) error {
@@ -294,33 +231,10 @@ func runIntegration(name, modelName string, args []string) error {
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName, args)
}
// syncAliases syncs aliases to server and saves locally for an AliasConfigurer.
func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, name, model string, existing map[string]string) error {
aliases := make(map[string]string)
for k, v := range existing {
aliases[k] = v
}
aliases["primary"] = model
if isCloudModel(ctx, client, model) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = model
}
} else {
delete(aliases, "fast")
}
if err := ac.SetAliases(ctx, aliases); err != nil {
return err
}
return saveAliases(name, aliases)
}
// LaunchCmd returns the cobra command for launching integrations.
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
var modelFlag string
@@ -388,87 +302,9 @@ Examples:
return fmt.Errorf("unknown integration: %s", name)
}
// Handle AliasConfigurer integrations (claude, codex)
if ac, ok := r.(AliasConfigurer); ok {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Validate --model flag if provided
if modelFlag != "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
}
}
var model string
var existingAliases map[string]string
// Load saved config
if cfg, err := loadIntegration(name); err == nil {
existingAliases = cfg.Aliases
if len(cfg.Models) > 0 {
model = cfg.Models[0]
// AliasConfigurer integrations use single model; sanitize if multiple
if len(cfg.Models) > 1 {
_ = saveIntegration(name, []string{model})
}
}
}
// --model flag overrides saved model
if modelFlag != "" {
model = modelFlag
}
// Validate saved model still exists
if model != "" && modelFlag == "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}); err != nil {
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
model = ""
}
}
// If no valid model or --config flag, show picker
if model == "" || configFlag {
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
}
// Sync aliases and save
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
}
if err := saveIntegration(name, []string{model}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
// Launch (unless --config without confirmation)
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
return runIntegration(name, model, passArgs)
}
return nil
}
return runIntegration(name, model, passArgs)
}
// Validate --model flag for non-AliasConfigurer integrations
if modelFlag != "" {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
if !configFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0], passArgs)
}
}
@@ -482,8 +318,6 @@ Examples:
}
}
}
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
return runIntegration(name, saved.Models[0], passArgs)
} else {
var err error
models, err = selectModels(cmd.Context(), name, "")
@@ -546,9 +380,8 @@ Examples:
}
type modelInfo struct {
Name string
Remote bool
ToolCapable bool
Name string
Remote bool
}
// buildModelList merges existing models with recommendations, sorts them, and returns
@@ -585,7 +418,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
continue
}
items = append(items, rec)
if strings.HasSuffix(rec.Name, ":cloud") {
if isCloudModel(rec.Name) {
cloudModels[rec.Name] = true
}
}
@@ -645,16 +478,8 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
return items, preChecked, existingModels, cloudModels
}
// isCloudModel checks if a model is a cloud model using the Show API.
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
if client == nil {
return false
}
resp, err := client.Show(ctx, &api.ShowRequest{Name: name})
if err != nil {
return false
}
return resp.RemoteModel != ""
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud")
}
func pullModel(ctx context.Context, client *api.Client, model string) error {

View File

@@ -1,7 +1,6 @@
package config
import (
"context"
"fmt"
"slices"
"strings"
@@ -298,15 +297,24 @@ func TestParseArgs(t *testing.T) {
}
func TestIsCloudModel(t *testing.T) {
// isCloudModel now only uses Show API, so nil client always returns false
t.Run("nil client returns false", func(t *testing.T) {
models := []string{"glm-4.7:cloud", "kimi-k2.5:cloud", "local-model"}
for _, model := range models {
if isCloudModel(context.Background(), nil, model) {
t.Errorf("isCloudModel(%q) with nil client should return false", model)
tests := []struct {
name string
want bool
}{
{"glm-4.7:cloud", true},
{"kimi-k2.5:cloud", true},
{"glm-4.7-flash", false},
{"glm-4.7-flash:latest", false},
{"cloud-model", false},
{"model:cloudish", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isCloudModel(tt.name); got != tt.want {
t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want)
}
}
})
})
}
}
func names(items []selectItem) []string {
@@ -501,41 +509,3 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
t.Error("llama3.2 should not be in cloudModels")
}
}
func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save a config for opencode so it looks like a previous launch
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Verify loadIntegration returns the saved models
saved, err := loadIntegration("opencode")
if err != nil {
t.Fatal(err)
}
if len(saved.Models) == 0 {
t.Fatal("expected saved models")
}
if saved.Models[0] != "llama3.2" {
t.Errorf("expected llama3.2, got %s", saved.Models[0])
}
}
func TestAliasConfigurerInterface(t *testing.T) {
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
claude := &Claude{}
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
t.Error("Claude should implement AliasConfigurer")
}
})
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
codex := &Codex{}
if _, ok := interface{}(codex).(AliasConfigurer); ok {
t.Error("Codex should not implement AliasConfigurer")
}
})
}

View File

@@ -17,6 +17,8 @@ type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" }
const ansiGreen = "\033[32m"
func (c *Openclaw) Run(model string, args []string) error {
bin := "openclaw"
if _, err := exec.LookPath(bin); err != nil {

View File

@@ -1,7 +1,6 @@
package config
import (
"context"
"encoding/json"
"fmt"
"maps"
@@ -11,52 +10,12 @@ import (
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
// cloudModelLimit holds context and output token limits for a cloud model.
type cloudModelLimit struct {
Context int
Output int
}
// cloudModelLimits maps cloud model base names to their token limits.
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
var cloudModelLimits = map[string]cloudModelLimit{
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
"glm-4.6": {Context: 202_752, Output: 131_072},
"glm-4.7": {Context: 202_752, Output: 131_072},
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
"kimi-k2.5": {Context: 262_144, Output: 262_144},
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It tries the exact name first, then strips the ":cloud" suffix.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
if l, ok := cloudModelLimits[name]; ok {
return l, true
}
base := strings.TrimSuffix(name, ":cloud")
if base != name {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string, args []string) error {
@@ -154,8 +113,6 @@ func (o *OpenCode) Edit(modelList []string) error {
}
}
client, _ := api.ClientFromEnvironment()
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
@@ -165,29 +122,12 @@ func (o *OpenCode) Edit(modelList []string) error {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
existing["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
continue
}
entry := map[string]any{
models[model] = map[string]any{
"name": model,
"_launch": true,
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
models[model] = entry
}
ollama["models"] = models

View File

@@ -2,7 +2,6 @@ package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
@@ -496,165 +495,6 @@ func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
}
}
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
t.Helper()
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry, ok := models[model].(map[string]any)
if !ok {
t.Fatalf("model %s not found in config", model)
}
return entry
}
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
if entry["limit"] != nil {
t.Errorf("local model should not have limit set, got %v", entry["limit"])
}
}
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
// Set up a model with a user-configured limit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{
"provider": {
"ollama": {
"models": {
"llama3.2": {
"name": "llama3.2",
"_launch": true,
"limit": {"context": 8192, "output": 4096}
}
}
}
}
}`), 0o644)
// Re-edit should preserve the user's limit (not delete it)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("user-configured limit was removed")
}
if limit["context"] != float64(8192) {
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
}
if limit["output"] != float64(4096) {
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
}
}
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
// Verify that when a cloud model entry has limits set (as Edit would do),
// the structure matches what opencode expects and re-edit preserves them.
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
expected := cloudModelLimits["glm-4.7"]
// Simulate a cloud model that already has the limit set by a previous Edit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
"provider": {
"ollama": {
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud",
"_launch": true,
"limit": {"context": %d, "output": %d}
}
}
}
}
}`, expected.Context, expected.Output)), 0o644)
// Re-edit should preserve the cloud model limit
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model limit was removed on re-edit")
}
if limit["context"] != float64(expected.Context) {
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
}
if limit["output"] != float64(expected.Output) {
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
}
}
func TestLookupCloudModelLimit(t *testing.T) {
tests := []struct {
name string
wantOK bool
wantContext int
wantOutput int
}{
{"glm-4.7", true, 202_752, 131_072},
{"glm-4.7:cloud", true, 202_752, 131_072},
{"kimi-k2.5", true, 262_144, 262_144},
{"kimi-k2.5:cloud", true, 262_144, 262_144},
{"deepseek-v3.2", true, 163_840, 65_536},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3-coder:480b", true, 262_144, 65_536},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(tt.name)
if ok != tt.wantOK {
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
}
if ok {
if l.Context != tt.wantContext {
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
}
if l.Output != tt.wantOutput {
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
}
}
})
}
}
func TestOpenCodeModels_NoConfig(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()

View File

@@ -17,7 +17,6 @@ const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiClearDown = "\033[J"
)

View File

@@ -96,14 +96,6 @@ func TestSelectState(t *testing.T) {
}
})
t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) {
s := newSelectState([]selectItem{})
done, result, err := s.handleInput(eventEnter, 0)
if done || result != "" || err != nil {
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEscape, 0)
@@ -582,19 +574,8 @@ func TestRenderSelect(t *testing.T) {
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "no matches") {
t.Errorf("expected 'no matches' message, got: %s", output)
}
})
t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) {
s := newSelectState([]selectItem{})
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' message for empty list with no filter")
t.Error("expected 'no matches' message")
}
})

View File

@@ -10,21 +10,19 @@ import (
"github.com/ollama/ollama/api"
)
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
func startApp(ctx context.Context, client *api.Client) error {
exe, err := os.Executable()
if err != nil {
return errNotRunning
return err
}
link, err := os.Readlink(exe)
if err != nil {
return errNotRunning
return err
}
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
m := r.FindStringSubmatch(link)
if len(m) != 1 {
return errNotRunning
return errors.New("could not find ollama app")
}
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
return err

View File

@@ -34,7 +34,6 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
type filteredEnv []string
@@ -117,7 +116,7 @@ type llamaServer struct {
type ollamaServer struct {
llmServer
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
}
// LoadModel will load a model from disk. The model must be in the GGML format.
@@ -143,11 +142,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
var llamaModel *llama.Model
var tok tokenizer.Tokenizer
var textProcessor model.TextProcessor
var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 {
tok, err = model.NewTextProcessor(modelPath)
textProcessor, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
@@ -156,7 +155,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
}
}
if tok == nil {
if textProcessor == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
@@ -212,7 +211,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType())
if tok == nil {
if textProcessor == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
@@ -262,7 +261,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
gpuLibs := ml.LibraryPaths(gpus)
status := NewStatusWriter(os.Stderr)
cmd, port, err := StartRunner(
tok != nil,
textProcessor != nil,
modelPath,
gpuLibs,
status,
@@ -311,8 +310,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}
}()
if tok != nil {
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
if textProcessor != nil {
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
}
@@ -1775,7 +1774,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
}
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
tokens, err := s.tokenizer.Encode(content, false)
tokens, err := s.textProcessor.Encode(content, false)
if err != nil {
return nil, err
}
@@ -1810,7 +1809,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
toks[i] = int32(t)
}
content, err := s.tokenizer.Decode(toks)
content, err := s.textProcessor.Decode(toks)
if err != nil {
return "", err
}

View File

@@ -131,15 +131,12 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
messageID := anthropic.GenerateMessageID()
// Estimate input tokens for streaming (actual count not available until generation completes)
estimatedTokens := anthropic.EstimateInputTokens(req)
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
converter: anthropic.NewStreamConverter(messageID, req.Model),
}
if req.Stream {

272
model/bytepairencoding.go Normal file
View File

@@ -0,0 +1,272 @@
package model
import (
"cmp"
"iter"
"slices"
"strings"
"github.com/dlclark/regexp2"
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil"
)
type BytePairEncoding struct {
vocab *Vocabulary
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
}
}),
}
}
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
return bpe.vocab
}
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
return bpe.vocab.Is(id, special)
}
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
parts := []string{s}
for _, re := range bpe.regexps {
parts = slices.Collect(func(yield func(string) bool) {
for _, part := range parts {
r := []rune(part)
var offset int
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
if offset-m.Index != 0 {
if !yield(string(r[:m.Index])) {
return
}
}
if !yield(m.String()) {
return
}
offset = m.Index + m.Length
}
if offset < len(r) {
if !yield(string(r[offset:])) {
return
}
}
}
})
}
return slices.Values(parts)
}
// fragment is a string fragment and their corresponding token IDs
type fragment struct {
value string
ids []int32
}
// pair is a pair of runes and its rank
type pair struct {
a, b int
rank int
value string
}
type merge struct {
p, n int
runes []rune
}
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := bpe.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range bpe.split(frag.value) {
// TODO: process splits concurrently
var sb strings.Builder
for _, b := range []byte(split) {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
sb.WriteRune(r)
}
// short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *pair {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
rank := bpe.vocab.Merge(left, right)
if rank < 0 {
return nil
}
return &pair{
a: a,
b: b,
rank: rank,
value: left + right,
}
}
pairs := heap.NewWith(func(i, j *pair) int {
return cmp.Compare(i.rank, j.rank)
})
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pairs.Push(pair)
}
}
for !pairs.Empty() {
pair, _ := pairs.Pop()
left, right := merges[pair.a], merges[pair.b]
if len(left.runes) == 0 || len(right.runes) == 0 ||
string(left.runes)+string(right.runes) != pair.value {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pairs.Push(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pairs.Push(pair)
}
}
for _, merge := range merges {
if len(merge.runes) > 0 {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
}
}
}
}
}
if addSpecial {
ids = bpe.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
for _, r := range bpe.vocab.Decode(id) {
switch {
case r == 0x0100:
// this produces 0x00 aka NULL
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// NOTE: not using WriteRune here because it writes the UTF-8
// encoding of the rune which is _not_ what we want
if err := sb.WriteByte(byte(r)); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "string", sb.String(), "from", ids)
return sb.String(), nil
}

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"bufio"
@@ -17,7 +17,7 @@ import (
func llama(t testing.TB) BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
}
}
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
if err != nil {
t.Fatal(err)
}

View File

@@ -23,7 +23,6 @@ import (
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var (
@@ -134,7 +133,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil
}
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
func NewTextProcessor(s string) (TextProcessor, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
@@ -151,7 +150,7 @@ func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
return nil, err
}
tp, ok := m.(tokenizer.Tokenizer)
tp, ok := m.(TextProcessor)
if !ok {
return nil, ErrUnsupportedTokenizer
}

View File

@@ -10,12 +10,11 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -130,7 +129,7 @@ func (o Options) headDim() int {
}
func New(c fs.Config) (model.Model, error) {
vocab := &tokenizer.Vocabulary{
vocab := &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -154,17 +153,17 @@ func New(c fs.Config) (model.Model, error) {
},
}
var t tokenizer.Tokenizer
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
t = tokenizer.NewWordPiece(vocab, true)
processor = model.NewWordPiece(vocab, true)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
Tokenizer: t,
Layers: make([]EncoderLayer, c.Uint("block_count")),
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,7 +13,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -223,7 +222,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -278,8 +277,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,12 +10,11 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
@@ -135,8 +134,8 @@ func init() {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,7 +10,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -28,7 +27,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
tokenizer.Tokenizer
model.SentencePiece
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -44,8 +43,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -7,12 +7,11 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
tokenizer.Tokenizer
model.SentencePiece
*TextModel
poolingType pooling.Type
@@ -32,8 +31,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -12,12 +12,11 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
*VisionModel `gguf:"v"`
*TextModel
@@ -55,7 +54,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
vocabulary := tokenizer.Vocabulary{
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -71,19 +70,19 @@ func New(c fs.Config) (model.Model, error) {
),
}
var t tokenizer.Tokenizer
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model") {
case "gpt2":
t = tokenizer.NewBytePairEncoding(&vocabulary)
processor = model.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
t = tokenizer.NewSentencePiece(&vocabulary)
processor = model.NewSentencePiece(&vocabulary)
}
m := Model{
Tokenizer: t,
TextProcessor: processor,
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),

View File

@@ -6,12 +6,11 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.SentencePiece
*TextModel
}
@@ -24,8 +23,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -10,7 +10,6 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
@@ -199,7 +198,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -237,8 +236,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,12 +11,11 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v"`
@@ -38,8 +37,8 @@ func New(c fs.Config) (model.Model, error) {
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
m := &Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,12 +12,11 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Transformer struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TransformerBlocks []TransformerBlock `gguf:"blk"`
@@ -197,8 +196,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,7 +10,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -60,7 +59,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -79,7 +78,7 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedTokenizer
}
vocabulary := tokenizer.Vocabulary{
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -105,8 +104,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),

View File

@@ -11,7 +11,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -26,7 +25,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -42,8 +41,8 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedModel
}
var processor tokenizer.Tokenizer
vocabulary := tokenizer.Vocabulary{
var processor model.TextProcessor
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -81,16 +80,16 @@ func New(c fs.Config) (model.Model, error) {
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
processor = tokenizer.NewSentencePiece(&vocabulary)
processor = model.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
Tokenizer: processor,
Layers: make([]Layer, c.Uint("block_count")),
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -11,12 +11,11 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
ImageProcessor
*VisionModel `gguf:"v"`
@@ -34,8 +33,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) {
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,12 +11,11 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v"`
@@ -29,12 +28,12 @@ type Model struct {
var _ model.MultimodalProcessor = (*Model)(nil)
// Implement TextProcessor interface
var _ tokenizer.Tokenizer = (*Model)(nil)
var _ model.TextProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,12 +11,11 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
*VisionModel `gguf:"v"`
*TextModel
@@ -33,8 +32,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,12 +11,11 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -179,6 +178,29 @@ func New(c fs.Config) (model.Model, error) {
numHeads := int(c.Uint("attention.head_count"))
headDim := hiddenSize / numHeads
processor := model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
)
blockCount := int(c.Uint("block_count"))
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
layers := make([]EncoderLayer, blockCount)
@@ -197,29 +219,8 @@ func New(c fs.Config) (model.Model, error) {
}
return &Model{
Tokenizer: tokenizer.NewWordPiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
),
Layers: layers,
TextProcessor: processor,
Layers: layers,
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,

View File

@@ -11,7 +11,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
const (
@@ -34,7 +33,7 @@ type Options struct {
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -45,24 +44,28 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
Layers: make([]Layer, c.Uint("block_count")),
}
processor := model.NewBytePairEncoding(
&vocabulary,
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,7 +13,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -93,7 +92,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []DecoderLayer `gguf:"blk"`
@@ -140,8 +139,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,12 +10,11 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v"`
@@ -28,8 +27,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -7,12 +7,11 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
*Model
poolingType pooling.Type
@@ -35,8 +34,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
layers[i].MLP = &dense{}
}
m := embedModel{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,7 +12,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -160,7 +159,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -219,8 +218,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -39,12 +39,15 @@ func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
if nSeqs > 0 {
// 3D tensor: [hiddenDim, seqTokens, nSeqs]
if batchSize != seqTokens || nSeqs != seqs {
return nil, ErrUnsupportedBatchLayout
// Fallback: treat as flat batch if layout doesn't match.
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize*nSeqs)
batchSize = batchSize * nSeqs
} else {
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
batchSize = seqTokens * seqs
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
batchSize = seqTokens * seqs
} else if batchSize != seqTokens*seqs {
return nil, ErrUnsupportedBatchLayout
// Layout mismatch; proceed with flat batch.
}
}
}

View File

@@ -64,6 +64,8 @@ type HybridCache struct {
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// token indices per sequence in batch order
curSeqTokenIdxs [][]int32
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
@@ -168,19 +170,44 @@ func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bo
}
if len(c.curSeqs) == 0 {
c.curSeqTokenIdxs = c.curSeqTokenIdxs[:0]
return nil
}
if cap(c.curSeqTokenIdxs) < len(c.curSeqs) {
c.curSeqTokenIdxs = make([][]int32, len(c.curSeqs))
} else {
c.curSeqTokenIdxs = c.curSeqTokenIdxs[:len(c.curSeqs)]
}
for i := range c.curSeqTokenIdxs {
c.curSeqTokenIdxs[i] = c.curSeqTokenIdxs[i][:0]
}
seqIndex := make(map[int]int, len(c.curSeqs))
for i, s := range c.curSeqs {
seqIndex[s] = i
}
for i, s := range batch.Sequences {
c.curSeqTokenIdxs[seqIndex[s]] = append(c.curSeqTokenIdxs[seqIndex[s]], int32(i))
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
uniform := true
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
uniform = false
break
}
}
c.curSeqTokens = want
if uniform {
c.curSeqTokens = want
} else {
// Mixed batch: recurrent layers will process sequences independently.
c.curSeqTokens = 0
}
// When reserving memory for estimation, use fake slot assignments
if reserve {
@@ -585,7 +612,101 @@ func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Te
c.captureDeltaCheckpoint(ctx, layer, srcF32)
}
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
// convStateForSlot returns the conv state for a single slot as [convDim, convChannels, 1].
func (c *HybridCache) convStateForSlot(ctx ml.Context, layer int, slot int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
cur := buf.Rows(ctx, slotIdx)
return cur.Reshape(ctx, c.convDim, c.convChannels, 1), nil
}
// updateConvStateForSlot writes a new conv state for a single slot.
func (c *HybridCache) updateConvStateForSlot(ctx ml.Context, layer int, slot int, seqIndex int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.convDim*c.convChannels, 1)
srcF32 := src.Cast(ctx, ml.DTypeF32)
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
ctx.Forward(buf.SetRows(ctx, srcF32, slotIdx))
c.captureConvCheckpointForSeq(ctx, layer, seqIndex, srcF32)
}
// deltaStateForSlot returns the delta state for a single slot as [headVDim, headVDim*numVHeads, 1].
func (c *HybridCache) deltaStateForSlot(ctx ml.Context, layer int, slot int, headVDim, numVHeads int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.deltaBuffer(ctx, layer)
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
cur := buf.Rows(ctx, slotIdx)
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, 1), nil
}
// updateDeltaStateForSlot writes a new delta state for a single slot.
func (c *HybridCache) updateDeltaStateForSlot(ctx ml.Context, layer int, slot int, seqIndex int, newState ml.Tensor) {
buf := c.deltaBuffer(ctx, layer)
src := newState.Reshape(ctx, c.deltaStateSize, 1)
srcF32 := src.Cast(ctx, ml.DTypeF32)
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
ctx.Forward(buf.SetRows(ctx, srcF32, slotIdx))
c.captureDeltaCheckpointForSeq(ctx, layer, seqIndex, srcF32)
}
func (c *HybridCache) captureConvCheckpointForSeq(ctx ml.Context, layer int, seqIndex int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointConv(layer)
return
}
if seqIndex < 0 || seqIndex >= len(c.curCheckpointPos) {
return
}
pos := c.curCheckpointPos[seqIndex]
if pos < 0 {
return
}
slot := c.curSlots[seqIndex]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
return
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointConv(layer, entry)
ctx.Forward(src.Copy(ctx, dst))
}
func (c *HybridCache) captureDeltaCheckpointForSeq(ctx ml.Context, layer int, seqIndex int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointDelta(layer)
return
}
if seqIndex < 0 || seqIndex >= len(c.curCheckpointPos) {
return
}
pos := c.curCheckpointPos[seqIndex]
if pos < 0 {
return
}
slot := c.curSlots[seqIndex]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
return
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointDelta(layer, entry)
ctx.Forward(src.Copy(ctx, dst))
}
// IsSupportedForBatch returns true if the current batch layout supports grid-style recurrent processing.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}

View File

@@ -48,6 +48,13 @@ type GatedDeltaNet struct {
Layer int
}
type stateAccessors struct {
convState func() (ml.Tensor, error)
updateConv func(ml.Tensor)
deltaState func() (ml.Tensor, error)
updateDelta func(ml.Tensor)
}
// createMasks builds the constant mask tensors (called once, reused for all chunks)
func createMasks(ctx ml.Context) *Masks {
ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
@@ -68,7 +75,6 @@ func createMasks(ctx ml.Context) *Masks {
}
func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
layer := gdn.Layer
nSeqTokens := hiddenStates.Dim(1)
nSeqs := hiddenStates.Dim(2)
if cache != nil && cache.IsSupportedForBatch() {
@@ -77,34 +83,140 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
if seqTokens > 0 && seqs > 0 {
if nSeqs > 1 {
if nSeqTokens != seqTokens || nSeqs != seqs {
return nil, ErrUnsupportedBatchLayout
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
}
} else {
if nSeqTokens != seqTokens*seqs {
return nil, ErrUnsupportedBatchLayout
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
nSeqTokens = seqTokens
nSeqs = seqs
}
}
numVHeads := opts.ssmDtRank
headVDim := opts.ssmDInner / numVHeads
layer := gdn.Layer
access := stateAccessors{
convState: func() (ml.Tensor, error) {
return cache.ConvState(ctx, layer)
},
updateConv: func(newState ml.Tensor) {
cache.UpdateConvState(ctx, layer, newState)
},
deltaState: func() (ml.Tensor, error) {
return cache.DeltaState(ctx, layer, headVDim, numVHeads)
},
updateDelta: func(newState ml.Tensor) {
cache.UpdateDeltaState(ctx, layer, newState)
},
}
return gdn.forwardWithAccessors(ctx, hiddenStates, opts, access)
}
if cache == nil {
return nil, ErrUnsupportedBatchLayout
}
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
}
func (gdn *GatedDeltaNet) forwardMixed(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
if hiddenStates.Dim(2) > 0 {
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2))
}
if len(cache.curSeqs) == 0 {
return hiddenStates, nil
}
// Ensure any shared slots are detached once for this forward pass.
cache.ensureWritableOnce(ctx)
layer := gdn.Layer
numVHeads := opts.ssmDtRank
headVDim := opts.ssmDInner / numVHeads
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
}
// Precompute projections for the full batch and slice per sequence.
mixedBAFull := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
qkvMixedFull := gdn.SSMQKV.Forward(ctx, hiddenStates)
zFull := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
out := hiddenStates
for seqIndex := range cache.curSeqs {
idxs := cache.curSeqTokenIdxs[seqIndex]
if len(idxs) == 0 {
continue
}
idxTensor := ctx.Input().FromInts(idxs, len(idxs))
mixedBA := mixedBAFull.Rows(ctx, idxTensor)
qkvMixed := qkvMixedFull.Rows(ctx, idxTensor)
z := zFull.Rows(ctx, idxTensor)
slot := cache.curSlots[seqIndex]
access := stateAccessors{
convState: func() (ml.Tensor, error) {
return cache.convStateForSlot(ctx, layer, slot)
},
updateConv: func(newState ml.Tensor) {
cache.updateConvStateForSlot(ctx, layer, slot, seqIndex, newState)
},
deltaState: func() (ml.Tensor, error) {
return cache.deltaStateForSlot(ctx, layer, slot, headVDim, numVHeads)
},
updateDelta: func(newState ml.Tensor) {
cache.updateDeltaStateForSlot(ctx, layer, slot, seqIndex, newState)
},
}
seqOut, err := gdn.forwardProjected(ctx, len(idxs), 1, mixedBA, qkvMixed, z, opts, access)
if err != nil {
return nil, err
}
out = out.SetRows(ctx, seqOut, idxTensor)
}
return out, nil
}
func (gdn *GatedDeltaNet) forwardWithAccessors(ctx ml.Context, hiddenStates ml.Tensor, opts *Options, access stateAccessors) (ml.Tensor, error) {
nSeqTokens := hiddenStates.Dim(1)
nSeqs := hiddenStates.Dim(2)
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
}
// Optimized path: pre-split QKV and gate
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates)
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
return gdn.forwardProjected(ctx, nSeqTokens, nSeqs, mixedBA, qkvMixed, z, opts, access)
}
func (gdn *GatedDeltaNet) forwardProjected(
ctx ml.Context,
nSeqTokens, nSeqs int,
mixedBA, qkvMixed, z ml.Tensor,
opts *Options,
access stateAccessors,
) (ml.Tensor, error) {
layer := gdn.Layer
headKDim := opts.ssmDState
numKHeads := opts.ssmNGroup
numVHeads := opts.ssmDtRank
headVDim := opts.ssmDInner / numVHeads
convKernelSize := opts.convKernelSize
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
}
// Optimized path: pre-split QKV and gate
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
qkvMixed = qkvMixed.Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
baNewDim := 2 * numVHeads / numKHeads
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
@@ -127,7 +239,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
// Get conv state from cache
convStates, err := cache.ConvState(ctx, layer)
convStates, err := access.convState()
if err != nil {
// Log this - if it happens, short-term context will be lost
slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
@@ -142,7 +254,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
// Save new conv state (last convKernelSize-1 tokens)
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
cache.UpdateConvState(ctx, layer, lastConvStates)
access.updateConv(lastConvStates)
// Apply SSM convolution (kernel must be F32 for Metal)
convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
@@ -162,7 +274,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
// Get delta state from cache
state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads)
state, err := access.deltaState()
if err != nil {
// Log this - if it happens frequently, context will degrade
slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
@@ -185,14 +297,19 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
}
// Choose computation mode based on sequence length
var attnOut ml.Tensor
var (
attnOut ml.Tensor
newState ml.Tensor
)
if nSeqTokens == 1 {
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
attnOut, newState = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts)
} else {
// Use pre-computed masks from opts (created once in Model.Forward)
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
attnOut, newState = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts)
}
access.updateDelta(newState)
// Apply gated normalization
attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
@@ -215,9 +332,7 @@ func (gdn *GatedDeltaNet) deltaNetAutoregressive(
ctx ml.Context,
q, k, v, gate, beta, state ml.Tensor,
opts *Options,
layer int,
cache *HybridCache,
) ml.Tensor {
) (ml.Tensor, ml.Tensor) {
numVHeads := v.Dim(1)
headVDim := v.Dim(0)
nSeqs := q.Dim(3)
@@ -273,10 +388,8 @@ func (gdn *GatedDeltaNet) deltaNetAutoregressive(
coreAttnOut := stateQ.SumRows(ctx)
coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs)
newState := state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs), newState
}
// deltaNetChunked implements chunked computation for prefill.
@@ -286,9 +399,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
q, k, v, gate, beta, state ml.Tensor,
masks *Masks,
opts *Options,
layer int,
cache *HybridCache,
) ml.Tensor {
) (ml.Tensor, ml.Tensor) {
headKDim := q.Dim(0)
numVHeads := v.Dim(1)
headVDim := v.Dim(0)
@@ -465,8 +576,6 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
}
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs)
newStateFlat := newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs), newStateFlat
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
// Options contains model configuration
@@ -208,7 +207,7 @@ func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outp
// Model is the main Qwen3-Next model
type Model struct {
model.Base
tokenizer.Tokenizer
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -354,8 +353,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,12 +10,11 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
model.TextProcessor
*TextModel
*VisionModel `gguf:"v"`
@@ -173,8 +172,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"container/heap"
@@ -17,7 +17,7 @@ type SentencePiece struct {
vocab *Vocabulary
}
var _ Tokenizer = (*SentencePiece)(nil)
var _ TextProcessor = (*SentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
@@ -224,7 +224,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizer that use byte tokens like "<0xEA>"
// For tokenizers that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"log/slog"
@@ -15,7 +15,7 @@ import (
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
const (
TOKEN_TYPE_NORMAL = iota + 1
@@ -9,7 +9,7 @@ const (
TOKEN_TYPE_BYTE
)
type Tokenizer interface {
type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"log/slog"

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"testing"

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"fmt"
@@ -32,7 +32,7 @@ var wordPieceReplacer = strings.NewReplacer(
" 're", "'re",
)
// Decode implements Tokenizer.
// Decode implements TextProcessor.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
@@ -96,7 +96,7 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
}
}
// Encode implements Tokenizer.
// Encode implements TextProcessor.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
@@ -151,17 +151,17 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
// Is implements Tokenizer.
// Is implements TextProcessor.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements Tokenizer.
// Vocabulary implements TextProcessor.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ Tokenizer = (*WordPiece)(nil)
var _ TextProcessor = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
return WordPiece{

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"slices"

View File

@@ -37,7 +37,6 @@ import (
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/tokenizer"
_ "github.com/ollama/ollama/model/models"
)
@@ -211,9 +210,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
decoder := func(tokenID int) string {
text, _ := tok.Decode([]int32{int32(tokenID)})
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
return text
}
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
@@ -243,7 +242,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
if err != nil {
return nil, nil, nil, err
}
@@ -765,7 +764,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
@@ -774,14 +773,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
if err != nil {
panic("failed to decode token")
}
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
if seq.logprobs {
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
}
@@ -879,7 +878,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var grammar *sample.GrammarSampler
var err error
if req.Grammar != "" {
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return

View File

@@ -7,7 +7,7 @@ import (
"slices"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/tokenizer"
"github.com/ollama/ollama/model"
)
// token represents information about a single token during sampling
@@ -168,15 +168,15 @@ type GrammarSampler struct {
grammar *llama.Grammar
}
func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(tok.Vocabulary().Values))
pieces := make([]string, len(tok.Vocabulary().Values))
for i := range tok.Vocabulary().Values {
pieces[i], _ = tok.Decode([]int32{int32(i)})
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(model.Vocabulary().Values))
pieces := make([]string, len(model.Vocabulary().Values))
for i := range model.Vocabulary().Values {
pieces[i], _ = model.Decode([]int32{int32(i)})
vocabIds[i] = uint32(i)
}
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tok.Vocabulary().EOS)
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}

View File

@@ -8,7 +8,7 @@ import (
"path/filepath"
"testing"
"github.com/ollama/ollama/tokenizer"
"github.com/ollama/ollama/model"
)
func TestWeighted(t *testing.T) {
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
}
}
func modelHelper(t testing.TB) tokenizer.Tokenizer {
func modelHelper(t testing.TB) model.BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.FromSlash("../tokenizer/testdata/llama3.2/encoder.json"))
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) tokenizer.Tokenizer {
merges := make([]string, 0, 1)
// Only need vocab for Grammar Test
return tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
return model.NewBytePairEncoding(
&model.Vocabulary{
Values: tokens,
Types: make([]int32, len(vocab)),
Merges: merges,

View File

@@ -1,5 +1,5 @@
#!/bin/sh
# This script installs Ollama on Linux and macOS.
# This script installs Ollama on Linux.
# It detects the current operating system architecture and installs the appropriate version of Ollama.
set -eu
@@ -27,7 +27,8 @@ require() {
echo $MISSING
}
OS="$(uname -s)"
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;;
@@ -35,65 +36,6 @@ case "$ARCH" in
*) error "Unsupported architecture: $ARCH" ;;
esac
###########################################
# macOS
###########################################
if [ "$OS" = "Darwin" ]; then
NEEDS=$(require curl unzip)
if [ -n "$NEEDS" ]; then
status "ERROR: The following tools are required but missing:"
for NEED in $NEEDS; do
echo " - $NEED"
done
exit 1
fi
if [ -n "${OLLAMA_VERSION:-}" ]; then
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/download/${OLLAMA_VERSION}/Ollama-darwin.zip"
else
DOWNLOAD_URL="https://github.com/ollama/ollama/releases/latest/download/Ollama-darwin.zip"
fi
if pgrep -x Ollama >/dev/null 2>&1; then
status "Stopping running Ollama instance..."
pkill -x Ollama 2>/dev/null || true
sleep 2
fi
if [ -d "/Applications/Ollama.app" ]; then
status "Removing existing Ollama installation..."
rm -rf "/Applications/Ollama.app"
fi
status "Downloading Ollama for macOS..."
curl --fail --show-error --location --progress-bar \
-o "$TEMP_DIR/Ollama-darwin.zip" "$DOWNLOAD_URL"
status "Installing Ollama to /Applications..."
unzip -q "$TEMP_DIR/Ollama-darwin.zip" -d "$TEMP_DIR"
mv "$TEMP_DIR/Ollama.app" "/Applications/"
status "Adding 'ollama' command to PATH (may require password)..."
mkdir -p "/usr/local/bin" 2>/dev/null || sudo mkdir -p "/usr/local/bin"
ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama" 2>/dev/null || \
sudo ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama"
if [ -z "${OLLAMA_NO_START:-}" ]; then
status "Starting Ollama..."
open -a Ollama --args hidden
fi
status "Install complete. You can now run 'ollama'."
exit 0
fi
###########################################
# Linux
###########################################
[ "$OS" = "Linux" ] || error 'This script is intended to run on Linux and macOS only.'
IS_WSL2=false
KERN=$(uname -r)

View File

@@ -1,422 +0,0 @@
package server
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
const (
serverConfigFilename = "server.json"
serverConfigVersion = 1
)
var errAliasCycle = errors.New("alias cycle detected")
type aliasEntry struct {
Alias string `json:"alias"`
Target string `json:"target"`
PrefixMatching bool `json:"prefix_matching,omitempty"`
}
type serverConfig struct {
Version int `json:"version"`
Aliases []aliasEntry `json:"aliases"`
}
type store struct {
mu sync.RWMutex
path string
entries map[string]aliasEntry // normalized alias -> entry (exact matches)
prefixEntries []aliasEntry // prefix matches, sorted longest-first
}
func createStore(path string) (*store, error) {
store := &store{
path: path,
entries: make(map[string]aliasEntry),
}
if err := store.load(); err != nil {
return nil, err
}
return store, nil
}
func (s *store) load() error {
data, err := os.ReadFile(s.path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
var cfg serverConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return err
}
if cfg.Version != 0 && cfg.Version != serverConfigVersion {
return fmt.Errorf("unsupported router config version %d", cfg.Version)
}
for _, entry := range cfg.Aliases {
targetName := model.ParseName(entry.Target)
if !targetName.IsValid() {
slog.Warn("invalid alias target in router config", "target", entry.Target)
continue
}
canonicalTarget := displayAliasName(targetName)
if entry.PrefixMatching {
// Prefix aliases don't need to be valid model names
alias := strings.TrimSpace(entry.Alias)
if alias == "" {
slog.Warn("empty prefix alias in router config")
continue
}
s.prefixEntries = append(s.prefixEntries, aliasEntry{
Alias: alias,
Target: canonicalTarget,
PrefixMatching: true,
})
} else {
aliasName := model.ParseName(entry.Alias)
if !aliasName.IsValid() {
slog.Warn("invalid alias name in router config", "alias", entry.Alias)
continue
}
canonicalAlias := displayAliasName(aliasName)
s.entries[normalizeAliasKey(aliasName)] = aliasEntry{
Alias: canonicalAlias,
Target: canonicalTarget,
}
}
}
// Sort prefix entries by alias length descending (longest prefix wins)
s.sortPrefixEntriesLocked()
return nil
}
func (s *store) saveLocked() error {
dir := filepath.Dir(s.path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
// Combine exact and prefix entries
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
for _, entry := range s.entries {
entries = append(entries, entry)
}
entries = append(entries, s.prefixEntries...)
sort.Slice(entries, func(i, j int) bool {
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
})
cfg := serverConfig{
Version: serverConfigVersion,
Aliases: entries,
}
f, err := os.CreateTemp(dir, "router-*.json")
if err != nil {
return err
}
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(cfg); err != nil {
_ = f.Close()
_ = os.Remove(f.Name())
return err
}
if err := f.Close(); err != nil {
_ = os.Remove(f.Name())
return err
}
if err := os.Chmod(f.Name(), 0o644); err != nil {
_ = os.Remove(f.Name())
return err
}
return os.Rename(f.Name(), s.path)
}
func (s *store) ResolveName(name model.Name) (model.Name, bool, error) {
// If a local model exists, do not allow alias shadowing (highest priority).
exists, err := localModelExists(name)
if err != nil {
return name, false, err
}
if exists {
return name, false, nil
}
key := normalizeAliasKey(name)
s.mu.RLock()
entry, exactMatch := s.entries[key]
var prefixMatch *aliasEntry
if !exactMatch {
// Try prefix matching - prefixEntries is sorted longest-first
nameStr := strings.ToLower(displayAliasName(name))
for i := range s.prefixEntries {
prefix := strings.ToLower(s.prefixEntries[i].Alias)
if strings.HasPrefix(nameStr, prefix) {
prefixMatch = &s.prefixEntries[i]
break // First match is longest due to sorting
}
}
}
s.mu.RUnlock()
if !exactMatch && prefixMatch == nil {
return name, false, nil
}
var current string
var visited map[string]struct{}
if exactMatch {
visited = map[string]struct{}{key: {}}
current = entry.Target
} else {
// For prefix match, use the target as-is
visited = map[string]struct{}{}
current = prefixMatch.Target
}
targetKey := normalizeAliasKeyString(current)
for {
targetName := model.ParseName(current)
if !targetName.IsValid() {
return name, false, fmt.Errorf("alias target %q is invalid", current)
}
if _, seen := visited[targetKey]; seen {
return name, false, errAliasCycle
}
visited[targetKey] = struct{}{}
s.mu.RLock()
next, ok := s.entries[targetKey]
s.mu.RUnlock()
if !ok {
return targetName, true, nil
}
current = next.Target
targetKey = normalizeAliasKeyString(current)
}
}
func (s *store) Set(alias, target model.Name, prefixMatching bool) error {
targetKey := normalizeAliasKey(target)
s.mu.Lock()
defer s.mu.Unlock()
if prefixMatching {
// For prefix aliases, we skip cycle detection since prefix matching
// works differently and the target is a specific model
aliasStr := displayAliasName(alias)
// Remove any existing prefix entry with the same alias
for i, e := range s.prefixEntries {
if strings.EqualFold(e.Alias, aliasStr) {
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
break
}
}
s.prefixEntries = append(s.prefixEntries, aliasEntry{
Alias: aliasStr,
Target: displayAliasName(target),
PrefixMatching: true,
})
s.sortPrefixEntriesLocked()
return s.saveLocked()
}
aliasKey := normalizeAliasKey(alias)
if aliasKey == targetKey {
return fmt.Errorf("alias cannot point to itself")
}
visited := map[string]struct{}{aliasKey: {}}
currentKey := targetKey
for {
if _, seen := visited[currentKey]; seen {
return errAliasCycle
}
visited[currentKey] = struct{}{}
next, ok := s.entries[currentKey]
if !ok {
break
}
currentKey = normalizeAliasKeyString(next.Target)
}
s.entries[aliasKey] = aliasEntry{
Alias: displayAliasName(alias),
Target: displayAliasName(target),
}
return s.saveLocked()
}
func (s *store) Delete(alias model.Name) (bool, error) {
aliasKey := normalizeAliasKey(alias)
s.mu.Lock()
defer s.mu.Unlock()
// Try exact match first
if _, ok := s.entries[aliasKey]; ok {
delete(s.entries, aliasKey)
return true, s.saveLocked()
}
// Try prefix entries
aliasStr := displayAliasName(alias)
for i, e := range s.prefixEntries {
if strings.EqualFold(e.Alias, aliasStr) {
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
return true, s.saveLocked()
}
}
return false, nil
}
// DeleteByString deletes an alias by its raw string value, useful for prefix
// aliases that may not be valid model names.
func (s *store) DeleteByString(alias string) (bool, error) {
alias = strings.TrimSpace(alias)
aliasLower := strings.ToLower(alias)
s.mu.Lock()
defer s.mu.Unlock()
// Try prefix entries first (since this is mainly for prefix aliases)
for i, e := range s.prefixEntries {
if strings.EqualFold(e.Alias, alias) {
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
return true, s.saveLocked()
}
}
// Also check exact entries by normalized key
if _, ok := s.entries[aliasLower]; ok {
delete(s.entries, aliasLower)
return true, s.saveLocked()
}
return false, nil
}
func (s *store) List() []aliasEntry {
s.mu.RLock()
defer s.mu.RUnlock()
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
for _, entry := range s.entries {
entries = append(entries, entry)
}
entries = append(entries, s.prefixEntries...)
sort.Slice(entries, func(i, j int) bool {
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
})
return entries
}
func normalizeAliasKey(name model.Name) string {
return strings.ToLower(displayAliasName(name))
}
func (s *store) sortPrefixEntriesLocked() {
sort.Slice(s.prefixEntries, func(i, j int) bool {
// Sort by length descending (longest prefix first)
return len(s.prefixEntries[i].Alias) > len(s.prefixEntries[j].Alias)
})
}
func normalizeAliasKeyString(value string) string {
n := model.ParseName(value)
if !n.IsValid() {
return strings.ToLower(strings.TrimSpace(value))
}
return normalizeAliasKey(n)
}
func displayAliasName(n model.Name) string {
display := n.DisplayShortest()
if strings.EqualFold(n.Tag, "latest") {
if idx := strings.LastIndex(display, ":"); idx != -1 {
return display[:idx]
}
}
return display
}
func localModelExists(name model.Name) (bool, error) {
manifests, err := manifest.Manifests(true)
if err != nil {
return false, err
}
needle := name.String()
for existing := range manifests {
if strings.EqualFold(existing.String(), needle) {
return true, nil
}
}
return false, nil
}
func serverConfigPath() string {
home, err := os.UserHomeDir()
if err != nil {
return filepath.Join(".ollama", serverConfigFilename)
}
return filepath.Join(home, ".ollama", serverConfigFilename)
}
func (s *Server) aliasStore() (*store, error) {
s.aliasesOnce.Do(func() {
s.aliases, s.aliasesErr = createStore(serverConfigPath())
})
return s.aliases, s.aliasesErr
}
func (s *Server) resolveAlias(name model.Name) (model.Name, bool, error) {
store, err := s.aliasStore()
if err != nil {
return name, false, err
}
if store == nil {
return name, false, nil
}
return store.ResolveName(name)
}

View File

@@ -22,7 +22,6 @@ import (
"os/signal"
"slices"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@@ -82,9 +81,6 @@ type Server struct {
addr net.Addr
sched *Scheduler
defaultNumCtx int
aliasesOnce sync.Once
aliases *store
aliasesErr error
}
func init() {
@@ -195,16 +191,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
resolvedName, _, err := s.resolveAlias(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
name = resolvedName
// We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure.
name, err = getExistingName(name)
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
@@ -1591,9 +1580,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.POST("/api/copy", s.CopyHandler)
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
// Inference
r.GET("/api/ps", s.PsHandler)
@@ -1964,20 +1950,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
resolvedName, _, err := s.resolveAlias(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
name = resolvedName
name, err = getExistingName(name)
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
m, err := GetModel(name.String())
m, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):

View File

@@ -1,159 +0,0 @@
package server
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/types/model"
)
type aliasListResponse struct {
Aliases []aliasEntry `json:"aliases"`
}
type aliasDeleteRequest struct {
Alias string `json:"alias"`
}
func (s *Server) ListAliasesHandler(c *gin.Context) {
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var aliases []aliasEntry
if store != nil {
aliases = store.List()
}
c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases})
}
func (s *Server) CreateAliasHandler(c *gin.Context) {
var req aliasEntry
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Alias = strings.TrimSpace(req.Alias)
req.Target = strings.TrimSpace(req.Target)
if req.Alias == "" || req.Target == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"})
return
}
// Target must always be a valid model name
targetName := model.ParseName(req.Target)
if !targetName.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)})
return
}
var aliasName model.Name
if req.PrefixMatching {
// For prefix aliases, we still parse the alias to normalize it,
// but we allow any non-empty string since prefix patterns may not be valid model names
aliasName = model.ParseName(req.Alias)
// Even if not valid as a model name, we accept it for prefix matching
} else {
aliasName = model.ParseName(req.Alias)
if !aliasName.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
return
}
if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"})
return
}
exists, err := localModelExists(aliasName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if exists {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)})
return
}
}
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := store.Set(aliasName, targetName, req.PrefixMatching); err != nil {
status := http.StatusInternalServerError
if errors.Is(err, errAliasCycle) {
status = http.StatusBadRequest
}
c.AbortWithStatusJSON(status, gin.H{"error": err.Error()})
return
}
resp := aliasEntry{
Alias: displayAliasName(aliasName),
Target: displayAliasName(targetName),
PrefixMatching: req.PrefixMatching,
}
if req.PrefixMatching && !aliasName.IsValid() {
// For prefix aliases that aren't valid model names, use the raw alias
resp.Alias = req.Alias
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) DeleteAliasHandler(c *gin.Context) {
var req aliasDeleteRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Alias = strings.TrimSpace(req.Alias)
if req.Alias == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"})
return
}
store, err := s.aliasStore()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
aliasName := model.ParseName(req.Alias)
var deleted bool
if aliasName.IsValid() {
deleted, err = store.Delete(aliasName)
} else {
// For invalid model names (like prefix aliases), try deleting by raw string
deleted, err = store.DeleteByString(req.Alias)
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !deleted {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)})
return
}
c.JSON(http.StatusOK, gin.H{"deleted": true})
}

View File

@@ -1,426 +0,0 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
func TestAliasShadowingRejected(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
s := Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "shadowed-model",
RemoteHost: "example.com",
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", w.Code)
}
}
func TestAliasResolvesForChatRemote(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
var remoteModel string
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req api.ChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatal(err)
}
remoteModel = req.Model
w.Header().Set("Content-Type", "application/json")
resp := api.ChatResponse{
Model: req.Model,
Done: true,
DoneReason: "load",
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
t.Fatal(err)
}
}))
defer rs.Close()
p, err := url.Parse(rs.URL)
if err != nil {
t.Fatal(err)
}
t.Setenv("OLLAMA_REMOTES", p.Hostname())
s := Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "target-model",
RemoteHost: rs.URL,
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "alias-model",
Messages: []api.Message{{Role: "user", Content: "hi"}},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Model != "alias-model" {
t.Fatalf("expected response model to be alias-model, got %q", resp.Model)
}
if remoteModel != "test" {
t.Fatalf("expected remote model to be 'test', got %q", remoteModel)
}
}
func TestPrefixAliasBasicMatching(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Create a prefix alias: "myprefix-" -> "targetmodel"
targetName := model.ParseName("targetmodel")
// Set a prefix alias (using "myprefix-" as the pattern)
store.mu.Lock()
store.prefixEntries = append(store.prefixEntries, aliasEntry{
Alias: "myprefix-",
Target: "targetmodel",
PrefixMatching: true,
})
store.mu.Unlock()
// Test that "myprefix-foo" resolves to "targetmodel"
testName := model.ParseName("myprefix-foo")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
if resolved.DisplayShortest() != targetName.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", targetName.DisplayShortest(), resolved.DisplayShortest())
}
// Test that "otherprefix-foo" does not resolve
otherName := model.ParseName("otherprefix-foo")
_, wasResolved, err = store.ResolveName(otherName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if wasResolved {
t.Fatal("expected name not to be resolved")
}
// Test that exact alias takes precedence
exactAlias := model.ParseName("myprefix-exact")
exactTarget := model.ParseName("exacttarget")
if err := store.Set(exactAlias, exactTarget, false); err != nil {
t.Fatal(err)
}
resolved, wasResolved, err = store.ResolveName(exactAlias)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
if resolved.DisplayShortest() != exactTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q (exact match), got %q", exactTarget.DisplayShortest(), resolved.DisplayShortest())
}
}
func TestPrefixAliasLongestMatchWins(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Add two prefix aliases with overlapping patterns
store.mu.Lock()
store.prefixEntries = []aliasEntry{
{Alias: "abc-", Target: "short-target", PrefixMatching: true},
{Alias: "abc-def-", Target: "long-target", PrefixMatching: true},
}
store.sortPrefixEntriesLocked()
store.mu.Unlock()
// "abc-def-ghi" should match the longer prefix "abc-def-"
testName := model.ParseName("abc-def-ghi")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
expectedLongTarget := model.ParseName("long-target")
if resolved.DisplayShortest() != expectedLongTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q (longest prefix match), got %q", expectedLongTarget.DisplayShortest(), resolved.DisplayShortest())
}
// "abc-xyz" should match the shorter prefix "abc-"
testName2 := model.ParseName("abc-xyz")
resolved, wasResolved, err = store.ResolveName(testName2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
expectedShortTarget := model.ParseName("short-target")
if resolved.DisplayShortest() != expectedShortTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", expectedShortTarget.DisplayShortest(), resolved.DisplayShortest())
}
}
func TestPrefixAliasChain(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Create a chain: prefix "test-" -> "intermediate" -> "final"
intermediate := model.ParseName("intermediate")
final := model.ParseName("final")
// Add prefix alias
store.mu.Lock()
store.prefixEntries = []aliasEntry{
{Alias: "test-", Target: "intermediate", PrefixMatching: true},
}
store.mu.Unlock()
// Add exact alias for the intermediate step
if err := store.Set(intermediate, final, false); err != nil {
t.Fatal(err)
}
// "test-foo" should resolve through the chain to "final"
testName := model.ParseName("test-foo")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved")
}
if resolved.DisplayShortest() != final.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", final.DisplayShortest(), resolved.DisplayShortest())
}
}
func TestPrefixAliasCRUD(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
s := Server{}
// Create a prefix alias via API
w := createRequest(t, s.CreateAliasHandler, aliasEntry{
Alias: "myprefix-",
Target: "llama2",
PrefixMatching: true,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
var createResp aliasEntry
if err := json.NewDecoder(w.Body).Decode(&createResp); err != nil {
t.Fatal(err)
}
if !createResp.PrefixMatching {
t.Fatal("expected prefix_matching to be true in response")
}
// List aliases and verify the prefix alias is included
w = createRequest(t, s.ListAliasesHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var listResp aliasListResponse
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
t.Fatal(err)
}
found := false
for _, a := range listResp.Aliases {
if a.PrefixMatching && a.Target == "llama2" {
found = true
break
}
}
if !found {
t.Fatal("expected to find prefix alias in list")
}
// Delete the prefix alias
w = createRequest(t, s.DeleteAliasHandler, aliasDeleteRequest{Alias: "myprefix-"})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify it's deleted
w = createRequest(t, s.ListAliasesHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
t.Fatal(err)
}
for _, a := range listResp.Aliases {
if a.PrefixMatching {
t.Fatal("expected prefix alias to be deleted")
}
}
}
func TestPrefixAliasCaseInsensitive(t *testing.T) {
tmpDir := t.TempDir()
store, err := createStore(filepath.Join(tmpDir, "server.json"))
if err != nil {
t.Fatal(err)
}
// Add a prefix alias with mixed case
store.mu.Lock()
store.prefixEntries = []aliasEntry{
{Alias: "MyPrefix-", Target: "targetmodel", PrefixMatching: true},
}
store.mu.Unlock()
// Test that matching is case-insensitive
testName := model.ParseName("myprefix-foo")
resolved, wasResolved, err := store.ResolveName(testName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved (case-insensitive)")
}
expectedTarget := model.ParseName("targetmodel")
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
}
// Test uppercase request
testName2 := model.ParseName("MYPREFIX-BAR")
_, wasResolved, err = store.ResolveName(testName2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected name to be resolved (uppercase)")
}
}
func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("HOME", t.TempDir())
s := Server{}
// Create a local model that would match a prefix alias
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "myprefix-localmodel",
RemoteHost: "example.com",
From: "test",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Create a prefix alias that would match the local model name
w = createRequest(t, s.CreateAliasHandler, aliasEntry{
Alias: "myprefix-",
Target: "someothermodel",
PrefixMatching: true,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify that resolving "myprefix-localmodel" returns the local model, not the alias target
store, err := s.aliasStore()
if err != nil {
t.Fatal(err)
}
localModelName := model.ParseName("myprefix-localmodel")
resolved, wasResolved, err := store.ResolveName(localModelName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if wasResolved {
t.Fatalf("expected local model to take precedence (wasResolved should be false), but got resolved to %q", resolved.DisplayShortest())
}
if resolved.DisplayShortest() != localModelName.DisplayShortest() {
t.Fatalf("expected resolved name to be local model %q, got %q", localModelName.DisplayShortest(), resolved.DisplayShortest())
}
// Also verify that a non-local model matching the prefix DOES resolve to the alias target
nonLocalName := model.ParseName("myprefix-nonexistent")
resolved, wasResolved, err = store.ResolveName(nonLocalName)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !wasResolved {
t.Fatal("expected non-local model to resolve via prefix alias")
}
expectedTarget := model.ParseName("someothermodel")
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
}
}

View File

@@ -417,9 +417,9 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
numParallel = 1
}
// Some architectures are not safe with num_parallel > 1.
// `mllama`, `qwen3vl`, and `qwen3vlmoe` are snowflakes and uses an encoder cache which cannot be used with num_parallel > 1
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe"}, req.model.Config.ModelFamily) && numParallel != 1 {
numParallel = 1
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
}

77
x/kvcache/cache.go Normal file
View File

@@ -0,0 +1,77 @@
package kvcache
import (
"errors"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
var (
ErrKvCacheFull = errors.New("could not find a kv cache slot")
ErrNotSupported = errors.New("model does not support operation")
)
type Cache interface {
// ** used by model implementations **
// SetLayer sets the active layer of the cache
SetLayer(layer int)
// Get returns the history of key and value tensors plus a mask
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
// Put stores a batch of key and value in the cache
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management **
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it
Close()
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. reserve is to preallocate memory
// without actually storing data in the cache.
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)
// CanResume returns true if the cache can continue with the next token at
// the given position and sequence. Assumes that the caller has already
// verified the contents of the cache.
CanResume(seq int, pos int32) bool
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
//
// If an error occurs, the entire context for the sequence should be
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}

144
x/kvcache/causal.go Normal file
View File

@@ -0,0 +1,144 @@
//go:build mlx
package kvcache
import (
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
type Causal struct {
DType ml.DType
// locations for data storage for this batch
curLocPut ml.Tensor
// locations for data storage for this batch
curLocGet ml.Tensor
// the active layer for Get and Put
curLayer int
capacity int
offset int
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
// TODO is this needed per layer, or will it always be consistent?
kHeadDims, vHeadDims, numKVHeads map[int]int
}
func NewCausalCache() *Causal {
return &Causal{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
kHeadDims: make(map[int]int),
vHeadDims: make(map[int]int),
numKVHeads: make(map[int]int),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.DType = dtype
c.capacity = capacity
c.backend = backend
}
func (c *Causal) SetConfig(config ml.CacheConfig) {}
func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
func (c *Causal) Close() {
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
locsPut := make([]int32, len(batch.Positions))
for i := c.offset; i < len(batch.Positions); i++ {
locsPut[i-c.offset] = int32(i)
}
c.offset += len(batch.Positions)
locsGet := make([]int32, c.offset)
for i := range c.offset {
locsGet[i] = int32(i)
}
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
// slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
return nil
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(3)
vHeadDim := value.Dim(3)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
kCellSize := kHeadDim * numKVHeads
vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
c.kHeadDims[c.curLayer] = kHeadDim
c.vHeadDims[c.curLayer] = vHeadDim
c.numKVHeads[c.curLayer] = numKVHeads
}
key = key.Reshape(ctx, batchSize, 1, kCellSize)
// slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
// slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
// slog.Info("XXX Causal.Put ", "key", key)
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
value = value.Reshape(ctx, batchSize, 1, vCellSize)
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
}
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := c.kHeadDims[c.curLayer]
vHeadDim := c.vHeadDims[c.curLayer]
numKVHeads := c.numKVHeads[c.curLayer]
// rowSize := numKVHeads * c.curBatchSize
// cachedSize := c.curMask.Dim(1)
cachedSize := c.curLocGet.Dim(0)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
return key, value, nil
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("not implemented")
}
func (c *Causal) CanResume(seq int, pos int32) bool {
panic("not implemented")
}
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
panic("not implemented")
}

156
x/kvcache/encoder.go Normal file
View File

@@ -0,0 +1,156 @@
package kvcache
// import (
// "fmt"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// // Encoder cache stores K and V tensors that are position independent
// //
// // The tensors can be of any shape and will be returned as they were stored
// // The mask is currently always nil
// //
// // Not currently safe for multiple sequences
// type EncoderCache struct {
// // config controls mostly backend-specific optimizations
// config *ml.CacheConfig
// // ** current forward pass **
// // the active layer for Get and Put
// curLayer int
// // if something is stored during this pass, this
// // will be the position (but there is no guarantee
// // anything will be stored)
// curPos int32
// // curReserve indicates that this forward pass is only for
// // memory reservation and we should not update our metadata
// // based on it.
// curReserve bool
// // ** cache metadata **
// // was something stored in the cache?
// encoderCached bool
// // position of the cached data
// encoderPos int32
// // ** cache data storage **
// backend ml.Backend
// ctxs map[int]ml.Context
// keys, values map[int]ml.Tensor
// }
// func NewEncoderCache() *EncoderCache {
// return &EncoderCache{
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// }
// }
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// if c.config == nil {
// var config ml.CacheConfig
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
// config = cc.CacheConfig()
// }
// c.config = &config
// }
// if maxSequences > 1 {
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
// }
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
// }
// c.backend = backend
// }
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
// if c.config != nil {
// panic("config cannot be changed after being previously set, either by the model or backend")
// }
// c.config = &config
// }
// func (c *EncoderCache) Close() {
// for _, ctx := range c.ctxs {
// ctx.Close()
// }
// }
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// // We work with the most recent image
// if len(batch.Multimodal) > 0 {
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
// }
// c.curReserve = reserve
// return nil
// }
// func (c *EncoderCache) SetLayer(layer int) {
// c.curLayer = layer
// }
// func (c *EncoderCache) EncoderCached() bool {
// return c.encoderCached
// }
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// return c.keys[c.curLayer], c.values[c.curLayer], nil
// }
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
// if !c.curReserve {
// c.encoderPos = c.curPos
// c.encoderCached = true
// }
// if c.config.PermutedV {
// value = value.Transpose(ctx, 1, 2, 0, 3)
// }
// if _, ok := c.ctxs[c.curLayer]; !ok {
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
// }
// if _, ok := c.keys[c.curLayer]; !ok {
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
// }
// if _, ok := c.values[c.curLayer]; !ok {
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
// }
// ctx.Forward(
// key.Copy(ctx, c.keys[c.curLayer]),
// value.Copy(ctx, c.values[c.curLayer]),
// )
// }
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// panic("encoder cache does not support multiple sequences")
// }
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
// return true
// }
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
// c.encoderCached = false
// }
// return nil
// }

110
x/kvcache/wrapper.go Normal file
View File

@@ -0,0 +1,110 @@
package kvcache
// import (
// "math"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// // Wrapper cache is a container for multiple types of caches,
// // such as for the encoding and decoding portions of a model.
// type WrapperCache struct {
// // caches we are wrapping
// caches []Cache
// // cache to be used for this layer
// curType int
// }
// func NewWrapperCache(caches ...Cache) *WrapperCache {
// return &WrapperCache{
// caches: caches,
// }
// }
// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// for _, cache := range c.caches {
// cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
// }
// }
// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
// for _, cache := range c.caches {
// cache.SetConfig(config)
// }
// }
// func (c *WrapperCache) Close() {
// for _, cache := range c.caches {
// cache.Close()
// }
// }
// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// for i, cache := range c.caches {
// err := cache.StartForward(ctx, batch, reserve)
// if err != nil {
// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
// for j := i - 1; j >= 0; j-- {
// for k := range batch.Positions {
// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
// }
// }
// return err
// }
// }
// c.curType = 0
// return nil
// }
// func (c *WrapperCache) SetLayer(layer int) {
// for _, cache := range c.caches {
// cache.SetLayer(layer)
// }
// }
// func (c *WrapperCache) SetLayerType(layerType int) {
// c.curType = layerType
// }
// func (c *WrapperCache) UnderlyingCache() Cache {
// return c.caches[c.curType]
// }
// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// return c.caches[c.curType].Get(ctx)
// }
// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
// c.caches[c.curType].Put(ctx, key, value)
// }
// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// for _, cache := range c.caches {
// cache.CopyPrefix(srcSeq, dstSeq, len)
// }
// }
// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
// for _, cache := range c.caches {
// if !cache.CanResume(seq, pos) {
// return false
// }
// }
// return true
// }
// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
// for _, cache := range c.caches {
// err := cache.Remove(seq, beginIndex, endIndex)
// if err != nil {
// return err
// }
// }
// return nil
// }

433
x/ml/backend.go Normal file
View File

@@ -0,0 +1,433 @@
package ml
import (
"fmt"
"log/slog"
"os"
"github.com/ollama/ollama/fs"
)
type Backend interface {
// Close frees all memory associated with this backend
// Close()
// Load(ctx context.Context, progress func(float32)) error
// BackendMemory returns the memory allocations that were made for this model
// BackendMemory() BackendMemory
Config() fs.Config
Get(name string) Tensor
NewContext() Context
// NewContextSize(size int) Context
// Enumerate the devices available for inference via this backend
// BackendDevices() []DeviceInfo
}
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
// CacheConfig controls optimizations (mostly backend-specific) that may transform
// the output the cache to work better with specific kernels.
type CacheConfig struct {
// CachePadding specifies the multiple for the number of tokens of cache history
// that will be returned from cache Get for k, v and mask. The capacity of the
// cache itself will also be increased to a multiple of this size if needed.
CachePadding int
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
// and return the permuted version via Get. This uses the cache copy operation
// to avoid a Contiguous call on the permuted tensor.
PermutedV bool
// MaskDType specifies the data type for generating the mask. If unset it will
// default to DTypeF32.
MaskDType DType
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
// Any position that does not correspond to an actual token will be filled with -Inf.
MaskBatchPadding int
}
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// AllocMemory causes the backend to allocate memory for the model. If
// false, this is only being used for discovering the required amount of
// memory and cannot load the model for running.
AllocMemory bool
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
// GPULayers is the set of layers to offload to GPUs
GPULayers GPULayersList
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool
}
var backends = make(map[string]func(string, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
backends[name] = f
}
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
be := os.Getenv("OLLAMA_BACKEND")
if be == "" {
be = "mlx"
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
}
slog.Info("Loading new engine", "backend", be)
if backend, ok := backends[be]; ok {
return backend(modelPath, params)
}
return nil, fmt.Errorf("unsupported backend")
}
type Context interface {
Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor
// FromBytes(dtype DType, s []byte, shape ...int) Tensor
FromFloats(s []float32, shape ...int) Tensor
FromInts(s []int32, shape ...int) Tensor
RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
Arange(start, stop, step float32, dtype DType) Tensor
Forward(...Tensor) Context
// SetBatchSize provides a hint on the batch size to optimize processing
// Uses heuristics if not set
// SetBatchSize(int)
Compute(...Tensor)
// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
// Reserve is analogous to Compute but rather than executing a
// graph, simply preallocates memory. Typically called with a
// worst case graph to ensure all resources are available for
// for future inference.
// Reserve()
// MaxGraphNodes() int
Close()
// Input returns a context appropriate for creating tensors that are
// inputs to the model (which includes things like output locations)
Input() Context
// Layer returns a context appropriate for creating intermediate tensors
Layer(int) Context
// Load a tensor from "filename" safetensors file, and compare with the input tensor
// Returns error if the shape is inconsistent, or similarity measures are below 99%
CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
}
type RoPEOptions struct {
Base *float32
Freqs Tensor
}
func WithRoPEBase(base float32) func(*RoPEOptions) {
return func(opts *RoPEOptions) {
opts.Base = &base
}
}
func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
return func(opts *RoPEOptions) {
opts.Freqs = freqs
}
}
type Tensor interface {
ToString() string
RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
// TakeAxes(ctx Context, axes int, indicies ...int) Tensor
Dim(n int) int
Stride(n int) int
Shape() []int
DType() DType
// Cast(ctx Context, dtype DType) Tensor
// Bytes() []byte
Floats() []float32
Ints() []int32
// FromBytes([]byte)
// FromFloats([]float32)
// FromInts([]int32)
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor
// Mul(ctx Context, t2 Tensor) Tensor
// Div(ctx Context, t2 Tensor) Tensor
Max(ctx Context, axes []int, keepDims bool) Tensor
Min(ctx Context, axes []int, keepDims bool) Tensor
Matmul(ctx Context, a2 Tensor) Tensor
// Mulmat(ctx Context, t2 Tensor) Tensor
// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
// MulmatID(ctx Context, t2, ids Tensor) Tensor
// AddID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor
L2Norm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
// SumRows(ctx Context) Tensor
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
// Sin(ctx Context) Tensor
// Cos(ctx Context) Tensor
// Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
// QuickGELU(ctx Context, up ...Tensor) Tensor
// SILU(ctx Context, up ...Tensor) Tensor
// RELU(ctx Context, up ...Tensor) Tensor
// Sigmoid(ctx Context) Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
Reshape(ctx Context, shape ...int) Tensor
AsStrided(ctx Context, shape, strides []int, offset int) Tensor
Transpose(ctx Context, shape ...int) Tensor
Contiguous(ctx Context, allowColMajor bool) Tensor
// Pad(ctx Context, shape ...int) Tensor
// Stack(ctx Context, dim int, s ...Tensor) Tensor
// Repeat repeats the tensor n times along dimension dim
// Repeat(ctx Context, dim, n int) Tensor
// Concat(ctx Context, t2 Tensor, dim int) Tensor
// Rows(ctx Context, t2 Tensor) Tensor
// TODO these probably aren't actually needed - false starts on trying to wire up cache
// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
Copy(ctx Context, t2 Tensor) Tensor
// Duplicate(ctx Context) Tensor
// Slice(ctx Context, dim, low, high, step int) Tensor
// Chunk(ctx Context, dim int, size int) []Tensor
// ChunkSections(ctx Context, dim int, sections ...int) []Tensor
// TopK(ctx Context, k int) Tensor
// Argsort(ctx Context) Tensor
// Mean(ctx Context) Tensor
// Variance(ctx Context) Tensor
// Stddev(ctx Context) Tensor
// Sqr(ctx Context) Tensor
// Sqrt(ctx Context) Tensor
// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}
// ScaledDotProductAttention implements a fused attention
// operation equivalent to following code on a tensor named
// query:
//
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
//
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
//
// kq = kq.Softmax(ctx)
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
// type ScaledDotProductAttention interface {
// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
// }
// type number interface {
// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
// ~float32 | ~float64 |
// ~complex64 | ~complex128
// }
// func mul[T number](s ...T) T {
// p := T(1)
// for _, v := range s {
// p *= v
// }
// return p
// }
// type DumpOptions func(*dumpOptions)
// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
// func DumpWithPrecision(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.Precision = n
// }
// }
// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
// // beginning and end of each dimension will be printed.
// func DumpWithThreshold(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.Threshold = n
// }
// }
// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
// func DumpWithEdgeItems(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.EdgeItems = n
// }
// }
// type dumpOptions struct {
// Precision, Threshold, EdgeItems int
// }
// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
// for _, optsFunc := range optsFuncs {
// optsFunc(&opts)
// }
// if mul(t.Shape()...) <= opts.Threshold {
// opts.EdgeItems = math.MaxInt
// }
// switch t.DType() {
// case DTypeFloat32:
// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// })
// case DTypeFloat16: // TODO other types...
// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
// f32 = t.Copy(ctx, f32)
// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// })
// case DTypeInt32:
// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
// return strconv.FormatInt(int64(i), 10)
// })
// default:
// return "<unsupported>"
// }
// }
// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
// if t.Bytes() == nil {
// ctx.Compute(t)
// }
// s := make(S, mul(t.Shape()...))
// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
// panic(err)
// }
// shape := t.Shape()
// slices.Reverse(shape)
// var sb strings.Builder
// var f func([]int, int)
// f = func(dims []int, stride int) {
// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
// sb.WriteString("[")
// defer func() { sb.WriteString("]") }()
// for i := 0; i < dims[0]; i++ {
// if i >= items && i < dims[0]-items {
// sb.WriteString("..., ")
// // skip to next printable element
// skip := dims[0] - 2*items
// if len(dims) > 1 {
// stride += mul(append(dims[1:], skip)...)
// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
// }
// i += skip - 1
// } else if len(dims) > 1 {
// f(dims[1:], stride)
// stride += mul(dims[1:]...)
// if i < dims[0]-1 {
// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
// }
// } else {
// text := fn(s[stride+i])
// if len(text) > 0 && text[0] != '-' {
// sb.WriteString(" ")
// }
// sb.WriteString(text)
// if i < dims[0]-1 {
// sb.WriteString(", ")
// }
// }
// }
// }
// f(shape, 0)
// return sb.String()
// }
type DType int
const (
DTypeBool DType = iota
DTypeUint8
DTypeUint16
DTypeUint32
DTypeUint64
DTypeInt8
DTypeInt16
DTypeInt32
DTypeInt64
DTypeFloat16
DTypeFloat32
DTypeFloat64
DTypeBfloat16
DTypeComplex64
)
type SamplingMode int
const (
SamplingModeNearest SamplingMode = iota
SamplingModeBilinear
)

3
x/ml/backend/backend.go Normal file
View File

@@ -0,0 +1,3 @@
package backend
// _ "github.com/ollama/ollama/x/ml/backend/mlx"

View File

@@ -0,0 +1,61 @@
include(FetchContent)
# Read MLX version from top-level file (shared with Dockerfile)
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
set(MLX_C_BUILD_EXAMPLES OFF)
set(MLX_BUILD_GGUF OFF)
set(MLX_BUILD_SAFETENSORS ON)
function(set_target_output_directory _target)
if(TARGET ${_target})
set_target_properties(${_target} PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
ARCHIVE_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
)
endif()
endfunction()
# Check for Metal support (macOS only)
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(NOT MLX_METAL_VERSION)
message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF")
set(MLX_BUILD_METAL OFF)
endif()
else()
# On Linux, disable Metal backend
message(STATUS "Non-macOS platform detected. Setting MLX_BUILD_METAL=OFF")
set(MLX_BUILD_METAL OFF)
endif()
# Map CMAKE_CUDA_ARCHITECTURES to MLX_CUDA_ARCHITECTURES if not explicitly set
if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}")
endif()
# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available
if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER)
set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE)
message(STATUS "Enabling MLX CUDA backend with architectures: ${MLX_CUDA_ARCHITECTURES}")
elseif(MLX_CUDA_ARCHITECTURES)
message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled")
endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG ${MLX_C_GIT_TAG})
FetchContent_MakeAvailable(mlx-c)
set_target_output_directory(mlx)
set_target_output_directory(mlxc)

1278
x/ml/backend/mlx/mlx.go Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,92 @@
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
#include "mlx_dynamic.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef _WIN32
#include <windows.h>
typedef HMODULE lib_handle_t;
#define LOAD_LIB(path) LoadLibraryA(path)
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
#define CLOSE_LIB(handle) FreeLibrary(handle)
#define LIB_ERROR() "LoadLibrary failed"
static const char* LIB_NAMES[] = {"libmlxc.dll", NULL};
#else
#include <dlfcn.h>
typedef void* lib_handle_t;
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
#define GET_SYMBOL(handle, name) dlsym(handle, name)
#define CLOSE_LIB(handle) dlclose(handle)
#define LIB_ERROR() dlerror()
#ifdef __APPLE__
static const char* LIB_NAMES[] = {
"libmlxc.dylib",
"@loader_path/../build/lib/ollama/libmlxc.dylib",
"@executable_path/../build/lib/ollama/libmlxc.dylib",
"build/lib/ollama/libmlxc.dylib",
"../build/lib/ollama/libmlxc.dylib",
NULL
};
#else
static const char* LIB_NAMES[] = {
"libmlxc.so",
"$ORIGIN/../build/lib/ollama/libmlxc.so",
"build/lib/ollama/libmlxc.so",
"../build/lib/ollama/libmlxc.so",
NULL
};
#endif
#endif
static lib_handle_t mlx_handle = NULL;
static int mlx_initialized = 0;
static char mlx_error_buffer[512] = {0};
// Initialize MLX dynamic library
// Returns 0 on success, -1 on failure
// On failure, call mlx_dynamic_error() to get error message
int mlx_dynamic_init(void) {
if (mlx_initialized) {
return 0; // Already initialized
}
// Try each possible library path
for (int i = 0; LIB_NAMES[i] != NULL; i++) {
mlx_handle = LOAD_LIB(LIB_NAMES[i]);
if (mlx_handle != NULL) {
mlx_initialized = 1;
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
"MLX: Successfully loaded %s", LIB_NAMES[i]);
return 0;
}
}
// Failed to load library
const char* err = LIB_ERROR();
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
"MLX: Failed to load libmlxc library. %s",
err ? err : "Unknown error");
return -1;
}
// Get the last error message
const char* mlx_dynamic_error(void) {
return mlx_error_buffer;
}
// Check if MLX is initialized
int mlx_dynamic_is_initialized(void) {
return mlx_initialized;
}
// Cleanup (optional, called at program exit)
void mlx_dynamic_cleanup(void) {
if (mlx_handle != NULL) {
CLOSE_LIB(mlx_handle);
mlx_handle = NULL;
mlx_initialized = 0;
}
}

View File

@@ -0,0 +1,26 @@
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
#ifndef MLX_DYNAMIC_H
#define MLX_DYNAMIC_H
#ifdef __cplusplus
extern "C" {
#endif
// Initialize the MLX dynamic library
// Returns 0 on success, -1 on failure
int mlx_dynamic_init(void);
// Get the last error message from dynamic loading
const char* mlx_dynamic_error(void);
// Check if MLX is initialized
int mlx_dynamic_is_initialized(void);
// Cleanup resources (optional, for clean shutdown)
void mlx_dynamic_cleanup(void);
#ifdef __cplusplus
}
#endif
#endif // MLX_DYNAMIC_H

View File

@@ -0,0 +1,314 @@
//go:build mlx
package mlx
import (
"log/slog"
"os"
"reflect"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/model/input"
_ "github.com/ollama/ollama/x/model/models/gemma3"
)
func init() {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
}
func TestLoadModel(t *testing.T) {
dir := "/Users/daniel/Models/gemma-3-4b-it/"
b := &Backend{}
err := b.LoadSafeTensors(dir)
if err != nil {
t.Fatalf("load failed: %s", err)
}
}
func TestFromInts(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
data := []int32{1, 2, 3, 4, 5, 6}
a := c.FromInts(data, 2, 3)
slog.Info("", "array", a)
t.Log(a.ToString())
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
t.Fatalf("incorrect shape: %v", a.Shape())
}
}
func TestFromFloats(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
data := []float32{1, 2, 3, 4, 5, 6}
a := c.FromFloats(data, 2, 3)
slog.Info("", "array", a)
t.Log(a.ToString())
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
t.Fatalf("incorrect shape: %v", a.Shape())
}
res := a.Floats()
if !reflect.DeepEqual(res, data) {
t.Fatalf("incorrect results: %v", res)
}
}
func TestAdd(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16)
t2 := c.Arange(0, 24, 1, ml.DTypeFloat16)
exp := c.Arange(0, 48, 2, ml.DTypeFloat16)
t3 := t1.Add(c, t2)
c.Compute(t3, exp)
t3f := t3.Floats()
if !reflect.DeepEqual(t3f, exp.Floats()) {
t.Fatalf("incorrect result: %v", t3f)
}
}
func TestReshapeTranspose(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false)
c.Compute(t1)
t1f := t1.Floats()
exp := []float32{
0, 4, 8,
1, 5, 9,
2, 6, 10,
3, 7, 11,
12, 16, 20,
13, 17, 21,
14, 18, 22,
15, 19, 23,
}
if !reflect.DeepEqual(t1f, exp) {
t.Fatalf("incorrect results: %v", t1f)
}
}
func prod(vals ...int) int {
r := 1
for _, v := range vals {
r *= v
}
return r
}
func TestMatmul(t *testing.T) {
// TODO create scenarios...
b := &Backend{}
c := b.NewContext()
defer c.Close()
s1 := []int{1, 3, 2, 4}
t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...)
s2 := []int{4, 2}
t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...)
t3 := t1.Matmul(c, t2)
exp := []float32{
28, 34,
76, 98,
124, 162,
172, 226,
220, 290,
268, 354,
}
c.Compute(t3)
t3f := t3.Floats()
if !reflect.DeepEqual(t3f, exp) {
t.Fatalf("incorrect result: %v", t3f)
}
}
func TestRows(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3)
outputs := c.Zeros(ml.DTypeInt32, 1)
t2 := t1.TakeAxes(c, outputs, 1)
c.Forward(t1, t2).Compute(t1, t2)
t.Log(t1.ToString())
t.Log(t2.ToString())
f := t2.Floats()
t.Logf("Result: %v", f)
}
func TestCaching(t *testing.T) {
// Validate the caching algorithm
b := &Backend{}
c := b.NewContext()
defer c.Close()
batchSize := 3
headDim := 4
numKVHeads := 2
// Make cache twice the size of one test batch
cells := batchSize * 2
cellSize := numKVHeads * headDim
shape := []int{1, numKVHeads, batchSize, headDim}
stop := float32(1)
for _, x := range shape {
stop *= float32(x)
}
// Create the cache
cache := c.Zeros(ml.DTypeFloat16, cells, cellSize)
t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize})
// Input tensor
t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...)
t.Logf("Initial Data shape%v\n"+t1.ToString(), shape)
// Reshape to copy into the cache
/*
From MLX python/src/indexing.cpp mlx_scatter_args_array
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
auto up_shape = indices.shape();
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
up = broadcast_to(up, up_shape);
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
up = reshape(up, up_shape);
*/
numRows := 3
up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly
t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim})
// Simulate cells 1,3,5 are available
indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)}
t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows})
axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape
cache.Scatter(c, indicies, up, axis)
c.Forward(cache)
// Cache should contain the data now
t.Log("Cache after put\n" + cache.ToString())
// Retrieve cache content and verify it matches
out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...)
t.Logf("Output shape%v\n"+out.ToString(), out.Shape())
t1f := t1.Floats()
outf := out.Floats()
if !reflect.DeepEqual(t1f, outf) {
t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf)
}
}
func TestGemma3(t *testing.T) {
// Why is the sky blue
inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368}
limit := 50
// TODO generalize this
dir := "/Users/daniel/Models/gemma-3-4b-it/"
m, err := model.New(dir, ml.BackendParams{})
if err != nil {
t.Fatalf("unable to load model: %s", err)
}
b := m.Backend()
ctx := b.NewContext()
defer ctx.Close()
batch := input.Batch{
Inputs: ctx.FromInts(inputs[:], 1, len(inputs)),
Positions: make([]int32, len(inputs)),
Sequences: make([]int, len(inputs)),
Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1),
Offset: 0,
}
for i := range len(inputs) {
batch.Positions[i] = int32(i)
}
offset := len(inputs)
cache := m.Config().Cache
if cache != nil {
numSlots := 1
batchSize := 512
numCtx := 4096
// Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working
// cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64})
cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0})
cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize)
err := cache.StartForward(ctx, batch, false)
if err != nil {
t.Fatalf("failed cache.StartForward: %s", err)
}
}
opts := api.DefaultOptions()
var grammar *sample.GrammarSampler
sampler := sample.NewSampler(
opts.Temperature,
opts.TopK,
opts.TopP,
opts.MinP,
opts.Seed,
grammar,
)
t.Log("Starting Forward pass loop")
pendingResponses := []string{}
for {
out, err := m.Forward(ctx, batch)
if err != nil {
t.Fatalf("failed forward pass: %s", err)
}
ctx.Forward(out)
outputs := out.Floats()
t.Logf("finished forward pass! length:%d", len(outputs))
// sample a token
logits := outputs
token, err := sampler.Sample(logits)
if err != nil {
t.Fatalf("unable to sample token: %s", err)
}
t.Logf("Sampled token: %v", token)
if m.(model.TextProcessor).Is(token, model.SpecialEOS) {
t.Log("hit EOS")
break
}
piece, err := m.(model.TextProcessor).Decode([]int32{token})
if err != nil {
t.Fatalf("unable to decode token: %s", err)
}
pendingResponses = append(pendingResponses, piece)
sequence := strings.Join(pendingResponses, "")
if ok, stop := common.FindStop(sequence, opts.Stop); ok {
t.Logf("hit stop token: %v", stop)
break
}
t.Logf("RESULTS: %s", sequence)
batch = input.Batch{
Inputs: ctx.FromInts([]int32{token}, 1, 1),
Positions: make([]int32, 1),
Sequences: make([]int, 1),
Outputs: ctx.FromInts([]int32{0}, 1),
Offset: offset,
}
offset++
batch.Positions[0] = 0
err = cache.StartForward(ctx, batch, false)
if err != nil {
t.Fatalf("failed cache.StartForward: %s", err)
}
if offset > limit {
break
}
}
}

335
x/ml/backend/mlx/quant.go Normal file
View File

@@ -0,0 +1,335 @@
//go:build mlx
package mlx
/*
#include <stdio.h>
#include <string.h>
#include "mlx/c/array.h"
#include "mlx/c/ops.h"
// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp
void unpack_32_4(uint8_t* data, int8_t* dst) {
memset(dst, 0, 16);
for (int j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
if (j % 2 != 0) {
x <<= 4;
}
dst[j / 2] += x;
}
// Last 16 weights are in the higher bits
for (int j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] >> 4);
if (j % 2 != 0) {
x <<= 4;
}
dst[8 + j / 2] += x;
}
}
// Extracts (weight, scales, biases) from Q4_0 tensors.
// Data layout is: |16 bit scale|32 x 4bit weights|.
void extract_q4_0_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
scales[i] = *((float16_t*)data);
biases[i] = -8 * scales[i];
unpack_32_4(data, weights);
weights += 16;
data += bytes_per_block;
}
}
// Extracts (weight, scales, biases) from Q4_1 tensors.
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
void extract_q4_1_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
scales[i] = *((float16_t*)data);
biases[i] = *((float16_t*)(data) + 1);
unpack_32_4(data, weights);
weights += 16;
data += bytes_per_block;
}
}
// Extracts (weight, scales, biases) from Q8_0 tensors.
// Data layout is: |16 bit scale|32 x 8bit weights|.
void extract_q8_0_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t weights_per_block = 32;
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
uint8_t* block_data = data + i * bytes_per_block;
scales[i] = *((float16_t*)block_data);
biases[i] = -128 * scales[i];
for (int64_t j = 0; j < weights_per_block; ++j) {
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
// Original data is in int8_t, so we add a bias of -128 and invert the
// first bit.
x ^= 1 << 7;
weights[i * weights_per_block + j] = x;
}
}
}
// Drived from ggml-quants.c
#define QK_K 256
// 6-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
// Effectively 6.5625 bits per weight
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
uint16_t d; // super-block scale
} block_q6_K;
void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) {
const int64_t nb = k / QK_K;
block_q6_K *x = (block_q6_K *)vx;
float16_t* y = (float16_t *)vy;
for (int i = 0; i < nb; i++) {
float16_t d = 0.0;
memcpy(&d, &x[i].d, sizeof(d));
const uint8_t * restrict ql = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict sc = x[i].scales;
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
y[l + 0] = d * sc[is + 0] * q1;
y[l + 32] = d * sc[is + 2] * q2;
y[l + 64] = d * sc[is + 4] * q3;
y[l + 96] = d * sc[is + 6] * q4;
}
y += 128;
ql += 64;
qh += 32;
sc += 8;
}
}
}
#define K_SCALE_SIZE 12
#define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S
// 4-bit quantization
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 4.5 bits per weight
typedef struct {
union {
struct {
uint16_t d; // super-block scale for quantized scales
uint16_t dmin; // super-block scale for quantized mins
} GGML_COMMON_AGGR_S;
uint16_t dm;
} GGML_COMMON_AGGR_U;
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
if (j < 4) {
*d = q[j] & 63; *m = q[j + 4] & 63;
} else {
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) {
block_q4_K *x = (block_q4_K *)vx;
float16_t* y = (float16_t *)vy;
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
const uint8_t * q = x[i].qs;
float16_t d = 0.0;
memcpy(&d, &x[i].d, sizeof(d));
float16_t min = 0.0;
memcpy(&min, &x[i].dmin, sizeof(d));
int is = 0;
uint8_t sc, m;
for (int j = 0; j < QK_K; j += 64) {
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
const float16_t d1 = d * sc; const float16_t m1 = min * m;
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
const float16_t d2 = d * sc; const float16_t m2 = min * m;
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}
}
}
*/
import "C"
import (
"fmt"
"unsafe"
"github.com/x448/float16"
)
func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
shape := append([]C.int{}, final_shape...)
var weights_per_byte C.int
if dtype == 2 || dtype == 3 {
weights_per_byte = 2
} else if dtype == 8 {
weights_per_byte = 1
} else {
return r, fmt.Errorf("unsupported tensor type %d", dtype)
}
weights_per_block := C.int(32)
if shape[len(shape)-1]%weights_per_block != 0 {
return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1])
}
weights_shape := append([]C.int{}, shape...)
weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4)
w_nbytes := C.int(unsafe.Sizeof(uint32(0)))
for i := range weights_shape {
w_nbytes *= weights_shape[i]
}
w_data := make([]byte, w_nbytes)
cbytes := C.CBytes(w_data)
defer C.free(cbytes)
weights := C.mlx_array_new_data(
cbytes,
&weights_shape[0],
C.int(len(weights_shape)),
C.MLX_UINT32,
)
// For scales and bias
shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block
sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0)))
for i := range shape {
sb_nbytes *= shape[i]
}
s_data := make([]byte, sb_nbytes)
cbytes = C.CBytes(s_data)
defer C.free(cbytes)
scales := C.mlx_array_new_data(
cbytes,
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
b_data := make([]byte, sb_nbytes)
cbytes = C.CBytes(b_data)
defer C.free(cbytes)
biases := C.mlx_array_new_data(
cbytes,
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
var bits C.int
switch dtype {
case 2:
C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 4
case 3:
C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 4
case 8:
C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 8
}
groupSize := C.mlx_optional_int{value: 32, has_value: true}
bitsOpt := C.mlx_optional_int{value: bits, has_value: true}
var dtypeOpt C.mlx_optional_dtype // has_value defaults to false
C.mlx_dequantize(
&r,
weights,
scales,
biases,
groupSize,
bitsOpt,
nil, // TODO mode
dtypeOpt,
stream,
)
C.mlx_array_free(weights)
C.mlx_array_free(scales)
C.mlx_array_free(biases)
return r, nil
}
func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
size := 1
for _, d := range shape {
size *= int(d)
}
fdata := make([]float16.Float16, size)
switch dtype {
case 14:
C.dequant_row_q6_K(
data,
unsafe.Pointer(&fdata[0]),
C.int(size),
)
case 12:
C.dequant_row_q4_K(
data,
unsafe.Pointer(&fdata[0]),
C.int(size),
)
default:
return r, fmt.Errorf("unsupported K quant")
}
r = C.mlx_array_new_data(
unsafe.Pointer(&fdata[0]),
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
return r, nil
}

643
x/ml/device.go Normal file
View File

@@ -0,0 +1,643 @@
package ml
import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"hash/maphash"
"io"
"log/slog"
"math"
"net/http"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/logutil"
)
// GPULayers is a set of layers to be allocated on a single GPU
type GPULayers struct {
DeviceID
// Layers is a set of layer indicies to load
Layers []int
}
// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty.
func (g GPULayers) FirstLayer() int {
if len(g.Layers) == 0 {
return math.MaxInt
}
first := g.Layers[0]
for i := 1; i < len(g.Layers); i++ {
if g.Layers[i] < first {
first = g.Layers[i]
}
}
return first
}
func (g GPULayers) String() string {
if len(g.Layers) == 0 {
return ""
}
slices.Sort(g.Layers)
contiguous := true
base := g.Layers[0]
for i := range g.Layers {
if g.Layers[i] != base+i {
contiguous = false
break
}
}
if contiguous {
return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1])
} else {
return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers)
}
}
// GPULayersList is a set of layer allocations across multiple GPUs
type GPULayersList []GPULayers
func (l GPULayersList) Len() int { return len(l) }
func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
// Sort by the ordering of the layers offloaded
func (l GPULayersList) Less(i, j int) bool {
li := l[i].FirstLayer()
lj := l[j].FirstLayer()
return li < lj
}
func (l GPULayersList) String() string {
if l.Sum() > 0 {
return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l))
} else {
return fmt.Sprintf("%v", []GPULayers(l))
}
}
// Sum is the total number of layers assigned across all GPUs
func (l GPULayersList) Sum() int {
var sum int
for _, g := range l {
sum += len(g.Layers)
}
return sum
}
var h maphash.Hash
// Hash is an identifier of this layer assignment
func (l GPULayersList) Hash() uint64 {
h.Reset()
for _, g := range l {
if len(g.Layers) > 0 {
h.WriteString(g.ID + g.Library)
for _, l := range g.Layers {
binary.Write(&h, binary.NativeEndian, int64(l))
}
}
}
return h.Sum64()
}
// ErrNoMem is returned when panicing due to insufficient memory. It includes
// the attempted memory allocation.
type ErrNoMem struct {
BackendMemory
}
func (e ErrNoMem) Error() string {
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
}
// Minimal unique device identification
type DeviceID struct {
// ID is an identifier for the device for matching with system
// management libraries. The ID is only unique for other devices
// using the same Library.
// This ID represents a "post filtered" view of the enumerated devices
// if the ID is numeric
ID string `json:"id"`
// Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.)
Library string `json:"backend,omitempty"`
}
// DeviceMemory provides a breakdown of the memory needed
// per device, such as a CPU or GPU.
type DeviceMemory struct {
DeviceID
// Name is the name of the device as labeled by the backend. It
// may not be persistent across instances of the runner.
Name string
// Weights is the per-layer memory needed for the model weights.
Weights []uint64
// Cache is the per-layer memory needed for the KV cache.
Cache []uint64
// Graph is the size of the compute graph. It is not per-layer.
Graph uint64
}
func sumMemory(mem []uint64) uint64 {
var sum uint64
for _, m := range mem {
sum += m
}
return sum
}
// Size returns the total size of the memory required by this device
func (m DeviceMemory) Size() uint64 {
return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph
}
func memoryPresent(mem []uint64) bool {
return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 })
}
func (m DeviceMemory) LogValue() slog.Value {
var attrs []slog.Attr
if memoryPresent(m.Weights) {
attrs = append(attrs, slog.Any("Weights", m.Weights))
}
if memoryPresent(m.Cache) {
attrs = append(attrs, slog.Any("Cache", m.Cache))
}
if m.Graph != 0 {
attrs = append(attrs, slog.Any("Graph", m.Graph))
}
if len(attrs) > 0 && m.ID != "" {
attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
}
return slog.GroupValue(attrs...)
}
// BackendMemory provides the amount of memory required to load the model
// per device based on the BackendParams. In some cases, not all required
// allocations will be known at this point. However, the size of the most recent
// allocation is guaranteed to be provided so that if it failed, the caller can
// accommodate that to make forward progress.
type BackendMemory struct {
// InputWeights are always located on the CPU and cannot be moved
InputWeights uint64
// CPU model components are located in system memory. This does not
// include unified memory allocated through the GPU.
CPU DeviceMemory
// GPU model components are located on one or more GPUs.
GPUs []DeviceMemory
}
func (m BackendMemory) LogValue() slog.Value {
var attrs []slog.Attr
if m.InputWeights != 0 {
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
}
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
for _, g := range m.GPUs {
attrs = append(attrs, slog.Any(g.Name, g))
}
return slog.GroupValue(attrs...)
}
// Log prints a high level summary of the memory
func (m BackendMemory) Log(level slog.Level) {
var total uint64
for _, gpu := range m.GPUs {
if sum := sumMemory(gpu.Weights); sum > 0 {
slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 {
slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
for _, gpu := range m.GPUs {
if sum := sumMemory(gpu.Cache); sum > 0 {
slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := sumMemory(m.CPU.Cache); sum > 0 {
slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
for _, gpu := range m.GPUs {
if sum := gpu.Graph; sum > 0 {
slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := m.CPU.Graph; sum > 0 {
slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
if total > 0 {
slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total))
}
}
type DeviceInfo struct {
DeviceID
// Name is the name of the device as labeled by the backend. It
// may not be persistent across instances of the runner.
Name string `json:"name"`
// Description is the longer user-friendly identification of the device
Description string `json:"description"`
// FilterID is populated with the unfiltered device ID if a numeric ID is used
// so the device can be included.
FilterID string `json:"filter_id,omitempty"`
// Integrated is set true for integrated GPUs, false for Discrete GPUs
Integrated bool `json:"integration,omitempty"`
// PCIID is the bus, device and domain ID of the device for deduplication
// when discovered by multiple backends
PCIID string `json:"pci_id,omitempty"`
// TotalMemory is the total amount of memory the device can use for loading models
TotalMemory uint64 `json:"total_memory"`
// FreeMemory is the amount of memory currently available on the device for loading models
FreeMemory uint64 `json:"free_memory,omitempty"`
// ComputeMajor is the major version of capabilities of the device
// if unsupported by the backend, -1 will be returned
ComputeMajor int
// ComputeMinor is the minor version of capabilities of the device
// if unsupported by the backend, -1 will be returned
ComputeMinor int
// Driver Information
DriverMajor int `json:"driver_major,omitempty"`
DriverMinor int `json:"driver_minor,omitempty"`
// Where backends were loaded from
LibraryPath []string
}
type SystemInfo struct {
// ThreadCount is the optimal number of threads to use for inference
ThreadCount int `json:"threads,omitempty"`
// TotalMemory is the total amount of system memory
TotalMemory uint64 `json:"total_memory,omitempty"`
// FreeMemory is the amount of memory currently available on the system for loading models
FreeMemory uint64 `json:"free_memory,omitempty"`
// FreeSwap is the amount of system swap space reported as available
FreeSwap uint64 `json:"free_swap,omitempty"`
}
func (d DeviceInfo) Compute() string {
// AMD gfx is encoded into the major minor in hex form
if strings.EqualFold(d.Library, "ROCm") {
return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor)
}
return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor)
}
func (d DeviceInfo) Driver() string {
return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor)
}
// MinimumMemory reports the amount of memory that should be set aside
// on the device for overhead (e.g. VRAM consumed by context structures independent
// of model allocations)
func (d DeviceInfo) MinimumMemory() uint64 {
if d.Library == "Metal" {
return 512 * format.MebiByte
}
return 457 * format.MebiByte
}
// Sort by Free Space.
// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first
type ByFreeMemory []DeviceInfo
func (a ByFreeMemory) Len() int { return len(a) }
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByFreeMemory) Less(i, j int) bool {
if a[i].Integrated && !a[j].Integrated {
return true
} else if !a[i].Integrated && a[j].Integrated {
return false
}
return a[i].FreeMemory < a[j].FreeMemory
}
// ByPerformance groups devices by similar speed
func ByPerformance(l []DeviceInfo) [][]DeviceInfo {
resp := [][]DeviceInfo{}
scores := []bool{}
for _, info := range l {
found := false
requested := info.Integrated
for i, score := range scores {
if score == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
scores = append(scores, requested)
resp = append(resp, []DeviceInfo{info})
}
}
return resp
}
func ByLibrary(l []DeviceInfo) [][]DeviceInfo {
resp := [][]DeviceInfo{}
libs := []string{}
for _, info := range l {
found := false
requested := info.Library
for i, lib := range libs {
if lib == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
libs = append(libs, requested)
resp = append(resp, []DeviceInfo{info})
}
}
return resp
}
func LibraryPaths(l []DeviceInfo) []string {
gpuLibs := []string{LibOllamaPath}
for _, gpu := range l {
for _, dir := range gpu.LibraryPath {
needed := true
for _, existing := range gpuLibs {
if dir == existing {
needed = false
break
}
}
if needed {
gpuLibs = append(gpuLibs, dir)
}
}
}
return gpuLibs
}
type DeviceComparison int
const (
UniqueDevice DeviceComparison = iota
SameBackendDevice // The device is the same, and the library/backend is the same
DuplicateDevice // The same physical device but different library/backend (overlapping device)
)
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
if a.PCIID != b.PCIID {
return UniqueDevice
}
// If PCIID is empty, we have to use ID + library for uniqueness
if a.PCIID == "" && a.DeviceID != b.DeviceID {
return UniqueDevice
}
if a.Library == b.Library {
return SameBackendDevice
}
return DuplicateDevice
}
// For a SameBackendDevice, return true if b is better than a
// e.g. newer GPU library version
func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
aLib := a.LibraryPath[len(a.LibraryPath)-1]
bLib := b.LibraryPath[len(b.LibraryPath)-1]
if aLib == bLib {
return false
}
aLibSplit := strings.SplitN(aLib, "_", 2)
bLibSplit := strings.SplitN(bLib, "_", 2)
if len(aLibSplit) < 2 || len(bLibSplit) < 2 {
return false
}
if aLibSplit[0] != bLibSplit[0] {
slog.Debug("unexpected libraries", "a", aLib, "b", bLib)
return false
}
if aLibSplit[1] == bLibSplit[1] {
return false
}
cmp := []string{aLibSplit[1], bLibSplit[1]}
sort.Sort(sort.Reverse(sort.StringSlice(cmp)))
return cmp[0] == bLibSplit[1]
}
// For each GPU, check if it does NOT support flash attention
func FlashAttentionSupported(l []DeviceInfo) bool {
for _, gpu := range l {
supportsFA := gpu.Library == "cpu" ||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
gpu.Library == "ROCm" ||
gpu.Library == "Vulkan"
if !supportsFA {
return false
}
}
return true
}
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variables
// Set mustFilter true to enable filtering of CUDA devices
func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
if len(l) == 0 {
return nil
}
env := map[string]string{}
for _, d := range l {
d.updateVisibleDevicesEnv(env, mustFilter)
}
return env
}
// NeedsInitValidation returns true if the device in question has the potential
// to crash at inference time and requires deeper validation before we include
// it in the supported devices list.
func (d DeviceInfo) NeedsInitValidation() bool {
// ROCm: rocblas will crash on unsupported devices.
// CUDA: verify CC is supported by the version of the library
return d.Library == "ROCm" || d.Library == "CUDA"
}
// Set the init validation environment variable
func (d DeviceInfo) AddInitValidation(env map[string]string) {
env["GGML_CUDA_INIT"] = "1" // force deep initialization to trigger crash on unsupported GPUs
}
// PreferredLibrary returns true if this library is preferred over the other input
// library
// Used to filter out Vulkan in favor of CUDA or ROCm
func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool {
// TODO in the future if we find Vulkan is better than ROCm on some devices
// that implementation can live here.
if d.Library == "CUDA" || d.Library == "ROCm" {
return true
}
return false
}
func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) {
var envVar string
switch d.Library {
case "ROCm":
// ROCm must be filtered as it can crash the runner on unsupported devices
envVar = "ROCR_VISIBLE_DEVICES"
if runtime.GOOS != "linux" {
envVar = "HIP_VISIBLE_DEVICES"
}
case "CUDA":
if !mustFilter {
// By default we try to avoid filtering CUDA devices because ROCm also
// looks at the CUDA env var, and gets confused in mixed vendor environments.
return
}
envVar = "CUDA_VISIBLE_DEVICES"
default:
// Vulkan is not filtered via env var, but via scheduling decisions
return
}
v, existing := env[envVar]
if existing {
v = v + ","
}
if d.FilterID != "" {
v = v + d.FilterID
} else {
v = v + d.ID
}
env[envVar] = v
}
type BaseRunner interface {
// GetPort returns the localhost port number the runner is running on
GetPort() int
// HasExited indicates if the runner is no longer running. This can be used during
// bootstrap to detect if a given filtered device is incompatible and triggered an assert
HasExited() bool
}
type RunnerDiscovery interface {
BaseRunner
// GetDeviceInfos will perform a query of the underlying device libraries
// for device identification and free VRAM information
// During bootstrap scenarios, this routine may take seconds to complete
GetDeviceInfos(ctx context.Context) []DeviceInfo
}
type FilteredRunnerDiscovery interface {
RunnerDiscovery
// GetActiveDeviceIDs returns the filtered set of devices actively in
// use by this runner for running models. If the runner is a bootstrap runner, no devices
// will be active yet so no device IDs are returned.
// This routine will not query the underlying device and will return immediately
GetActiveDeviceIDs() []DeviceID
}
func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) {
var moreDevices []DeviceInfo
port := runner.GetPort()
tick := time.Tick(10 * time.Millisecond)
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("failed to finish discovery before timeout")
case <-tick:
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r)
if err != nil {
// slog.Warn("failed to send request", "error", err)
if runner.HasExited() {
return nil, fmt.Errorf("runner crashed")
}
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// old runner, fall back to bootstrapping model
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
}
body, err := io.ReadAll(resp.Body)
if err != nil {
slog.Warn("failed to read response", "error", err)
continue
}
if resp.StatusCode != 200 {
logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body)
return nil, fmt.Errorf("runner error: %s", string(body))
}
if err := json.Unmarshal(body, &moreDevices); err != nil {
slog.Warn("unmarshal encode response", "error", err)
continue
}
return moreDevices, nil
}
}
}

103
x/ml/nn/attention.go Normal file
View File

@@ -0,0 +1,103 @@
package nn
import (
"fmt"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
)
// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
// Parameters:
// - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
//
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
}
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
}
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
ctx.Forward(query)
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
ctx.Forward(key, value)
if cache != nil {
cache.Put(ctx, key, value)
}
} else if cache == nil {
panic("key & value tensors must be provided if cache is nil")
}
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query, "k": key, "v": value}, true)
// panic("after cache get") //
// 2025/12/10 16:02:33 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
// 2025/12/10 16:02:33 INFO XXX tensors are similar k=0.9999891519546509 shape="[1 4 13 256]" min_difference=[-0.21365738] max_difference=[0.19916534]
// 2025/12/10 16:02:33 INFO XXX tensors are similar v=0.9999960660934448 shape="[1 4 13 256]" min_difference=[-0.32923126] max_difference=[0.32646942]
// var mask ml.Tensor
if cache != nil {
key, value, _ = cache.Get(ctx)
}
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query.Contiguous(ctx, false), "k": key.Contiguous(ctx, false), "v": value.Contiguous(ctx, false)}, true)
// panic("after cache get") //
// 2025/12/10 15:34:03 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
// 2025/12/10 15:34:03 INFO XXX tensors are similar k=0.9999881982803345 shape="[1 4 13 256]" min_difference=[-0.25] max_difference=[0.25]
// 2025/12/10 15:34:03 INFO XXX tensors are similar v=0.9999913573265076 shape="[1 4 13 256]" min_difference=[-0.5] max_difference=[0.5]
// Only use the fast SDPA implementation if we have a cache, since that's what
// will do any expected backend-specific transformations for us
if cache != nil {
// TODO what to do with vmla?
// return query.Transpose(ctx, 0, 2, 1, 3).ScaledDotProductAttention(ctx, key.Transpose(ctx, 0, 2, 1, 3), value.Transpose(ctx, 0, 2, 1, 3), scale, "array", mask, sinks)
return query.ScaledDotProductAttention(ctx, key, value, scale, "causal", nil, sinks)
// TODO these two produce identical output, but not similar enough - 92.9% - should be 99.999%
} else {
panic("else case not supported")
// TODO transpose shapes are wrong
// key = key.Transpose(ctx, 0, 2, 1, 3)
// value = value.Transpose(ctx, 1, 2, 0, 3).Contiguous(ctx, false)
// kq := query.Matmul(ctx, key)
// kq = kq.Scale(ctx, scale)
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
// kq = kq.Softmax(ctx)
// kqv := kq.Matmul(ctx, value)
// if vmla != nil {
// kqv = kqv.Matmul(ctx, vmla)
// }
// return kqv.Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
}
}

30
x/ml/nn/convolution.go Normal file
View File

@@ -0,0 +1,30 @@
package nn
import "github.com/ollama/ollama/x/ml"
type Conv2D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1, 1)
if m.Bias != nil {
// Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch)
t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1))
}
return t
}
type Conv3D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, s2, p0, p1, p2, d0, d1, d2, g int) ml.Tensor {
t = m.Weight.Conv3D(ctx, t, s0, s1, s2, p0, p1, p2, d0, d1, d2, g)
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
}

11
x/ml/nn/embedding.go Normal file
View File

@@ -0,0 +1,11 @@
package nn
import "github.com/ollama/ollama/x/ml"
type Embedding struct {
Weight ml.Tensor `gguf:"weight"`
}
func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
return m.Weight.TakeAxes(ctx, hiddenState, 0)
}

32
x/ml/nn/linear.go Normal file
View File

@@ -0,0 +1,32 @@
package nn
import "github.com/ollama/ollama/x/ml"
type Linear struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
t = t.Matmul(ctx, m.Weight.Transpose(ctx))
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
}
type LinearBatch struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
panic("not yet ported")
// t = m.Weight.MulmatID(ctx, t, indices)
// if m.Bias != nil {
// t = t.AddID(ctx, m.Bias, indices)
// }
// return t
}

29
x/ml/nn/normalization.go Normal file
View File

@@ -0,0 +1,29 @@
package nn
import (
"github.com/ollama/ollama/x/ml"
)
type LayerNorm struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
}
type RMSNorm struct {
Weight ml.Tensor `gguf:"weight"`
}
func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
// slog.Info("RMSNorm", "eps", eps)
// fmt.Fprintln(os.Stderr, t.ToString())
// fmt.Fprintln(os.Stderr, m.Weight.ToString())
// TODO this is probably model specific, not generalized...
w := m.Weight.Add(ctx, ctx.FromFloats([]float32{1.0}, 1))
return t.RMSNorm(ctx, w, eps)
}

View File

@@ -0,0 +1,41 @@
package pooling
import (
"github.com/ollama/ollama/x/ml"
)
type Type uint32
const (
TypeNone Type = iota
TypeMean
TypeCLS
TypeLast
)
func (t Type) String() string {
switch t {
case TypeMean:
return "Mean"
case TypeCLS:
return "CLS"
case TypeLast:
return "Last"
default:
return "Unknown"
}
}
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
switch t {
// case TypeMean:
// hiddenStates = hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false).Mean(ctx)
// return hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
// case TypeCLS:
// return hiddenStates.Slice(ctx, 1, 0, 1, 1)
// case TypeLast:
// return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
default:
panic("unknown pooling type")
}
}

72
x/ml/nn/rope/rope.go Normal file
View File

@@ -0,0 +1,72 @@
package rope
import "github.com/ollama/ollama/x/ml"
// Options contains optional parameters for RoPE function
type Options struct {
Type int
Factors ml.Tensor
// YaRN options
YaRN struct {
OriginalContextLength int
ExtrapolationFactor,
AttentionFactor,
BetaFast,
BetaSlow float32
}
// MRoPE options
MRoPE struct {
Sections []int
}
}
// WithTypeNeoX sets RoPE type to NeoX
func WithTypeNeoX() func(*Options) {
return func(opts *Options) {
opts.Type = 2
}
}
// WithFactors sets custom rope factors
func WithFactors(factors ml.Tensor) func(*Options) {
return func(opts *Options) {
if factors != nil {
opts.Factors = factors
}
}
}
// WithOriginalContextLength sets a custom context length
func WithOriginalContextLength(n int) func(*Options) {
return func(opts *Options) {
opts.YaRN.OriginalContextLength = n
}
}
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.ExtrapolationFactor = extrapolationFactor
}
}
func WithAttentionFactor(attentionFactor float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.AttentionFactor = attentionFactor
}
}
func WithMRoPE(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1 << 3
opts.MRoPE.Sections = sections
}
}
func WithInterleaveMRoPE(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1<<3 | 1<<5
opts.MRoPE.Sections = sections
}
}

56
x/ml/path.go Normal file
View File

@@ -0,0 +1,56 @@
package ml
import (
"os"
"path/filepath"
"runtime"
)
// LibPath is a path to lookup dynamic libraries
// in development it's usually 'build/lib/ollama'
// in distribution builds it's 'lib/ollama' on Windows
// '../lib/ollama' on Linux and the executable's directory on macOS
// note: distribution builds, additional GPU-specific libraries are
// found in subdirectories of the returned path, such as
// 'cuda_v12', 'rocm', etc.
var LibOllamaPath string = func() string {
exe, err := os.Executable()
if err != nil {
return ""
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
var libPath string
switch runtime.GOOS {
case "windows":
libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama")
case "linux":
libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
case "darwin":
libPath = filepath.Dir(exe)
}
cwd, err := os.Getwd()
if err != nil {
return ""
}
paths := []string{
libPath,
// build paths for development
filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"),
filepath.Join(cwd, "build", "lib", "ollama"),
}
for _, p := range paths {
if _, err := os.Stat(p); err == nil {
return p
}
}
return filepath.Dir(exe)
}()

View File

@@ -1,4 +1,4 @@
package tokenizer
package model
import (
"cmp"
@@ -18,19 +18,19 @@ type BytePairEncoding struct {
regexps []*regexp2.Regexp
}
var _ Tokenizer = (*BytePairEncoding)(nil)
var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
if len(pretokenizer) == 0 {
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizer {
for _, p := range pretokenizers {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}

View File

@@ -0,0 +1,322 @@
package model
import (
"bufio"
"encoding/json"
"math"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func llama(t testing.TB) BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
vocab := make(map[string]int32)
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
t.Fatal(err)
}
types := make([]int32, len(vocab))
tokens := make([]string, len(vocab))
for token, id := range vocab {
tokens[id] = token
types[id] = 1
}
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
if _, ok := vocab[token]; !ok {
tokens = append(tokens, token) //nolint:makezero
types = append(types, 3) //nolint:makezero
vocab[token] = int32(len(vocab))
}
}
f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
merges := make([]string, 0, 50000)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
if !strings.HasPrefix(scanner.Text(), "#") {
merges = append(merges, scanner.Text())
}
}
return NewBytePairEncoding(
&Vocabulary{
Values: tokens,
Types: types,
Merges: merges,
},
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
}
func TestLlama(t *testing.T) {
tokenizer := llama(t)
t.Run("simple", func(t *testing.T) {
t.Parallel()
ids, err := tokenizer.Encode("hello world", true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
s, err := tokenizer.Decode([]int32{15339, 1917})
if err != nil {
t.Fatal(err)
}
if s != "hello world" {
t.Errorf("got %q, want hello world", s)
}
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
})
t.Run("simple repeated", func(t *testing.T) {
t.Parallel()
cases := map[string][]int32{
strings.Repeat("0", 1): {15},
strings.Repeat("0", 2): {410},
strings.Repeat("0", 3): {931},
strings.Repeat("0", 4): {931, 15},
strings.Repeat("0", 5): {931, 410},
strings.Repeat("0", 6): {931, 931},
strings.Repeat("0", 7): {931, 931, 15},
strings.Repeat("0", 8): {931, 931, 410},
strings.Repeat("0", 9): {931, 931, 931},
strings.Repeat("0", 10): {931, 931, 931, 15},
strings.Repeat("0", 11): {931, 931, 931, 410},
strings.Repeat("0", 12): {931, 931, 931, 931},
strings.Repeat("0", 13): {931, 931, 931, 931, 15},
strings.Repeat("0", 14): {931, 931, 931, 931, 410},
strings.Repeat("0", 15): {931, 931, 931, 931, 931},
strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
}
for s, want := range cases {
ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff(want, ids); diff != "" {
t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
}
}
})
t.Run("basic roundtrip", func(t *testing.T) {
t.Parallel()
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
}
for _, want := range cases {
ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Error(err)
}
if got, err := tokenizer.Decode(ids); err != nil {
t.Fatal(err)
} else if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
})
t.Run("special", func(t *testing.T) {
t.Parallel()
cases := map[string][]int32{
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
}
for s, want := range cases {
ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
}
})
t.Run("split", func(t *testing.T) {
t.Parallel()
cases := map[string][]string{
"Hello World!": {"Hello", " World", "!"},
"I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
"In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
"Hello!! ...world": {"Hello", "!!", " ...", "world"},
"Hello World": {"Hello", " ", " World"},
"Hello\nWorld": {"Hello", "\n", "World"},
"Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
}
for s, want := range cases {
got := slices.Collect(tokenizer.split(s))
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
}
})
t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
t.Parallel()
for b := 0x00; b <= 0xFF; b++ {
input := string(rune(b))
ids, err := tokenizer.Encode(input, false)
if err != nil {
t.Errorf("failed to encode rune 0x%02X: %v", b, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("failed to decode rune 0x%02X: %v", b, err)
continue
}
if b == 0x00 {
if len(decoded) != 0 {
t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
}
continue
}
if decoded != input {
t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
}
}
})
}
func BenchmarkBytePairEncoding(b *testing.B) {
tokenizer := llama(b)
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
if err != nil {
b.Fatal(err)
}
for i := range 8 {
n := min(int(math.Pow10(i)), len(bts))
bts := bts[:n]
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
_, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
ids, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for b.Loop() {
_, err := tokenizer.Decode(ids)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
slices.Collect(tokenizer.split(string(bts)))
}
})
}
}
func TestSplit(t *testing.T) {
cases := []struct {
name string
patterns,
want []string
}{
{
name: "default",
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
},
{
name: "unicode",
patterns: []string{
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
},
{
name: "individual digits",
patterns: []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
})
}
}

76
x/model/input/input.go Normal file
View File

@@ -0,0 +1,76 @@
package input
import "github.com/ollama/ollama/x/ml"
// Multimodal is a multimodal embedding or a component of one.
// For example, it could be a row of an image that can be processed
// independently.
type Multimodal struct {
// Tensor is the embedding data. Implementations may chose what to
// store here or it may be nil if not needed. However, any ml.Tensor
// objects must be stored here and not in Data.
Tensor ml.Tensor
// Data is implementation-specific opaque data, such as metadata on how
// to layout Tensor. It may be nil if not needed. It may also store larger
// objects such as complete images if they are to be processed later.
Data any
}
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
Token int32
// Multimodal is represents a non-text element such as an
// image (or part of one if the image can be processed in pieces).
// It may be used either together with Token or on its own.
Multimodal []Multimodal
// MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing
// equality.
MultimodalHash uint64
// SameBatch forces the following number of tokens to be processed
// in a single batch, breaking and extending batches as needed.
// Useful for things like images that must be processed in one
// shot.
SameBatch int
}
// MultimodalIndex is a multimodal element (such as an image)
// together with an index into the slice of Inputs with the
// corresponding token. Note that the index is not the same
// as the position - to find that use the index with the
// Positions slice.
type MultimodalIndex struct {
Index int
Multimodal []Multimodal
}
// Batch contains the inputs for a model forward pass
type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs ml.Tensor
// TODO maybe not the optimal way to handle this
// Offset of final tensor in the final batch
Offset int
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
Positions []int32
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal []MultimodalIndex
}

333
x/model/model.go Normal file
View File

@@ -0,0 +1,333 @@
package model
import (
"errors"
"fmt"
_ "image/jpeg"
_ "image/png"
"log/slog"
"os"
"reflect"
"strconv"
"strings"
_ "golang.org/x/image/bmp"
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp"
"github.com/ollama/ollama/fs"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
_ "github.com/ollama/ollama/x/ml/backend"
"github.com/ollama/ollama/x/ml/nn/pooling"
"github.com/ollama/ollama/x/model/input"
)
var (
ErrNoVisionModel = errors.New("this model is missing data required for image input")
ErrUnsupportedModel = errors.New("model not supported")
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
)
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, input.Batch) (ml.Tensor, error)
Backend() ml.Backend
Config() config
}
// MultimodalProcessor must be implemented by multimodal models.
type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and
// generates an output (typically an embedding) that can be used by the model.
//
// The return value is one or more tensors, each with optional model-specific
// opaque metadata. Typically, the tensors might be views into an embedding
// with each view representing a chunk of data that can be processed independently
// in different batches.
//
// The result may be cached by the runner.
EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error)
// PostTokenize is called after tokenization to allow the model to edit the
// input stream to correctly arrange multimodal elements.
//
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
// in the order that the user provided them. Each element of the slice will be
// either a single token or single multimodal object.
//
// The model must ensure that inputs are stored according to how they will be
// processed and stored in the cache. For example, Llava-style models should insert
// placeholder tokens equal to the feature size of the corresponding image with
// the image itself attached to and split across these tokens. When Forward is called
// a partial subset of these tokens may be submitted according to the batch size.
//
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize([]*input.Input) ([]*input.Input, error)
}
// Base implements the common fields and methods for all models
type Base struct {
b ml.Backend
config
}
type config struct {
Cache kvcache.Cache
}
// Backend returns the underlying backend that will run the model
func (m *Base) Backend() ml.Backend {
return m.b
}
func (m *Base) Config() config {
return m.config
}
var models = make(map[string]func(fs.Config) (Model, error))
// Register registers a model constructor for the given architecture
func Register(name string, f func(fs.Config) (Model, error)) {
if _, ok := models[name]; ok {
panic("model: model already registered")
}
models[name] = f
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string, params ml.BackendParams) (Model, error) {
b, err := ml.NewBackend(modelPath, params)
if err != nil {
return nil, err
}
m, err := modelForArch(b.Config())
if err != nil {
return nil, err
}
base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem()))
return m, nil
}
func NewTextProcessor(s string) (TextProcessor, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
}
defer r.Close()
meta, err := fsggml.Decode(r, -1)
if err != nil {
return nil, err
}
m, err := modelForArch(meta.KV())
if err != nil {
return nil, err
}
tp, ok := m.(TextProcessor)
if !ok {
return nil, ErrUnsupportedTokenizer
}
return tp, nil
}
func modelForArch(c fs.Config) (Model, error) {
arch := c.Architecture()
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed"
}
f, ok := models[arch]
if !ok {
return nil, ErrUnsupportedModel
}
return f(c)
}
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type()
if t.Kind() == reflect.Struct {
allNil := true
for i := range t.NumField() {
tt := t.Field(i).Type
vv := v.Field(i)
if !vv.CanSet() {
continue
}
// make a copy
tagsCopy := tags
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
tagsCopy = append(tagsCopy, parseTag(tag))
}
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
vv.Set(reflect.ValueOf(base))
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
var fn func([]Tag, string, string) [][]string
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
if len(tags) > 0 {
var names []string
if tags[0].name != "" {
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
names = append(names, prefix+n+suffix)
}
}
childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix)
if len(names) == 0 {
// current tag has no name, use child names only
fullNames = append(fullNames, childNames...)
} else if len(childNames) == 0 {
// current tag has names but no children, create branches for each name
for _, name := range names {
fullNames = append(fullNames, []string{name})
}
} else {
// merge each name with each child
for _, name := range names {
for _, childName := range childNames {
fullNames = append(fullNames, append([]string{name}, childName...))
}
}
}
}
return fullNames
}
names := fn(tagsCopy, "", "")
for _, name := range names {
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
logutil.Trace("found tensor", "", tensor)
vv.Set(reflect.ValueOf(tensor))
break
}
}
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
setPointer(base, vv, tagsCopy)
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
for i := range vv.Len() {
vvv := vv.Index(i)
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
} else {
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
}
}
}
if !canNil(tt) || !vv.IsNil() {
allNil = false
}
}
if allNil {
return reflect.Zero(t)
}
}
return v
}
func setPointer(base Base, v reflect.Value, tags []Tag) {
vv := v
if v.Kind() == reflect.Interface {
if v.IsNil() {
return
}
vv = vv.Elem()
}
vv = reflect.Indirect(vv)
if v.IsNil() {
vv = reflect.New(v.Type().Elem()).Elem()
}
if f := populateFields(base, vv, tags...); f.CanAddr() {
v.Set(f.Addr())
}
}
type Tag struct {
name,
// prefix and suffix are applied to child tags
prefix,
suffix string
alternatives []string
}
func parseTag(s string) (tag Tag) {
parts := strings.Split(s, ",")
if len(parts) > 0 {
tag.name = parts[0]
for _, part := range parts[1:] {
if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
// elevate alternative to primary if no primary given
tag.name = value
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
} else if ok {
tag.alternatives = append(tag.alternatives, value)
}
if value, ok := strings.CutPrefix(part, "pre:"); ok {
tag.prefix = value
}
if value, ok := strings.CutPrefix(part, "suf:"); ok {
tag.suffix = value
}
}
}
return
}
func canNil(t reflect.Type) bool {
return t.Kind() == reflect.Chan ||
t.Kind() == reflect.Func ||
t.Kind() == reflect.Interface ||
t.Kind() == reflect.Map ||
t.Kind() == reflect.Pointer ||
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
}
if len(batch.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
}
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, batch, false)
if err != nil {
return nil, err
}
}
t, err := m.Forward(ctx, batch)
if err != nil {
return nil, err
}
ctx.Forward(t)
return t, nil
}

View File

@@ -0,0 +1,58 @@
//go:build mlx
package gemma3
import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
"github.com/ollama/ollama/x/ml/nn/pooling"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/model/input"
)
type embedModel struct {
model.Base
model.SentencePiece
*TextModel
poolingType pooling.Type
Dense [2]*nn.Linear `gguf:"dense"`
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates)
}
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil
}
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
TextModel: newTextModel(c),
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
}
return m, nil
}

View File

@@ -0,0 +1,157 @@
//go:build mlx
package gemma3
import (
"bytes"
"image"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/model/input"
)
type Model struct {
model.Base
model.SentencePiece
*VisionModel `gguf:"vision_tower.vision_model"`
*TextModel `gguf:"language_model.model"`
*MultiModalProjector `gguf:"multi_modal_projector"`
ImageProcessor
}
var _ model.MultimodalProcessor = (*Model)(nil)
type MultiModalProjector struct {
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
InputProjection *nn.Linear `gguf:"mm_input_projection_weight"` // TODO .weight vs _weight
tokensPerImage int
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
l := visionOutputs.Dim(0)
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
patchesPerImage := imageSize / patchSize
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
visionOutputs = visionOutputs.Matmul(ctx, p.InputProjection.Weight.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false))
return visionOutputs
}
func New(c fs.Config) (model.Model, error) {
// slog.Info("XXX Config", "c", c)
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
MultiModalProjector: &MultiModalProjector{
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
},
}
// slidingWindowLen := int32(c.Uint("attention.sliding_window"))
// m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
// TODO need to implement sliding window...
m.Cache = kvcache.NewCausalCache()
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues := ctx.Input().FromFloats(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
)
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
return []input.Multimodal{{Tensor: visionOutputs}}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal[0].Tensor
result = append(result,
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
&input.Input{Token: 255999}, // "<start_of_image>""
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
)
// add image token placeholders
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
result = append(result,
&input.Input{Token: 256000}, // <end_of_image>
&input.Input{Token: 108}, // "\n\n"
)
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("gemma3", New)
model.Register("gemma3_embed", newEmbedModel)
}

View File

@@ -0,0 +1,211 @@
//go:build mlx
package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
"github.com/ollama/ollama/x/model/input"
)
type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen int
eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
largeModelScaling bool
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"embed_tokens"`
Layers []TextLayer `gguf:"layers"`
OutputNorm *nn.RMSNorm `gguf:"norm"`
Output *nn.Linear `gguf:"embed_tokens"`
*TextConfig
}
const (
gemmaGlobalCacheCount = 6
gemma27BLayerCount = 62
)
// const (
// cacheTypeSWA = iota
// cacheTypeCausal
// )
func newTextModel(c fs.Config) *TextModel {
numBlocks := int(c.Uint("block_count"))
m := TextModel{
Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")), // 2560 -- config.json: text_config.hidden_size
numHeads: int(c.Uint("attention.head_count")), // 8 -- hard coded in python implementation for the model, 4 in some places, then overridden as 8
numKVHeads: int(c.Uint("attention.head_count_kv")), // 4 -- same as above
attnKeyLen: int(c.Uint("attention.key_length", 256)), //256 -- rope settings, hardcoded in model definition python
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), // 1e-06 - hardcoded in model definition python
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), // 10000 - hardcoded in python
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), // 1e+06 - hardcoded in python
ropeScale: 1, // 1 - default is 1, implied in python code
// vocabSize: vocabSize, // 262144
// attnValLen: int(c.Uint("attention.value_length", 256)), //256
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
// (8 instead of 1)
// ropeScale: c.Float("rope.scaling.factor", 1.0),
},
}
if numBlocks == gemma27BLayerCount {
m.largeModelScaling = true
}
return &m
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"q_proj"`
QueryNorm *nn.RMSNorm `gguf:"q_norm"`
Key *nn.Linear `gguf:"k_proj"`
KeyNorm *nn.RMSNorm `gguf:"k_norm"`
Value *nn.Linear `gguf:"v_proj"`
Output *nn.Linear `gguf:"o_proj"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
B := hiddenState.Dim(0)
L := hiddenState.Dim(1)
ropeBase := opts.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase
}
q := sa.Query.Forward(ctx, hiddenState)
k := sa.Key.Forward(ctx, hiddenState)
v := sa.Value.Forward(ctx, hiddenState)
q = q.Reshape(ctx, B, L, opts.numHeads, -1).Transpose(ctx, 0, 2, 1, 3)
k = k.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3)
v = v.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
traditional := false
q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
k = k.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
// TODO - this is wrong somehow so commenting out
// if opts.largeModelScaling {
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
// } else {
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
// }
scaleFactor := math.Pow(256, -0.5)
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Transpose(ctx, 0, 2, 1, 3).Reshape(ctx, B, L, -1)
return sa.Output.Forward(ctx, kqv)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// ropeBase := m.TextConfig.ropeLocalBase
// if (layer+1)%gemmaGlobalCacheCount == 0 {
// ropeBase = m.TextConfig.ropeGlobalBase
// }
// q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
panic("not yet implemented")
// return key.RoPE(ctx, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
}
type TextMLP struct {
Up *nn.Linear `gguf:"up_proj"`
Down *nn.Linear `gguf:"down_proj"`
Gate *nn.Linear `gguf:"gate_proj"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type TextLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"`
SelfAttention *TextSelfAttention `gguf:"self_attn"`
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_layernorm"`
MLPNorm *nn.RMSNorm `gguf:"pre_feedforward_layernorm"`
MLP *TextMLP `gguf:"mlp"`
PostMLPNorm *nn.RMSNorm `gguf:"post_feedforward_layernorm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, outputs ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, offset, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.TakeAxes(ctx, outputs, 1)
residual = residual.TakeAxes(ctx, outputs, 1)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts) // TODO this is where it goes bad most likely...
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
// set image embeddings
// var except []int
// for _, image := range batch.Multimodal {
// visionOutputs := image.Multimodal[0].Tensor
// ctx.Forward(visionOutputs.Copy(ctx, hiddenState.AsStrided(ctx,
// []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)},
// []int{image.Index * hiddenState.Stride(1)}, 0)))
// for i := range visionOutputs.Dim(1) {
// except = append(except, image.Index+i)
// }
// }
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers
if cache != nil {
// cacheType := cacheTypeSWA
// if (i+1)%gemmaGlobalCacheCount == 0 {
// cacheType = cacheTypeCausal
// }
cache.SetLayer(i)
// TODO this needs to come back
// wc := cache.(*kvcache.WrapperCache)
// wc.SetLayerType(cacheType)
// if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
// causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
// }
}
var offset int
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
offset = batch.Offset
lastLayerOutputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, lastLayerOutputs, offset, cache, m.TextConfig)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}

View File

@@ -0,0 +1,121 @@
//go:build mlx
package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
)
var batchSize int = 1
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"self_attn.q_proj"`
Key *nn.Linear `gguf:"self_attn.k_proj"`
Value *nn.Linear `gguf:"self_attn.v_proj"`
Output *nn.Linear `gguf:"self_attn.out_proj"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention)
return hiddenState
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
return hiddenState
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
MLP *VisionMLP `gguf:"mlp"`
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// self attention
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// feed forward
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize, numHeads int
imageSize, patchSize int
eps float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"embeddings.patch_embedding"`
PositionEmbedding *nn.Embedding `gguf:"embeddings.position_embedding"`
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
Layers []VisionEncoderLayer `gguf:"encoder.layers"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeInt32)
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
}
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
eps: c.Float("vision.attention.layer_norm_epsilon"),
},
}
}

View File

@@ -0,0 +1,60 @@
//go:build mlx
package gemma3
import (
"image"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize, patchSize, numChannels int
}
func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
numChannels: int(c.Uint("vision.num_channels")),
}
}
func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
var pixelVals, rVals, gVals, bVals []float32
bounds := img.Bounds()
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
c := img.At(x, y)
r, g, b, _ := c.RGBA()
rVal := float32(r>>8) / 255.0
gVal := float32(g>>8) / 255.0
bVal := float32(b>>8) / 255.0
rVal = (rVal - mean[0]) / std[0]
gVal = (gVal - mean[1]) / std[1]
bVal = (bVal - mean[2]) / std[2]
rVals = append(rVals, rVal)
gVals = append(gVals, gVal)
bVals = append(bVals, bVal)
}
}
pixelVals = append(pixelVals, rVals...)
pixelVals = append(pixelVals, gVals...)
pixelVals = append(pixelVals, bVals...)
return pixelVals
}
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
outputSize := image.Point{p.imageSize, p.imageSize}
newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
return data, nil
}

3
x/model/models/models.go Normal file
View File

@@ -0,0 +1,3 @@
package models
// _ "github.com/ollama/ollama/x/model/models/gemma3"

249
x/model/sentencepiece.go Normal file
View File

@@ -0,0 +1,249 @@
package model
import (
"container/heap"
"fmt"
"log/slog"
"strconv"
"strings"
"github.com/ollama/ollama/logutil"
)
const spmWhitespaceSep = "▁"
type SentencePiece struct {
maxTokenLen int
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
}
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
var maxTokenLen int
for cnt := range vocab.Types {
switch vocab.Types[cnt] {
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
fallthrough
default:
counter[int(vocab.Types[cnt])] += 1
}
}
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePiece{
maxTokenLen: maxTokenLen,
vocab: vocab,
}
}
func (spm SentencePiece) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
if id := spm.vocab.Encode(text); id >= 0 {
ids = append(ids, id)
continue
}
q := &queue{}
heap.Init(q)
runes := []rune(text)
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
size: len(left) + len(right),
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(q, pair)
}
}
for q.Len() > 0 {
pair := heap.Pop(q).(*candidate)
left, right := merges[pair.a], merges[pair.b]
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
heap.Push(q, pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
heap.Push(q, pair)
}
}
for _, merge := range merges {
if token := string(merge.runes); token != "" {
id := spm.vocab.Encode(token)
if id >= 0 {
ids = append(ids, id)
continue
}
// Fallback to byte tokenization
var result []int32
for _, b := range []byte(token) {
byteToken := fmt.Sprintf("<0x%02X>", b)
unknownID := spm.vocab.Encode(byteToken)
if unknownID >= 0 {
result = append(result, unknownID)
} else {
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
}
}
ids = append(ids, result...)
}
}
}
if addSpecial {
ids = spm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
type candidate struct {
a, b int
score float32
size int
}
type queue []*candidate
func (q queue) Len() int { return len(q) }
func (q queue) Less(i, j int) bool {
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
}
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
func (q *queue) Push(x interface{}) {
item := x.(*candidate)
*q = append(*q, item)
}
func (q *queue) Pop() interface{} {
old := *q
n := len(old)
item := old[n-1]
*q = old[0 : n-1]
return item
}
func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizers that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
if err != nil {
return "", fmt.Errorf("failed to parse hex byte: %v", err)
}
if err := sb.WriteByte(byte(byteVal)); err != nil {
return "", err
}
} else {
if _, err := sb.WriteString(data); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "ids", ids, "string", sb.String())
return sb.String(), nil
}

Some files were not shown because too many files have changed in this diff Show More