Compare commits

...

48 Commits

Author SHA1 Message Date
Michael Yang
03db70c464 nil keys 2026-02-05 17:22:22 -08:00
Michael Yang
4e70a31da1 glm4.7 2026-02-05 17:22:22 -08:00
Michael Yang
f696b6c4d0 save 2026-02-05 17:22:22 -08:00
Michael Yang
9b952bbea4 cleanup afterloadfunc 2026-02-05 17:22:22 -08:00
Michael Yang
ce262df633 fix build duplicate symbols 2026-02-05 17:22:22 -08:00
Michael Yang
e0a4f0aa14 mlxrunner 2026-02-05 17:22:22 -08:00
Michael Yang
9cb6b8ac6d draft: model manifest file interface
this change makes it easier to address blobs by their original names
without rebuilding the filesystem structure
2026-02-05 17:22:22 -08:00
Michael Yang
87e01abd59 resolve circular dependency 2026-02-05 17:21:41 -08:00
Michael Yang
26448f1e7d mv x/mlxrunner x/imagegen 2026-02-05 17:21:41 -08:00
Michael Yang
693101589a simplify runner selection 2026-02-05 17:21:41 -08:00
Michael Yang
b8dff7e342 clean up unused directories 2026-02-05 17:21:41 -08:00
Michael Yang
b350656b23 move tokenizer to separate package 2026-02-05 17:21:37 -08:00
Parth Sareen
8a4b77f9da cmd: set context limits for cloud models in opencode (#14107) 2026-02-05 16:36:46 -08:00
Parth Sareen
5f53fe7884 cmd: ollama launch improvements (#14099) 2026-02-05 15:08:17 -08:00
Bruce MacDonald
7ab4ca0e7f scripts: add macOS support to install.sh (#14060)
Allow installing Ollama on MacOS directly from the command line. This is in line with other CLI tools and results in a more streamlined experience when the user is looking to use the CLI specifically.
2026-02-05 14:59:01 -08:00
Jeffrey Morgan
e36f389e82 scheduler: default parallel=1 for qwen3next/lfm (#14103) 2026-02-05 12:48:25 -08:00
Jesse Gross
c61023f554 ollamarunner: Fix off by one error with numPredict
When numPredict is set, the user will receive one less token
than the requested limit. In addition, the stats will incorrectly
show the number of tokens returned as the limit. In cases where
numPredict is not set, the number of tokens is reported correctly.

This occurs because numPredict is checked when setting up the next
batch but hitting the limit will terminate the current batch as well.
Instead, is is better to check the limit as we actually predict them.
2026-02-04 17:14:24 -08:00
Jeffrey Morgan
d25535c3f3 qwen3next: avoid inplace sigmoid for shared gate (#14077) 2026-02-04 15:50:02 -08:00
Bruce MacDonald
c323161f24 cmd: helpful error message for remote models (#14057)
When trying to use cloud model with OLLAMA_HOST="ollama.com" while not signed in a helpful error message is displayed when the user is not signed in telling them they must sign in to use cloud models. This should be the same experience for models which specify a remote instance.
2026-02-04 14:55:11 -08:00
Jeffrey Morgan
255579aaa7 qwen3next: fix issue in delta net (#14075)
gDiffExp was being broadcast across the wrong axis when multiplying with k. This fix reshapes gDiffExp to [1, chunkSize, nChunks, ...]
2026-02-04 13:40:38 -08:00
Jeffrey Morgan
f7102ba826 runner: discard compute results if sequence replaced mid-batch (#14072)
If a sequence is replaced in s.seqs while a batch is computing, the old logits can be decoded into the new sequence. This change rechecks the sequence pointer after compute and skips decoding for replaced entries, preventing stale results from being applied.
2026-02-04 13:19:48 -08:00
Jeffrey Morgan
cefabd79a8 Revert "cmd: claude launch improvements (#14064)" (#14071)
This reverts commit ee25219edd.
2026-02-04 09:10:37 -08:00
Jeffrey Morgan
df70249520 server: optimize chatPrompt to reduce tokenization calls (#14040)
Change the truncation algorithm to start with all messages and remove
from the front until it fits, rather than adding messages one at a time
from the back. This reduces tokenization calls from O(n) to O(1) in the
common case where all messages fit in context.
2026-02-04 01:21:31 -08:00
Jeffrey Morgan
77eb2ca619 model: add qwen3-next architecture (#14051) 2026-02-03 23:27:21 -08:00
Parth Sareen
ee25219edd cmd: claude launch improvements (#14064) 2026-02-03 19:33:58 -08:00
Jeffrey Morgan
b1fccabb34 Revert "Update vendored llama.cpp to b7847" (#14061) 2026-02-03 18:39:36 -08:00
Bruce MacDonald
a6355329bf cmd: open browser on ollama signin when available (#14055)
When a browser is available open it to the connect URL automatically when running the `ollama signin` command. Browser is not opened in any other unauthorized scenario.
2026-02-03 16:42:09 -08:00
Parth Sareen
0398b24b42 cmd: launch defaults (#14035) 2026-02-02 23:19:11 -08:00
Parth Sareen
75b1dddf91 cmd: launch extra params (#14039) 2026-02-03 02:03:33 -05:00
Parth Sareen
e1e80ffc3e cmd/config: move config location (#14034) 2026-02-02 22:48:51 -05:00
Aleksandr Vukmirovich
71896485fd anthropic: add InputTokens to streaming response (#13934)
---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-02-02 18:29:37 -08:00
Jeffrey Morgan
ef00199fb4 Update vendor ggml code to a5bb8ba4 (#13832)
Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
Co-authored-by: Gabe Goodhart <ghart@us.ibm.com>
Co-authored-by: Shalini Salomi Bodapati <Shalini.Salomi.Bodapati@ibm.com>
2026-02-02 17:31:59 -08:00
Jeffrey Morgan
8f4a008139 Add GLM-OCR vision model support (#14024) 2026-02-02 15:39:18 -08:00
Patrick Devine
d8cc798c2b glm 4.7 flash support on experimental engine (#13838) 2026-02-02 15:22:11 -08:00
Richard Lyons
6582f6da5c llm: Make "do load request" error message more informative 2026-02-02 11:13:21 -08:00
Jesse Gross
0334ffa625 server: use tiered VRAM-based default context length
Replace binary low VRAM mode with tiered VRAM thresholds that set
default context lengths for all models:

- < 24 GiB VRAM: 4,096 context
- 24-48 GiB VRAM: 32,768 context
- >= 48 GiB VRAM: 262,144 context
2026-02-02 10:47:09 -08:00
Jesse Gross
d11fbd2c60 server: fix ollama ps showing configured instead of actual context length
When context length is clamped to the model's trained context length,
ollama ps now shows the actual clamped value instead of the originally
configured value.
2026-02-02 10:47:09 -08:00
Jeffrey Morgan
6a7c3f188e openclaw: run onboarding for fresh installs (#14006)
When launching OpenClaw without prior onboarding, run the onboarding
wizard instead of going straight to gateway. This ensures proper
gateway configuration (mode, token, etc.) before first use.

- Add onboarded() to check for wizard.lastRunAt marker in config
- Run onboard with --auth-choice skip --gateway-token ollama for fresh installs
- Existing installs (onboarding completed) run gateway directly
2026-02-01 13:46:45 -08:00
Jeffrey Morgan
427e2c962a docs: add redirect from clawdbot to openclaw (#14004) 2026-01-31 20:50:42 -08:00
Thanh Nguyen
27db7f806f cmd/config: rename integration to openclaw (#13979)
---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-01-31 18:31:13 -05:00
Dhiraj Lochib
3590fbfa76 runner: fix typo 'baackend' -> 'backend' in error messages (#13645)
Fix typo in three error messages where 'baackend' was written instead
of 'backend' in the /health endpoint handler when initializing the
dummy model load.
2026-01-31 13:26:20 -08:00
noureldin-azzab
cd0094f772 added stakpak to web & desktop (#13961) 2026-01-31 13:04:34 -08:00
Louis Beaumont
06bc8e6712 docs: add Screenpipe to Community Integrations (#13906)
Screenpipe is a 24/7 screen & mic recording tool that uses Ollama
for local LLM-powered search and AI features. 16k+ GitHub stars.
2026-01-31 12:49:52 -08:00
frob
fc5f9bb448 docs: remove unsupported quantizations (#13982) 2026-01-31 12:46:20 -08:00
frob
a0740f7ef7 docs: add GB10 to supported devices (#13987) 2026-01-31 12:45:27 -08:00
Parth Sareen
a0923cbdd0 cmd: ollama launch add placeholder text for selector (#13966) 2026-01-29 09:48:49 -08:00
Seokrin Taron Sung
f92e362b2e cmd: capitalize Ollama in serve command help text (#13965) 2026-01-29 09:47:53 -08:00
Tincho
aa23d8ecd2 docs: update installation command for OpenCode CLI (#13971) 2026-01-29 09:47:02 -08:00
246 changed files with 29360 additions and 9585 deletions

22
.github/workflows/test-install.yaml vendored Normal file
View File

@@ -0,0 +1,22 @@
name: test-install
on:
pull_request:
paths:
- 'scripts/install.sh'
- '.github/workflows/test-install.yaml'
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Run install script
run: sh ./scripts/install.sh
env:
OLLAMA_NO_START: 1 # do not start app
- name: Verify ollama is available
run: ollama --version

View File

@@ -358,6 +358,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Odin Runes](https://github.com/leonid20000/OdinRunes)
- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
- [Screenpipe](https://github.com/mediar-ai/screenpipe) (24/7 screen & mic recording with AI-powered search, uses Ollama for local LLM features)
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
@@ -465,6 +466,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
- [Stakpak](https://github.com/stakpak/agent) (An open source, vendor neutral DevOps agent that works with any model, and any stack, for teams who just want to ship)
### Cloud

159
anthropic/anthropic.go Normal file → Executable file
View File

@@ -211,6 +211,7 @@ type MessageDelta struct {
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
@@ -517,24 +518,26 @@ func mapStopReason(reason string, hasToolCalls bool) string {
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
ID: id,
Model: model,
firstWrite: true,
estimatedInputTokens: estimatedInputTokens,
toolCallsSent: make(map[string]bool),
}
}
@@ -550,7 +553,11 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
if c.firstWrite {
c.firstWrite = false
// Use actual metrics if available, otherwise use estimate
c.inputTokens = r.Metrics.PromptEvalCount
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
c.inputTokens = c.estimatedInputTokens
}
events = append(events, StreamEvent{
Event: "message_start",
@@ -721,6 +728,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
})
}
c.inputTokens = r.Metrics.PromptEvalCount
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
@@ -732,6 +740,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
StopReason: stopReason,
},
Usage: DeltaUsage{
InputTokens: c.inputTokens,
OutputTokens: c.outputTokens,
},
},
@@ -776,3 +785,123 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
}
return args
}
// CountTokensRequest represents an Anthropic count_tokens request
type CountTokensRequest struct {
Model string `json:"model"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
}
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
func EstimateInputTokens(req MessagesRequest) int {
return estimateTokens(CountTokensRequest{
Model: req.Model,
Messages: req.Messages,
System: req.System,
Tools: req.Tools,
Thinking: req.Thinking,
})
}
// CountTokensResponse represents an Anthropic count_tokens response
type CountTokensResponse struct {
InputTokens int `json:"input_tokens"`
}
// estimateTokens returns a rough estimate of tokens (len/4).
// TODO: Replace with actual tokenization via Tokenize API for accuracy.
// Current len/4 heuristic is a rough approximation (~4 chars/token average).
func estimateTokens(req CountTokensRequest) int {
var totalLen int
// Count system prompt
if req.System != nil {
totalLen += countAnyContent(req.System)
}
// Count messages
for _, msg := range req.Messages {
// Count role (always present)
totalLen += len(msg.Role)
// Count content
contentLen := countAnyContent(msg.Content)
totalLen += contentLen
}
for _, tool := range req.Tools {
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
}
// Return len/4 as rough token estimate, minimum 1 if there's any content
tokens := totalLen / 4
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
tokens = 1
}
return tokens
}
func countAnyContent(content any) int {
if content == nil {
return 0
}
switch c := content.(type) {
case string:
return len(c)
case []any:
total := 0
for _, block := range c {
total += countContentBlock(block)
}
return total
default:
if data, err := json.Marshal(content); err == nil {
return len(data)
}
return 0
}
}
func countContentBlock(block any) int {
blockMap, ok := block.(map[string]any)
if !ok {
if s, ok := block.(string); ok {
return len(s)
}
return 0
}
total := 0
blockType, _ := blockMap["type"].(string)
if text, ok := blockMap["text"].(string); ok {
total += len(text)
}
if thinking, ok := blockMap["thinking"].(string); ok {
total += len(thinking)
}
if blockType == "tool_use" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data)
}
}
if blockType == "tool_result" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data)
}
}
if source, ok := blockMap["source"].(map[string]any); ok {
if data, ok := source["data"].(string); ok {
total += len(data)
}
}
return total
}

142
anthropic/anthropic_test.go Normal file → Executable file
View File

@@ -321,8 +321,6 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
@@ -605,7 +603,7 @@ func TestGenerateMessageID(t *testing.T) {
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
// First chunk
resp1 := api.ChatResponse{
@@ -642,7 +640,7 @@ func TestStreamConverter_Basic(t *testing.T) {
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events2 := conv.Process(resp2)
@@ -650,6 +648,24 @@ func TestStreamConverter_Basic(t *testing.T) {
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_delta" {
if data, ok := e.Data.(MessageDeltaEvent); ok {
if data.Type != "message_delta" {
t.Errorf("unexpected data type: %+v", data)
}
if data.Delta.StopReason != "end_turn" {
t.Errorf("unexpected stop reason: %+v", data.Delta.StopReason)
}
if data.Usage.InputTokens != 10 || data.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", data.Usage)
}
} else {
t.Errorf("unexpected data: %+v", e.Data)
}
}
if e.Event == "message_stop" {
hasStop = true
}
@@ -660,7 +676,7 @@ func TestStreamConverter_Basic(t *testing.T) {
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
resp := api.ChatResponse{
Model: "test-model",
@@ -713,7 +729,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
@@ -760,7 +776,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
@@ -824,10 +840,6 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
@@ -881,11 +893,9 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
resp := api.ChatResponse{
Model: "test-model",
@@ -919,7 +929,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
conv := NewStreamConverter("msg_123", "test-model", 0)
resp := api.ChatResponse{
Model: "test-model",
@@ -951,3 +961,105 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
}
})
}
func TestEstimateTokens_SimpleMessage(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "Hello, world!"},
},
}
tokens := estimateTokens(req)
// "user" (4) + "Hello, world!" (13) = 17 chars / 4 = 4 tokens
if tokens < 1 {
t.Errorf("expected at least 1 token, got %d", tokens)
}
// Sanity check: shouldn't be wildly off
if tokens > 10 {
t.Errorf("expected fewer than 10 tokens for short message, got %d", tokens)
}
}
func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
tokens := estimateTokens(req)
// System prompt adds to count
if tokens < 5 {
t.Errorf("expected at least 5 tokens with system prompt, got %d", tokens)
}
}
func TestEstimateTokens_WithTools(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "What's the weather?"},
},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get the current weather for a location",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
},
},
}
tokens := estimateTokens(req)
// Tools add significant content
if tokens < 10 {
t.Errorf("expected at least 10 tokens with tools, got %d", tokens)
}
}
func TestEstimateTokens_WithThinking(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this carefully...",
},
map[string]any{
"type": "text",
"text": "Here is my response.",
},
},
},
},
}
tokens := estimateTokens(req)
// Thinking content should be counted
if tokens < 10 {
t.Errorf("expected at least 10 tokens with thinking content, got %d", tokens)
}
}
func TestEstimateTokens_EmptyContent(t *testing.T) {
req := CountTokensRequest{
Model: "test-model",
Messages: []MessageParam{},
}
tokens := estimateTokens(req)
if tokens != 0 {
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
}
}

View File

@@ -466,3 +466,25 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
}
return &resp, nil
}
// AliasRequest is the request body for creating or updating a model alias.
type AliasRequest struct {
Alias string `json:"alias"`
Target string `json:"target"`
PrefixMatching bool `json:"prefix_matching,omitempty"`
}
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
}
// AliasDeleteRequest is the request body for deleting a model alias.
type AliasDeleteRequest struct {
Alias string `json:"alias"`
}
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
}

View File

@@ -29,6 +29,7 @@ import (
"github.com/containerd/console"
"github.com/mattn/go-runewidth"
"github.com/olekukonko/tablewriter"
"github.com/pkg/browser"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
@@ -52,7 +53,7 @@ import (
"github.com/ollama/ollama/x/imagegen"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
@@ -366,14 +367,25 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
return err
} else if info.RemoteHost != "" {
// Cloud model, no need to load/unload
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
// Check if user is signed in for ollama.com cloud models
if isCloud {
if _, err := client.Whoami(cmd.Context()); err != nil {
return err
}
}
if opts.ShowConnect {
p.StopAndClear()
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
if isCloud {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
}
}
return nil
}
@@ -663,6 +675,7 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
fmt.Println()
if aErr.SigninURL != "" {
_ = browser.OpenURL(aErr.SigninURL)
fmt.Printf(ConnectInstructions, aErr.SigninURL)
}
return nil
@@ -1750,7 +1763,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return err
}
if err := startApp(cmd.Context(), client); err != nil {
return fmt.Errorf("ollama server not responding - %w", err)
return err
}
}
return nil
@@ -1888,7 +1901,7 @@ func NewCLI() *cobra.Command {
serveCmd := &cobra.Command{
Use: "serve",
Aliases: []string{"start"},
Short: "Start ollama",
Short: "Start Ollama",
Args: cobra.ExactArgs(0),
RunE: RunServer,
}

View File

@@ -3,6 +3,7 @@ package cmd
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -1553,7 +1554,7 @@ func TestShowInfoImageGen(t *testing.T) {
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
QuantizationLevel: "Q8",
},
Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0",
@@ -1565,7 +1566,7 @@ func TestShowInfoImageGen(t *testing.T) {
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization FP8 \n" +
" quantization Q8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +
@@ -1659,3 +1660,103 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
t.Error("Copy Think should not be affected by original modification")
}
}
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
tests := []struct {
name string
remoteHost string
whoamiStatus int
whoamiResp any
expectedError string
}{
{
name: "ollama.com cloud model - user signed in",
remoteHost: "https://ollama.com",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
},
{
name: "ollama.com cloud model - user not signed in",
remoteHost: "https://ollama.com",
whoamiStatus: http.StatusUnauthorized,
whoamiResp: map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
},
expectedError: "unauthorized",
},
{
name: "non-ollama.com remote - no auth check",
remoteHost: "https://other-remote.com",
whoamiStatus: http.StatusUnauthorized, // should not be called
whoamiResp: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
whoamiCalled := false
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(api.ShowResponse{
RemoteHost: tt.remoteHost,
RemoteModel: "test-model",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
case "/api/me":
whoamiCalled = true
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(tt.whoamiStatus)
if tt.whoamiResp != nil {
if err := json.NewEncoder(w).Encode(tt.whoamiResp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
default:
http.NotFound(w, r)
}
}))
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
opts := &runOptions{
Model: "test-cloud-model",
ShowConnect: false,
}
err := loadOrUnloadModel(cmd, opts)
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
if !whoamiCalled {
t.Error("expected whoami to be called for ollama.com cloud model")
}
} else {
if whoamiCalled {
t.Error("whoami should not be called for non-ollama.com remote")
}
}
if tt.expectedError != "" {
if err == nil {
t.Errorf("expected error containing %q, got nil", tt.expectedError)
} else {
var authErr api.AuthorizationError
if !errors.As(err, &authErr) {
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
}
}
} else {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
}
})
}
}

View File

@@ -1,25 +1,32 @@
package config
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner for Claude Code integration
// Claude implements Runner and AliasConfigurer for Claude Code integration
type Claude struct{}
// Compile-time check that Claude implements AliasConfigurer
var _ AliasConfigurer = (*Claude)(nil)
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string) []string {
func (c *Claude) args(model string, extra []string) []string {
var args []string
if model != "" {
return []string{"--model", model}
args = append(args, "--model", model)
}
return nil
args = append(args, extra...)
return args
}
func (c *Claude) findPath() (string, error) {
@@ -41,13 +48,13 @@ func (c *Claude) findPath() (string, error) {
return fallback, nil
}
func (c *Claude) Run(model string) error {
func (c *Claude) Run(model string, args []string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model)...)
cmd := exec.Command(claudePath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
@@ -58,3 +65,104 @@ func (c *Claude) Run(model string) error {
)
return cmd.Run()
}
// ConfigureAliases sets up model aliases for Claude Code.
// model: the model to use (if empty, user will be prompted to select)
// aliases: existing alias configuration to preserve/update
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
// there is a better strategy for prompt caching on local models.
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
aliases := make(map[string]string)
for k, v := range existingAliases {
aliases[k] = v
}
if model != "" {
aliases["primary"] = model
}
if !force && aliases["primary"] != "" {
client, _ := api.ClientFromEnvironment()
if isCloudModel(ctx, client, aliases["primary"]) {
if isCloudModel(ctx, client, aliases["fast"]) {
return aliases, false, nil
}
} else {
delete(aliases, "fast")
return aliases, false, nil
}
}
items, existingModels, cloudModels, client, err := listModels(ctx)
if err != nil {
return nil, false, err
}
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
if aliases["primary"] == "" || force {
primary, err := selectPrompt("Select model:", items)
fmt.Fprintf(os.Stderr, "\033[3A\033[J")
if err != nil {
return nil, false, err
}
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
return nil, false, err
}
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
return nil, false, err
}
aliases["primary"] = primary
}
if isCloudModel(ctx, client, aliases["primary"]) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = aliases["primary"]
}
} else {
delete(aliases, "fast")
}
return aliases, true, nil
}
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
// prevent stale routing to a previous cloud model.
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
if aliases["fast"] == "" {
for _, prefix := range prefixes {
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
}
return nil
}
prefixAliases := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
var errs []string
for prefix, target := range prefixAliases {
req := &api.AliasRequest{
Alias: prefix,
Target: target,
PrefixMatching: true,
}
if err := client.SetAliasExperimental(ctx, req); err != nil {
errs = append(errs, prefix)
}
}
if len(errs) > 0 {
return fmt.Errorf("failed to set aliases: %v", errs)
}
return nil
}

View File

@@ -84,17 +84,21 @@ func TestClaudeArgs(t *testing.T) {
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
{"empty model", "", nil},
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
{"empty model", "", nil, nil},
{"with model and verbose", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
{"with allowed tools", "llama3.2", []string{"--allowedTools", "Read,Write,Bash"}, []string{"--model", "llama3.2", "--allowedTools", "Read,Write,Bash"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}

View File

@@ -14,20 +14,21 @@ type Codex struct{}
func (c *Codex) String() string { return "Codex" }
func (c *Codex) args(model string) []string {
func (c *Codex) args(model string, extra []string) []string {
args := []string{"--oss"}
if model != "" {
args = append(args, "-m", model)
}
args = append(args, extra...)
return args
}
func (c *Codex) Run(model string) error {
func (c *Codex) Run(model string, args []string) error {
if err := checkCodexVersion(); err != nil {
return err
}
cmd := exec.Command("codex", c.args(model)...)
cmd := exec.Command("codex", c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

View File

@@ -11,17 +11,20 @@ func TestCodexArgs(t *testing.T) {
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", []string{"--oss"}},
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", nil, []string{"--oss"}},
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}

View File

@@ -6,13 +6,15 @@ import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
)
type integration struct {
Models []string `json:"models"`
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
}
type config struct {
@@ -20,6 +22,14 @@ type config struct {
}
func configPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "config.json"), nil
}
func legacyConfigPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
@@ -27,6 +37,46 @@ func configPath() (string, error) {
return filepath.Join(home, ".ollama", "config", "config.json"), nil
}
// migrateConfig moves the config from the legacy path to ~/.ollama/config.json
func migrateConfig() (bool, error) {
oldPath, err := legacyConfigPath()
if err != nil {
return false, err
}
oldData, err := os.ReadFile(oldPath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
var js json.RawMessage
if err := json.Unmarshal(oldData, &js); err != nil {
slog.Warn("legacy config has invalid JSON, skipping migration", "path", oldPath, "error", err)
return false, nil
}
newPath, err := configPath()
if err != nil {
return false, err
}
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
return false, err
}
if err := os.WriteFile(newPath, oldData, 0o644); err != nil {
return false, fmt.Errorf("write new config: %w", err)
}
_ = os.Remove(oldPath)
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
slog.Info("migrated config", "from", oldPath, "to", newPath)
return true, nil
}
func load() (*config, error) {
path, err := configPath()
if err != nil {
@@ -34,6 +84,11 @@ func load() (*config, error) {
}
data, err := os.ReadFile(path)
if err != nil && os.IsNotExist(err) {
if migrated, merr := migrateConfig(); merr == nil && migrated {
data, err = os.ReadFile(path)
}
}
if err != nil {
if os.IsNotExist(err) {
return &config{Integrations: make(map[string]*integration)}, nil
@@ -79,8 +134,16 @@ func saveIntegration(appName string, models []string) error {
return err
}
cfg.Integrations[strings.ToLower(appName)] = &integration{
Models: models,
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
var aliases map[string]string
if existing != nil && existing.Aliases != nil {
aliases = existing.Aliases
}
cfg.Integrations[key] = &integration{
Models: models,
Aliases: aliases,
}
return save(cfg)
@@ -100,6 +163,29 @@ func loadIntegration(appName string) (*integration, error) {
return ic, nil
}
func saveAliases(appName string, aliases map[string]string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}
cfg, err := load()
if err != nil {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
if existing == nil {
existing = &integration{}
}
// Replace aliases entirely (not merge) so deletions are persisted
existing.Aliases = aliases
cfg.Integrations[key] = existing
return save(cfg)
}
func listIntegrations() ([]integration, error) {
cfg, err := load()
if err != nil {

View File

@@ -0,0 +1,677 @@
package config
import (
"context"
"errors"
"os"
"path/filepath"
"testing"
)
func TestSetAliases_CloudModel(t *testing.T) {
// Test the SetAliases logic by checking the alias map behavior
aliases := map[string]string{
"primary": "kimi-k2.5:cloud",
"fast": "kimi-k2.5:cloud",
}
// Verify fast is set (cloud model behavior)
if aliases["fast"] == "" {
t.Error("cloud model should have fast alias set")
}
if aliases["fast"] != aliases["primary"] {
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
}
}
func TestSetAliases_LocalModel(t *testing.T) {
aliases := map[string]string{
"primary": "llama3.2:latest",
}
// Simulate local model behavior: fast should be empty
delete(aliases, "fast")
if aliases["fast"] != "" {
t.Error("local model should have empty fast alias")
}
}
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// First save with both primary and fast
initial := map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}
if err := saveAliases("claude", initial); err != nil {
t.Fatalf("failed to save initial aliases: %v", err)
}
// Verify both are saved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
}
// Now save without fast (simulating switch to local model)
updated := map[string]string{
"primary": "local-model",
// fast intentionally missing
}
if err := saveAliases("claude", updated); err != nil {
t.Fatalf("failed to save updated aliases: %v", err)
}
// Verify fast is GONE (not merged/preserved)
loaded, err = loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
}
if loaded.Aliases["primary"] != "local-model" {
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
}
}
func TestSaveAliases_PreservesModels(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// First save integration with models
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
t.Fatalf("failed to save integration: %v", err)
}
// Then update aliases
aliases := map[string]string{"primary": "new-model"}
if err := saveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save aliases: %v", err)
}
// Verify models are preserved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
t.Errorf("models should be preserved, got %v", loaded.Models)
}
}
// TestSaveAliases_EmptyMap clears all aliases
func TestSaveAliases_EmptyMap(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save with aliases
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save empty map
if err := saveAliases("claude", map[string]string{}); err != nil {
t.Fatalf("failed to save empty: %v", err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Aliases) != 0 {
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
}
}
// TestSaveAliases_NilMap handles nil gracefully
func TestSaveAliases_NilMap(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save with aliases first
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save nil map - should clear aliases
if err := saveAliases("claude", nil); err != nil {
t.Fatalf("failed to save nil: %v", err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Aliases) > 0 {
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
}
}
// TestSaveAliases_EmptyAppName returns error
func TestSaveAliases_EmptyAppName(t *testing.T) {
err := saveAliases("", map[string]string{"primary": "model"})
if err == nil {
t.Error("expected error for empty app name")
}
}
func TestSaveAliases_CaseInsensitive(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Load with different case
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model1" {
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
}
// Update with different case
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to update: %v", err)
}
loaded, err = loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
if loaded.Aliases["primary"] != "model2" {
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
}
}
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
func TestSaveAliases_CreatesIntegration(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save aliases for non-existent integration
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
loaded, err := loadIntegration("newintegration")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model" {
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
}
}
func TestConfigureAliases_AliasMap(t *testing.T) {
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
aliases := make(map[string]string)
aliases["primary"] = "cloud-model"
// Simulate cloud model behavior
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "cloud-model" {
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
}
})
t.Run("cloud model preserves custom fast", func(t *testing.T) {
aliases := map[string]string{
"primary": "cloud-model",
"fast": "custom-fast-model",
}
// Simulate cloud model behavior - should preserve existing fast
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "custom-fast-model" {
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
}
})
t.Run("local model clears fast", func(t *testing.T) {
aliases := map[string]string{
"primary": "local-model",
"fast": "should-be-cleared",
}
// Simulate local model behavior
isCloud := false
if !isCloud {
delete(aliases, "fast")
}
if aliases["fast"] != "" {
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
}
})
t.Run("switching cloud to local clears fast", func(t *testing.T) {
// Start with cloud config
aliases := map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}
// Switch to local
aliases["primary"] = "local-model"
isCloud := false
if !isCloud {
delete(aliases, "fast")
}
if aliases["fast"] != "" {
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
}
if aliases["primary"] != "local-model" {
t.Errorf("primary should be updated, got %q", aliases["primary"])
}
})
t.Run("switching local to cloud sets fast", func(t *testing.T) {
// Start with local config (no fast)
aliases := map[string]string{
"primary": "local-model",
}
// Switch to cloud
aliases["primary"] = "cloud-model"
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "cloud-model" {
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
}
})
}
func TestSetAliases_PrefixMapping(t *testing.T) {
// This tests the expected mapping without needing a real client
aliases := map[string]string{
"primary": "my-cloud-model",
"fast": "my-fast-model",
}
expectedMappings := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
t.Errorf("claude-sonnet- should map to primary")
}
if expectedMappings["claude-haiku-"] != "my-fast-model" {
t.Errorf("claude-haiku- should map to fast")
}
}
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
aliases := map[string]string{
"primary": "local-model",
// fast is empty/missing - indicates local model
}
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
// Verify the logic: when fast is empty, we should delete
if aliases["fast"] != "" {
t.Error("fast should be empty for local model")
}
// Verify we have the right prefixes to delete
if len(prefixesToDelete) != 2 {
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
}
}
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate: server fails, config should NOT be saved
serverErr := errors.New("server unavailable")
if serverErr == nil {
t.Error("config should NOT be saved when server fails")
}
}
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate: server succeeds, config should be saved
var serverErr error
if serverErr != nil {
t.Fatal("server should succeed")
}
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("saveAliases failed: %v", err)
}
// Verify it was actually saved
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model" {
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
}
}
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Write config with extra fields
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
os.MkdirAll(filepath.Dir(configPath), 0o755)
// Note: Our config struct only has Integrations, so top-level unknown fields
// won't be preserved by our current implementation. This test documents that.
initialConfig := `{
"integrations": {
"claude": {
"models": ["model1"],
"aliases": {"primary": "model1"},
"unknownField": "should be lost"
}
},
"topLevelUnknown": "will be lost"
}`
os.WriteFile(configPath, []byte(initialConfig), 0o644)
// Update aliases
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Read raw file to check
data, _ := os.ReadFile(configPath)
content := string(data)
// models should be preserved
if !contains(content, "model1") {
t.Error("models should be preserved")
}
// primary should be updated
if !contains(content, "model2") {
t.Error("primary should be updated to model2")
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
c := &Claude{}
var _ AliasConfigurer = c // Compile-time check
}
func TestModelNameEdgeCases(t *testing.T) {
testCases := []struct {
name string
model string
}{
{"simple", "llama3.2"},
{"with tag", "llama3.2:latest"},
{"with cloud tag", "kimi-k2.5:cloud"},
{"with namespace", "library/llama3.2"},
{"with dots", "glm-4.7-flash"},
{"with numbers", "qwen3:8b"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
aliases := map[string]string{"primary": tc.model}
if err := saveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save model %q: %v", tc.model, err)
}
loaded, err := loadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != tc.model {
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
}
})
}
}
func TestSwitchingScenarios(t *testing.T) {
t.Run("cloud to local removes fast", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial cloud config
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
t.Fatal(err)
}
// Switch to local (no fast)
if err := saveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
}
if loaded.Aliases["primary"] != "local-model" {
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
}
})
t.Run("local to cloud adds fast", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial local config
if err := saveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
// Switch to cloud (with fast)
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
}
})
t.Run("cloud to different cloud updates both", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial cloud config
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model-1",
"fast": "cloud-model-1",
}); err != nil {
t.Fatal(err)
}
// Switch to different cloud
if err := saveAliases("claude", map[string]string{
"primary": "cloud-model-2",
"fast": "cloud-model-2",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["primary"] != "cloud-model-2" {
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
}
if loaded.Aliases["fast"] != "cloud-model-2" {
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
}
})
}
func TestToolCapabilityFiltering(t *testing.T) {
t.Run("all models checked for tool capability", func(t *testing.T) {
// Both cloud and local models are checked for tool capability via Show API
// Only models with "tools" in capabilities are included
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
if !m.ToolCapable {
t.Error("tool capable model should be marked as such")
}
})
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
if !m.ToolCapable {
t.Error("ToolCapable field should be accessible")
}
})
}
func TestIsCloudModel_RequiresClient(t *testing.T) {
t.Run("nil client always returns false", func(t *testing.T) {
// isCloudModel now only uses Show API, no suffix detection
if isCloudModel(context.Background(), nil, "model:cloud") {
t.Error("nil client should return false regardless of suffix")
}
if isCloudModel(context.Background(), nil, "local-model") {
t.Error("nil client should return false")
}
})
}
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save aliases with one model
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
t.Fatal(err)
}
// Save integration with same model (this is the pattern we use)
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Aliases["primary"] != loaded.Models[0] {
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
}
})
t.Run("out of sync config is detectable", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate out-of-sync state (like manual edit or bug)
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
// They should be different (this is the bug state)
if loaded.Models[0] == loaded.Aliases["primary"] {
t.Error("expected out-of-sync state for this test")
}
// The fix: when updating aliases, also update models
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
t.Fatal(err)
}
loaded, _ = loadIntegration("claude")
if loaded.Models[0] != loaded.Aliases["primary"] {
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
loaded.Models[0], loaded.Aliases["primary"])
}
})
t.Run("updating primary alias updates models too", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial state
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
t.Fatal(err)
}
// Update aliases AND models together
newAliases := map[string]string{"primary": "updated-model"}
if err := saveAliases("claude", newAliases); err != nil {
t.Fatal(err)
}
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
if loaded.Models[0] != "updated-model" {
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
}
if loaded.Aliases["primary"] != "updated-model" {
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
}
})
}

View File

@@ -46,6 +46,53 @@ func TestIntegrationConfig(t *testing.T) {
}
})
t.Run("save and load aliases", func(t *testing.T) {
models := []string{"llama3.2"}
if err := saveIntegration("claude", models); err != nil {
t.Fatal(err)
}
aliases := map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
}
if err := saveAliases("claude", aliases); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases == nil {
t.Fatal("expected aliases to be saved")
}
for k, v := range aliases {
if config.Aliases[k] != v {
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
}
}
})
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
t.Fatal(err)
}
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases["primary"] != "model-a" {
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
}
})
t.Run("defaultModel returns first model", func(t *testing.T) {
saveIntegration("codex", []string{"model-a", "model-b"})
@@ -200,12 +247,10 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Create corrupted config.json file
dir := filepath.Join(tmpDir, ".ollama", "config")
dir := filepath.Join(tmpDir, ".ollama")
os.MkdirAll(dir, 0o755)
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
// Corrupted file is treated as empty, so loadIntegration returns not found
_, err := loadIntegration("test")
if err == nil {
t.Error("expected error for nonexistent integration in corrupted file")
@@ -267,7 +312,7 @@ func TestConfigPath(t *testing.T) {
t.Fatal(err)
}
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
expected := filepath.Join(tmpDir, ".ollama", "config.json")
if path != expected {
t.Errorf("expected %s, got %s", expected, path)
}
@@ -322,6 +367,183 @@ func TestLoad(t *testing.T) {
})
}
func TestMigrateConfig(t *testing.T) {
t.Run("migrates legacy file to new location", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
data := []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`)
os.WriteFile(filepath.Join(legacyDir, "config.json"), data, 0o644)
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if !migrated {
t.Fatal("expected migration to occur")
}
newPath, _ := configPath()
got, err := os.ReadFile(newPath)
if err != nil {
t.Fatalf("new config not found: %v", err)
}
if string(got) != string(data) {
t.Errorf("content mismatch: got %s", got)
}
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
t.Error("legacy file should have been removed")
}
if _, err := os.Stat(legacyDir); !os.IsNotExist(err) {
t.Error("legacy directory should have been removed")
}
})
t.Run("no-op when no legacy file exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if migrated {
t.Error("expected no migration")
}
})
t.Run("skips corrupt legacy file", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{corrupt`), 0o644)
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if migrated {
t.Error("should not migrate corrupt file")
}
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); os.IsNotExist(err) {
t.Error("corrupt legacy file should not have been deleted")
}
})
t.Run("new path takes precedence over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"old":{"models":["old-model"]}}}`), 0o644)
newDir := filepath.Join(tmpDir, ".ollama")
os.WriteFile(filepath.Join(newDir, "config.json"), []byte(`{"integrations":{"new":{"models":["new-model"]}}}`), 0o644)
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if _, ok := cfg.Integrations["new"]; !ok {
t.Error("expected new-path integration to be loaded")
}
if _, ok := cfg.Integrations["old"]; ok {
t.Error("legacy integration should not have been loaded")
}
})
t.Run("idempotent when called twice", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
if _, err := migrateConfig(); err != nil {
t.Fatal(err)
}
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if migrated {
t.Error("second migration should be a no-op")
}
})
t.Run("legacy directory preserved if not empty", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "other-file.txt"), []byte("keep me"), 0o644)
if _, err := migrateConfig(); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(legacyDir); os.IsNotExist(err) {
t.Error("directory with other files should not have been removed")
}
if _, err := os.Stat(filepath.Join(legacyDir, "other-file.txt")); os.IsNotExist(err) {
t.Error("other files in legacy directory should be untouched")
}
})
t.Run("save writes to new path after migration", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
// load triggers migration, then save should write to new path
if err := saveIntegration("codex", []string{"qwen2.5"}); err != nil {
t.Fatal(err)
}
newPath := filepath.Join(tmpDir, ".ollama", "config.json")
if _, err := os.Stat(newPath); os.IsNotExist(err) {
t.Error("save should write to new path")
}
// old path should not be recreated
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
t.Error("save should not recreate legacy path")
}
})
t.Run("load triggers migration transparently", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg.Integrations["claude"] == nil || cfg.Integrations["claude"].Models[0] != "llama3.2" {
t.Error("migration via load() did not preserve data")
}
})
}
func TestSave(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)

View File

@@ -39,7 +39,7 @@ type modelEntry struct {
func (d *Droid) String() string { return "Droid" }
func (d *Droid) Run(model string) error {
func (d *Droid) Run(model string, args []string) error {
if _, err := exec.LookPath("droid"); err != nil {
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
}
@@ -53,7 +53,7 @@ func (d *Droid) Run(model string) error {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("droid")
cmd := exec.Command("droid", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/progress"
"github.com/spf13/cobra"
)
@@ -22,7 +23,7 @@ import (
// Runner can run an integration with a model.
type Runner interface {
Run(model string) error
Run(model string, args []string) error
// String returns the human-readable name of the integration
String() string
}
@@ -38,13 +39,39 @@ type Editor interface {
Models() []string
}
// AliasConfigurer can configure model aliases (e.g., for subagent routing).
// Integrations like Claude and Codex use this to route model requests to local models.
type AliasConfigurer interface {
// ConfigureAliases prompts the user to configure aliases and returns the updated map.
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
// SetAliases syncs the configured aliases to the server
SetAliases(ctx context.Context, aliases map[string]string) error
}
// integrations is the registry of available integrations.
var integrations = map[string]Runner{
"claude": &Claude{},
"clawdbot": &Clawdbot{},
"clawdbot": &Openclaw{},
"codex": &Codex{},
"moltbot": &Openclaw{},
"droid": &Droid{},
"opencode": &OpenCode{},
"openclaw": &Openclaw{},
}
// recommendedModels are shown when the user has no models or as suggestions.
// Order matters: local models first, then cloud models.
var recommendedModels = []selectItem{
{Name: "glm-4.7-flash", Description: "Recommended (requires ~25GB VRAM)"},
{Name: "qwen3:8b", Description: "Recommended (requires ~11GB VRAM)"},
{Name: "glm-4.7:cloud", Description: "Recommended"},
{Name: "kimi-k2.5:cloud", Description: "Recommended"},
}
// integrationAliases are hidden from the interactive selector but work as CLI arguments.
var integrationAliases = map[string]bool{
"clawdbot": true,
"moltbot": true,
}
func selectIntegration() (string, error) {
@@ -55,6 +82,9 @@ func selectIntegration() (string, error) {
names := slices.Sorted(maps.Keys(integrations))
var items []selectItem
for _, name := range names {
if integrationAliases[name] {
continue
}
r := integrations[name]
description := r.String()
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
@@ -83,152 +113,212 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
return nil, err
}
if len(models.Models) == 0 {
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
var items []selectItem
cloudModels := make(map[string]bool)
var existing []modelInfo
for _, m := range models.Models {
if m.RemoteModel != "" {
cloudModels[m.Name] = true
}
items = append(items, selectItem{Name: m.Name})
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
}
if len(items) == 0 {
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
}
// Get previously configured models (saved config takes precedence)
var preChecked []string
if saved, err := loadIntegration(name); err == nil {
preChecked = saved.Models
} else if editor, ok := r.(Editor); ok {
preChecked = editor.Models()
}
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
for _, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
items, preChecked, existingModels, cloudModels := buildModelList(existing, preChecked, current)
// If current model is configured, move to front of preChecked
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
if len(items) == 0 {
return nil, fmt.Errorf("no models available")
}
// Sort: checked first, then alphabetical
slices.SortFunc(items, func(a, b selectItem) int {
ac, bc := checked[a.Name], checked[b.Name]
if ac != bc {
if ac {
return -1
}
return 1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
var selected []string
// only editors support multi-model selection
if _, ok := r.(Editor); ok {
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
if err != nil {
return nil, err
}
} else {
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
prompt := fmt.Sprintf("Select model for %s:", r)
if _, ok := r.(AliasConfigurer); ok {
prompt = fmt.Sprintf("Select Primary model for %s:", r)
}
model, err := selectPrompt(prompt, items)
if err != nil {
return nil, err
}
selected = []string{model}
}
// if any model in selected is a cloud model, ensure signed in
var toPull []string
for _, m := range selected {
if !existingModels[m] {
toPull = append(toPull, m)
}
}
if len(toPull) > 0 {
msg := fmt.Sprintf("Download %s?", strings.Join(toPull, ", "))
if ok, err := confirmPrompt(msg); err != nil {
return nil, err
} else if !ok {
return nil, errCancelled
}
for _, m := range toPull {
fmt.Fprintf(os.Stderr, "\n")
if err := pullModel(ctx, client, m); err != nil {
return nil, fmt.Errorf("failed to pull %s: %w", m, err)
}
}
}
if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
return nil, err
}
return selected, nil
}
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
if existingModels[model] {
return nil
}
msg := fmt.Sprintf("Download %s?", model)
if ok, err := confirmPrompt(msg); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
if err := pullModel(ctx, client, model); err != nil {
return fmt.Errorf("failed to pull %s: %w", model, err)
}
return nil
}
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, nil, nil, nil, err
}
models, err := client.List(ctx)
if err != nil {
return nil, nil, nil, nil, err
}
var existing []modelInfo
for _, m := range models.Models {
existing = append(existing, modelInfo{
Name: m.Name,
Remote: m.RemoteModel != "",
})
}
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
if len(items) == 0 {
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
return items, existingModels, cloudModels, client, nil
}
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) > 0 {
// ensure user is signed in
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return selected, nil
}
if len(selectedCloudModels) == 0 {
return nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return nil, err
}
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return nil
}
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return nil, fmt.Errorf("%s requires sign in", modelList)
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return err
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return fmt.Errorf("%s requires sign in", modelList)
}
// TODO(parthsareen): extract into auth package for cmd
// Auto-open browser (best effort, fail silently)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return nil, ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return selected, nil
}
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return nil
}
}
}
}
return selected, nil
}
func runIntegration(name, modelName string) error {
func runIntegration(name, modelName string, args []string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName)
return r.Run(modelName, args)
}
// syncAliases syncs aliases to server and saves locally for an AliasConfigurer.
func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, name, model string, existing map[string]string) error {
aliases := make(map[string]string)
for k, v := range existing {
aliases[k] = v
}
aliases["primary"] = model
if isCloudModel(ctx, client, model) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = model
}
} else {
delete(aliases, "fast")
}
if err := ac.SetAliases(ctx, aliases); err != nil {
return err
}
return saveAliases(name, aliases)
}
// LaunchCmd returns the cobra command for launching integrations.
@@ -237,29 +327,52 @@ func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) erro
var configFlag bool
cmd := &cobra.Command{
Use: "launch [INTEGRATION]",
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
Short: "Launch an integration with Ollama",
Long: `Launch an integration configured with Ollama models.
Supported integrations:
claude Claude Code
clawdbot Clawdbot
codex Codex
droid Droid
opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot)
Examples:
ollama launch
ollama launch claude
ollama launch claude --model <model>
ollama launch droid --config (does not auto-launch)`,
Args: cobra.MaximumNArgs(1),
ollama launch droid --config (does not auto-launch)
ollama launch codex -- -p myprofile (pass extra args to integration)
ollama launch codex -- --sandbox workspace-write`,
Args: cobra.ArbitraryArgs,
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
// Extract integration name and args to pass through using -- separator
var name string
if len(args) > 0 {
name = args[0]
var passArgs []string
dashIdx := cmd.ArgsLenAtDash()
if dashIdx == -1 {
// No "--" separator: only allow 0 or 1 args (integration name)
if len(args) > 1 {
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
}
if len(args) == 1 {
name = args[0]
}
} else {
// "--" was used: args before it = integration name, args after = passthrough
if dashIdx > 1 {
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
}
if dashIdx == 1 {
name = args[0]
}
passArgs = args[dashIdx:]
}
if name == "" {
var err error
name, err = selectIntegration()
if errors.Is(err, errCancelled) {
@@ -275,16 +388,92 @@ Examples:
return fmt.Errorf("unknown integration: %s", name)
}
// If launching without --model, use saved config if available
if !configFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0])
// Handle AliasConfigurer integrations (claude, codex)
if ac, ok := r.(AliasConfigurer); ok {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Validate --model flag if provided
if modelFlag != "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
}
}
var model string
var existingAliases map[string]string
// Load saved config
if cfg, err := loadIntegration(name); err == nil {
existingAliases = cfg.Aliases
if len(cfg.Models) > 0 {
model = cfg.Models[0]
// AliasConfigurer integrations use single model; sanitize if multiple
if len(cfg.Models) > 1 {
_ = saveIntegration(name, []string{model})
}
}
}
// --model flag overrides saved model
if modelFlag != "" {
model = modelFlag
}
// Validate saved model still exists
if model != "" && modelFlag == "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}); err != nil {
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
model = ""
}
}
// If no valid model or --config flag, show picker
if model == "" || configFlag {
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
}
// Sync aliases and save
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
}
if err := saveIntegration(name, []string{model}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
// Launch (unless --config without confirmation)
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
return runIntegration(name, model, passArgs)
}
return nil
}
return runIntegration(name, model, passArgs)
}
// Validate --model flag for non-AliasConfigurer integrations
if modelFlag != "" {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
}
}
var models []string
if modelFlag != "" {
// When --model is specified, merge with existing models (new model becomes default)
models = []string{modelFlag}
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
for _, m := range existing.Models {
@@ -293,6 +482,8 @@ Examples:
}
}
}
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
return runIntegration(name, saved.Models[0], passArgs)
} else {
var err error
models, err = selectModels(cmd.Context(), name, "")
@@ -339,13 +530,13 @@ Examples:
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
return runIntegration(name, models[0])
return runIntegration(name, models[0], passArgs)
}
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
return nil
}
return runIntegration(name, models[0])
return runIntegration(name, models[0], passArgs)
},
}
@@ -353,3 +544,163 @@ Examples:
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
return cmd
}
type modelInfo struct {
Name string
Remote bool
ToolCapable bool
}
// buildModelList merges existing models with recommendations, sorts them, and returns
// the ordered items along with maps of existing and cloud model names.
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []selectItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
existingModels = make(map[string]bool)
cloudModels = make(map[string]bool)
recommended := make(map[string]bool)
var hasLocalModel, hasCloudModel bool
for _, rec := range recommendedModels {
recommended[rec.Name] = true
}
for _, m := range existing {
existingModels[m.Name] = true
if m.Remote {
cloudModels[m.Name] = true
hasCloudModel = true
} else {
hasLocalModel = true
}
displayName := strings.TrimSuffix(m.Name, ":latest")
existingModels[displayName] = true
item := selectItem{Name: displayName}
if recommended[displayName] {
item.Description = "recommended"
}
items = append(items, item)
}
for _, rec := range recommendedModels {
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
continue
}
items = append(items, rec)
if strings.HasSuffix(rec.Name, ":cloud") {
cloudModels[rec.Name] = true
}
}
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
for _, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
}
// Non-existing models get "install?" suffix and are pushed to the bottom.
// When user has no models, preserve recommended order.
notInstalled := make(map[string]bool)
for i := range items {
if !existingModels[items[i].Name] {
notInstalled[items[i].Name] = true
if items[i].Description != "" {
items[i].Description += ", install?"
} else {
items[i].Description = "install?"
}
}
}
if hasLocalModel || hasCloudModel {
slices.SortStableFunc(items, func(a, b selectItem) int {
ac, bc := checked[a.Name], checked[b.Name]
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
if ac != bc {
if ac {
return -1
}
return 1
}
if !ac && !bc && aNew != bNew {
if aNew {
return 1
}
return -1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
}
return items, preChecked, existingModels, cloudModels
}
// isCloudModel checks if a model is a cloud model using the Show API.
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
if client == nil {
return false
}
resp, err := client.Show(ctx, &api.ShowRequest{Name: name})
if err != nil {
return false
}
return resp.RemoteModel != ""
}
func pullModel(ctx context.Context, client *api.Client, model string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
bars := make(map[string]*progress.Bar)
var status string
var spinner *progress.Spinner
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if resp.Completed == 0 {
return nil
}
if spinner != nil {
spinner.Stop()
}
bar, ok := bars[resp.Digest]
if !ok {
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
}
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
request := api.PullRequest{Name: model}
return client.Pull(ctx, &request, fn)
}

View File

@@ -1,10 +1,13 @@
package config
import (
"context"
"fmt"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/spf13/cobra"
)
@@ -90,8 +93,8 @@ func TestLaunchCmd(t *testing.T) {
cmd := LaunchCmd(mockCheck)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
@@ -121,7 +124,7 @@ func TestLaunchCmd(t *testing.T) {
}
func TestRunIntegration_UnknownIntegration(t *testing.T) {
err := runIntegration("unknown-integration", "model")
err := runIntegration("unknown-integration", "model", nil)
if err == nil {
t.Error("expected error for unknown integration, got nil")
}
@@ -174,15 +177,365 @@ func TestLaunchCmd_NilHeartbeat(t *testing.T) {
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
for name, r := range integrations {
t.Run(name, func(t *testing.T) {
// Test String() doesn't panic and returns non-empty
displayName := r.String()
if displayName == "" {
t.Error("String() should not return empty")
}
// Test Run() exists (we can't call it without actually running the command)
// Just verify the method is available
var _ func(string) error = r.Run
var _ func(string, []string) error = r.Run
})
}
}
func TestParseArgs(t *testing.T) {
// Tests reflect cobra's ArgsLenAtDash() semantics:
// - cobra strips "--" from args
// - ArgsLenAtDash() returns the index where "--" was, or -1
tests := []struct {
name string
args []string // args as cobra delivers them (no "--")
dashIdx int // what ArgsLenAtDash() returns
wantName string
wantArgs []string
wantErr bool
}{
{
name: "no extra args, no dash",
args: []string{"claude"},
dashIdx: -1,
wantName: "claude",
},
{
name: "with extra args after --",
args: []string{"codex", "-p", "myprofile"},
dashIdx: 1,
wantName: "codex",
wantArgs: []string{"-p", "myprofile"},
},
{
name: "extra args only after --",
args: []string{"codex", "--sandbox", "workspace-write"},
dashIdx: 1,
wantName: "codex",
wantArgs: []string{"--sandbox", "workspace-write"},
},
{
name: "-- at end with no args after",
args: []string{"claude"},
dashIdx: 1,
wantName: "claude",
},
{
name: "-- with no integration name",
args: []string{"--verbose"},
dashIdx: 0,
wantName: "",
wantArgs: []string{"--verbose"},
},
{
name: "multiple args before -- is error",
args: []string{"claude", "codex", "--verbose"},
dashIdx: 2,
wantErr: true,
},
{
name: "multiple args without -- is error",
args: []string{"claude", "codex"},
dashIdx: -1,
wantErr: true,
},
{
name: "no args, no dash",
args: []string{},
dashIdx: -1,
wantName: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the parsing logic from LaunchCmd using dashIdx
var name string
var parsedArgs []string
var err error
dashIdx := tt.dashIdx
args := tt.args
if dashIdx == -1 {
if len(args) > 1 {
err = fmt.Errorf("unexpected arguments: %v", args[1:])
} else if len(args) == 1 {
name = args[0]
}
} else {
if dashIdx > 1 {
err = fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
} else {
if dashIdx == 1 {
name = args[0]
}
parsedArgs = args[dashIdx:]
}
}
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if name != tt.wantName {
t.Errorf("name = %q, want %q", name, tt.wantName)
}
if !slices.Equal(parsedArgs, tt.wantArgs) {
t.Errorf("args = %v, want %v", parsedArgs, tt.wantArgs)
}
})
}
}
func TestIsCloudModel(t *testing.T) {
// isCloudModel now only uses Show API, so nil client always returns false
t.Run("nil client returns false", func(t *testing.T) {
models := []string{"glm-4.7:cloud", "kimi-k2.5:cloud", "local-model"}
for _, model := range models {
if isCloudModel(context.Background(), nil, model) {
t.Errorf("isCloudModel(%q) with nil client should return false", model)
}
}
})
}
func names(items []selectItem) []string {
var out []string
for _, item := range items {
out = append(out, item.Name)
}
return out
}
func TestBuildModelList_NoExistingModels(t *testing.T) {
items, _, _, _ := buildModelList(nil, nil, "")
want := []string{"glm-4.7-flash", "qwen3:8b", "glm-4.7:cloud", "kimi-k2.5:cloud"}
if diff := cmp.Diff(want, names(items)); diff != "" {
t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
}
for _, item := range items {
if !strings.HasSuffix(item.Description, "install?") {
t.Errorf("item %q should have description ending with 'install?', got %q", item.Name, item.Description)
}
}
}
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "qwen2.5:latest", Remote: false},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
want := []string{"llama3.2", "qwen2.5", "glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("cloud recs should be at bottom (-want +got):\n%s", diff)
}
}
func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
want := []string{"glm-4.7:cloud", "llama3.2", "glm-4.7-flash", "kimi-k2.5:cloud", "qwen3:8b"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("mixed models should be alphabetical (-want +got):\n%s", diff)
}
}
func TestBuildModelList_PreCheckedFirst(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
got := names(items)
if got[0] != "llama3.2" {
t.Errorf("pre-checked model should be first, got %v", got)
}
}
func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
existing := []modelInfo{
{Name: "glm-4.7-flash", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, nil, "")
for _, item := range items {
switch item.Name {
case "glm-4.7-flash", "glm-4.7:cloud":
if strings.HasSuffix(item.Description, "install?") {
t.Errorf("installed recommended %q should not have 'install?' suffix, got %q", item.Name, item.Description)
}
case "kimi-k2.5:cloud", "qwen3:8b":
if !strings.HasSuffix(item.Description, "install?") {
t.Errorf("non-installed recommended %q should have 'install?' suffix, got %q", item.Name, item.Description)
}
}
}
}
func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
existing := []modelInfo{
{Name: "glm-4.7-flash", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
// glm-4.7-flash and glm-4.7:cloud are installed so they sort normally;
// kimi-k2.5:cloud and qwen3:8b are not installed so they go to the bottom
want := []string{"glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("existing cloud models should sort normally (-want +got):\n%s", diff)
}
}
func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "kimi-k2.5:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, nil, "")
got := names(items)
// kimi-k2.5:cloud is installed so it sorts normally;
// the rest of the recommendations are not installed so they go to the bottom
want := []string{"kimi-k2.5:cloud", "llama3.2", "glm-4.7-flash", "glm-4.7:cloud", "qwen3:8b"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("only non-installed models should be at bottom (-want +got):\n%s", diff)
}
for _, item := range items {
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
if !strings.HasSuffix(item.Description, "install?") {
t.Errorf("non-installed %q should have 'install?' suffix, got %q", item.Name, item.Description)
}
}
}
}
func TestBuildModelList_LatestTagStripped(t *testing.T) {
existing := []modelInfo{
{Name: "glm-4.7-flash:latest", Remote: false},
{Name: "llama3.2:latest", Remote: false},
}
items, _, existingModels, _ := buildModelList(existing, nil, "")
got := names(items)
// :latest should be stripped from display names
for _, name := range got {
if strings.HasSuffix(name, ":latest") {
t.Errorf("name %q should not have :latest suffix", name)
}
}
// glm-4.7-flash should not be duplicated (existing :latest matches the recommendation)
count := 0
for _, name := range got {
if name == "glm-4.7-flash" {
count++
}
}
if count != 1 {
t.Errorf("glm-4.7-flash should appear exactly once, got %d in %v", count, got)
}
// Stripped name should be in existingModels so it won't be pulled
if !existingModels["glm-4.7-flash"] {
t.Error("glm-4.7-flash should be in existingModels")
}
}
func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false},
{Name: "glm-4.7:cloud", Remote: true},
}
_, _, existingModels, cloudModels := buildModelList(existing, nil, "")
if !existingModels["llama3.2"] {
t.Error("llama3.2 should be in existingModels")
}
if !existingModels["glm-4.7:cloud"] {
t.Error("glm-4.7:cloud should be in existingModels")
}
if existingModels["glm-4.7-flash"] {
t.Error("glm-4.7-flash should not be in existingModels (it's a recommendation)")
}
if !cloudModels["glm-4.7:cloud"] {
t.Error("glm-4.7:cloud should be in cloudModels")
}
if !cloudModels["kimi-k2.5:cloud"] {
t.Error("kimi-k2.5:cloud should be in cloudModels (recommended cloud)")
}
if cloudModels["llama3.2"] {
t.Error("llama3.2 should not be in cloudModels")
}
}
func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save a config for opencode so it looks like a previous launch
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Verify loadIntegration returns the saved models
saved, err := loadIntegration("opencode")
if err != nil {
t.Fatal(err)
}
if len(saved.Models) == 0 {
t.Fatal("expected saved models")
}
if saved.Models[0] != "llama3.2" {
t.Errorf("expected llama3.2, got %s", saved.Models[0])
}
}
func TestAliasConfigurerInterface(t *testing.T) {
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
claude := &Claude{}
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
t.Error("Claude should implement AliasConfigurer")
}
})
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
codex := &Codex{}
if _, ok := interface{}(codex).(AliasConfigurer); ok {
t.Error("Codex should not implement AliasConfigurer")
}
})
}

View File

@@ -13,26 +13,44 @@ import (
"github.com/ollama/ollama/envconfig"
)
type Clawdbot struct{}
type Openclaw struct{}
func (c *Clawdbot) String() string { return "Clawdbot" }
func (c *Openclaw) String() string { return "OpenClaw" }
const ansiGreen = "\033[32m"
func (c *Clawdbot) Run(model string) error {
if _, err := exec.LookPath("clawdbot"); err != nil {
return fmt.Errorf("clawdbot is not installed, install from https://docs.clawd.bot")
func (c *Openclaw) Run(model string, args []string) error {
bin := "openclaw"
if _, err := exec.LookPath(bin); err != nil {
bin = "clawdbot"
if _, err := exec.LookPath(bin); err != nil {
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
}
}
models := []string{model}
if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
models = config.Models
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("clawdbot", "gateway")
if !c.onboarded() {
// Onboarding not completed: run it (model already set via Edit)
// Use "ollama" as gateway token for simple local access
cmd := exec.Command(bin, "onboard",
"--auth-choice", "skip",
"--gateway-token", "ollama",
)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// Onboarding completed: run gateway
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
cmd.Stdin = os.Stdin
// Capture output to detect "already running" message
@@ -42,22 +60,55 @@ func (c *Clawdbot) Run(model string) error {
err := cmd.Run()
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
fmt.Fprintf(os.Stderr, "%sClawdbot has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
return nil
}
return err
}
func (c *Clawdbot) Paths() []string {
// onboarded checks if OpenClaw onboarding wizard was completed
// by looking for the wizard.lastRunAt marker in the config
func (c *Openclaw) onboarded() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
} else if data, err := os.ReadFile(legacyPath); err == nil {
_ = json.Unmarshal(data, &config)
} else {
return false
}
// Check for wizard.lastRunAt marker (set when onboarding completes)
wizard, _ := config["wizard"].(map[string]any)
if wizard == nil {
return false
}
lastRunAt, _ := wizard["lastRunAt"].(string)
return lastRunAt != ""
}
func (c *Openclaw) Paths() []string {
home, _ := os.UserHomeDir()
p := filepath.Join(home, ".clawdbot", "clawdbot.json")
p := filepath.Join(home, ".openclaw", "openclaw.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
if _, err := os.Stat(legacy); err == nil {
return []string{legacy}
}
return nil
}
func (c *Clawdbot) Edit(models []string) error {
func (c *Openclaw) Edit(models []string) error {
if len(models) == 0 {
return nil
}
@@ -67,7 +118,8 @@ func (c *Clawdbot) Edit(models []string) error {
return err
}
configPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
@@ -76,6 +128,8 @@ func (c *Clawdbot) Edit(models []string) error {
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
} else if data, err := os.ReadFile(legacyPath); err == nil {
_ = json.Unmarshal(data, &config)
}
// Navigate/create: models.providers.ollama (preserving other providers)
@@ -167,15 +221,18 @@ func (c *Clawdbot) Edit(models []string) error {
return writeWithBackup(configPath, data)
}
func (c *Clawdbot) Models() []string {
func (c *Openclaw) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
if err != nil {
return nil
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
if err != nil {
return nil
}
}
modelsSection, _ := config["models"].(map[string]any)

View File

@@ -8,12 +8,12 @@ import (
"testing"
)
func TestClawdbotIntegration(t *testing.T) {
c := &Clawdbot{}
func TestOpenclawIntegration(t *testing.T) {
c := &Openclaw{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Clawdbot" {
t.Errorf("String() = %q, want %q", got, "Clawdbot")
if got := c.String(); got != "OpenClaw" {
t.Errorf("String() = %q, want %q", got, "OpenClaw")
}
})
@@ -26,13 +26,13 @@ func TestClawdbotIntegration(t *testing.T) {
})
}
func TestClawdbotEdit(t *testing.T) {
c := &Clawdbot{}
func TestOpenclawEdit(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
cleanup := func() { os.RemoveAll(configDir) }
@@ -41,8 +41,8 @@ func TestClawdbotEdit(t *testing.T) {
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2")
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
assertOpenclawModelExists(t, configPath, "llama3.2")
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("multiple models - first is primary", func(t *testing.T) {
@@ -50,9 +50,9 @@ func TestClawdbotEdit(t *testing.T) {
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2")
assertClawdbotModelExists(t, configPath, "mistral")
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
assertOpenclawModelExists(t, configPath, "llama3.2")
assertOpenclawModelExists(t, configPath, "mistral")
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("preserve other providers", func(t *testing.T) {
@@ -127,8 +127,8 @@ func TestClawdbotEdit(t *testing.T) {
c.Edit([]string{"llama3.2", "mistral"})
c.Edit([]string{"llama3.2"})
assertClawdbotModelExists(t, configPath, "llama3.2")
assertClawdbotModelNotExists(t, configPath, "mistral")
assertOpenclawModelExists(t, configPath, "llama3.2")
assertOpenclawModelNotExists(t, configPath, "mistral")
})
t.Run("empty models is no-op", func(t *testing.T) {
@@ -169,12 +169,12 @@ func TestClawdbotEdit(t *testing.T) {
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2")
assertOpenclawModelExists(t, configPath, "llama3.2")
})
}
func TestClawdbotModels(t *testing.T) {
c := &Clawdbot{}
func TestOpenclawModels(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
@@ -185,9 +185,9 @@ func TestClawdbotModels(t *testing.T) {
})
t.Run("returns all ollama models", func(t *testing.T) {
configDir := filepath.Join(tmpDir, ".clawdbot")
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[
{"id":"llama3.2"},
{"id":"mistral"}
@@ -202,7 +202,7 @@ func TestClawdbotModels(t *testing.T) {
}
// Helper functions
func assertClawdbotModelExists(t *testing.T, path, model string) {
func assertOpenclawModelExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
@@ -221,7 +221,7 @@ func assertClawdbotModelExists(t *testing.T, path, model string) {
t.Errorf("model %s not found", model)
}
func assertClawdbotModelNotExists(t *testing.T, path, model string) {
func assertOpenclawModelNotExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
@@ -239,7 +239,7 @@ func assertClawdbotModelNotExists(t *testing.T, path, model string) {
}
}
func assertClawdbotPrimaryModel(t *testing.T, path, expected string) {
func assertOpenclawPrimaryModel(t *testing.T, path, expected string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
@@ -252,15 +252,15 @@ func assertClawdbotPrimaryModel(t *testing.T, path, expected string) {
}
}
func TestClawdbotPaths(t *testing.T) {
c := &Clawdbot{}
func TestOpenclawPaths(t *testing.T) {
c := &Openclaw{}
t.Run("returns path when config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{}`), 0o644)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
@@ -277,12 +277,12 @@ func TestClawdbotPaths(t *testing.T) {
})
}
func TestClawdbotModelsEdgeCases(t *testing.T) {
c := &Clawdbot{}
func TestOpenclawModelsEdgeCases(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("corrupted JSON returns nil", func(t *testing.T) {
@@ -340,11 +340,11 @@ func TestClawdbotModelsEdgeCases(t *testing.T) {
})
}
func TestClawdbotEditSchemaFields(t *testing.T) {
c := &Clawdbot{}
func TestOpenclawEditSchemaFields(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
@@ -381,20 +381,20 @@ func TestClawdbotEditSchemaFields(t *testing.T) {
}
}
func TestClawdbotEditModelNames(t *testing.T) {
c := &Clawdbot{}
func TestOpenclawEditModelNames(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".clawdbot")) }
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".openclaw")) }
t.Run("model with colon tag", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "llama3.2:70b")
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2:70b")
assertOpenclawModelExists(t, configPath, "llama3.2:70b")
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2:70b")
})
t.Run("model with slash", func(t *testing.T) {
@@ -402,8 +402,8 @@ func TestClawdbotEditModelNames(t *testing.T) {
if err := c.Edit([]string{"library/model:tag"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "library/model:tag")
assertClawdbotPrimaryModel(t, configPath, "ollama/library/model:tag")
assertOpenclawModelExists(t, configPath, "library/model:tag")
assertOpenclawPrimaryModel(t, configPath, "ollama/library/model:tag")
})
t.Run("model with hyphen", func(t *testing.T) {
@@ -411,16 +411,16 @@ func TestClawdbotEditModelNames(t *testing.T) {
if err := c.Edit([]string{"test-model"}); err != nil {
t.Fatal(err)
}
assertClawdbotModelExists(t, configPath, "test-model")
assertOpenclawModelExists(t, configPath, "test-model")
})
}
func TestClawdbotEditAgentsPreservation(t *testing.T) {
c := &Clawdbot{}
func TestOpenclawEditAgentsPreservation(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("preserve other agent defaults", func(t *testing.T) {
@@ -457,7 +457,7 @@ func TestClawdbotEditAgentsPreservation(t *testing.T) {
})
}
const testClawdbotFixture = `{
const testOpenclawFixture = `{
"theme": "dark",
"mcp": {"servers": {"custom": {"enabled": true}}},
"models": {
@@ -475,15 +475,15 @@ const testClawdbotFixture = `{
}
}`
func TestClawdbotEdit_RoundTrip(t *testing.T) {
c := &Clawdbot{}
func TestOpenclawEdit_RoundTrip(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
@@ -521,15 +521,15 @@ func TestClawdbotEdit_RoundTrip(t *testing.T) {
}
}
func TestClawdbotEdit_Idempotent(t *testing.T) {
c := &Clawdbot{}
func TestOpenclawEdit_Idempotent(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
c.Edit([]string{"llama3.2", "mistral"})
firstData, _ := os.ReadFile(configPath)
@@ -542,15 +542,15 @@ func TestClawdbotEdit_Idempotent(t *testing.T) {
}
}
func TestClawdbotEdit_MultipleConsecutiveEdits(t *testing.T) {
c := &Clawdbot{}
func TestOpenclawEdit_MultipleConsecutiveEdits(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
for i := range 10 {
models := []string{"model-a", "model-b"}
@@ -573,12 +573,12 @@ func TestClawdbotEdit_MultipleConsecutiveEdits(t *testing.T) {
}
}
func TestClawdbotEdit_BackupCreated(t *testing.T) {
c := &Clawdbot{}
func TestOpenclawEdit_BackupCreated(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configPath := filepath.Join(configDir, "clawdbot.json")
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
os.MkdirAll(configDir, 0o755)
@@ -590,7 +590,7 @@ func TestClawdbotEdit_BackupCreated(t *testing.T) {
t.Fatal(err)
}
backups, _ := filepath.Glob(filepath.Join(backupDir, "clawdbot.json.*"))
backups, _ := filepath.Glob(filepath.Join(backupDir, "openclaw.json.*"))
foundBackup := false
for _, backup := range backups {
data, _ := os.ReadFile(backup)
@@ -605,11 +605,151 @@ func TestClawdbotEdit_BackupCreated(t *testing.T) {
}
}
func TestClawdbotEdit_CreatesDirectoryIfMissing(t *testing.T) {
c := &Clawdbot{}
func TestOpenclawClawdbotAlias(t *testing.T) {
for _, alias := range []string{"clawdbot", "moltbot"} {
t.Run(alias+" alias resolves to Openclaw runner", func(t *testing.T) {
r, ok := integrations[alias]
if !ok {
t.Fatalf("%s not found in integrations", alias)
}
if _, ok := r.(*Openclaw); !ok {
t.Errorf("%s integration is %T, want *Openclaw", alias, r)
}
})
t.Run(alias+" is hidden from selector", func(t *testing.T) {
if !integrationAliases[alias] {
t.Errorf("%s should be in integrationAliases", alias)
}
})
}
}
func TestOpenclawLegacyPaths(t *testing.T) {
c := &Openclaw{}
t.Run("falls back to legacy clawdbot path", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Fatalf("expected 1 path, got %d", len(paths))
}
if paths[0] != filepath.Join(legacyDir, "clawdbot.json") {
t.Errorf("expected legacy path, got %s", paths[0])
}
})
t.Run("prefers new path over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Fatalf("expected 1 path, got %d", len(paths))
}
if paths[0] != filepath.Join(newDir, "openclaw.json") {
t.Errorf("expected new path, got %s", paths[0])
}
})
t.Run("Models reads from legacy path", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[{"id":"llama3.2"}]}}}
}`), 0o644)
models := c.Models()
if len(models) != 1 || models[0] != "llama3.2" {
t.Errorf("expected [llama3.2], got %v", models)
}
})
t.Run("Models prefers new path over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[{"id":"new-model"}]}}}
}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[{"id":"legacy-model"}]}}}
}`), 0o644)
models := c.Models()
if len(models) != 1 || models[0] != "new-model" {
t.Errorf("expected [new-model], got %v", models)
}
})
t.Run("Edit reads new path over legacy when both exist", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{"theme":"new"}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"legacy"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(filepath.Join(newDir, "openclaw.json"))
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "new" {
t.Errorf("expected theme from new config, got %v", cfg["theme"])
}
})
t.Run("Edit migrates from legacy config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"dark"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Should write to new path
newPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
data, err := os.ReadFile(newPath)
if err != nil {
t.Fatal("expected new config file to be created")
}
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("legacy theme setting was not migrated")
}
})
}
func TestOpenclawEdit_CreatesDirectoryIfMissing(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".clawdbot")
configDir := filepath.Join(tmpDir, ".openclaw")
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
t.Fatal("directory should not exist before test")
@@ -623,3 +763,116 @@ func TestClawdbotEdit_CreatesDirectoryIfMissing(t *testing.T) {
t.Fatal("directory was not created")
}
}
func TestOpenclawOnboarded(t *testing.T) {
c := &Openclaw{}
t.Run("returns false when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if c.onboarded() {
t.Error("expected false when no config exists")
}
})
t.Run("returns false when config exists but no wizard section", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
if c.onboarded() {
t.Error("expected false when no wizard section")
}
})
t.Run("returns false when wizard section exists but no lastRunAt", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{}}`), 0o644)
if c.onboarded() {
t.Error("expected false when wizard.lastRunAt is missing")
}
})
t.Run("returns false when wizard.lastRunAt is empty string", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":""}}`), 0o644)
if c.onboarded() {
t.Error("expected false when wizard.lastRunAt is empty")
}
})
t.Run("returns true when wizard.lastRunAt is set", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
if !c.onboarded() {
t.Error("expected true when wizard.lastRunAt is set")
}
})
t.Run("checks legacy clawdbot path", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
if !c.onboarded() {
t.Error("expected true when legacy config has wizard.lastRunAt")
}
})
t.Run("prefers new path over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
// New path has no wizard marker
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
// Legacy has wizard marker
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
if c.onboarded() {
t.Error("expected false - should prefer new path which has no wizard marker")
}
})
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
if c.onboarded() {
t.Error("expected false for corrupted JSON")
}
})
t.Run("handles wrong type for wizard section", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":"not a map"}`), 0o644)
if c.onboarded() {
t.Error("expected false when wizard is wrong type")
}
})
}

View File

@@ -1,6 +1,7 @@
package config
import (
"context"
"encoding/json"
"fmt"
"maps"
@@ -10,15 +11,55 @@ import (
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
// cloudModelLimit holds context and output token limits for a cloud model.
type cloudModelLimit struct {
Context int
Output int
}
// cloudModelLimits maps cloud model base names to their token limits.
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
var cloudModelLimits = map[string]cloudModelLimit{
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
"glm-4.6": {Context: 202_752, Output: 131_072},
"glm-4.7": {Context: 202_752, Output: 131_072},
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
"kimi-k2.5": {Context: 262_144, Output: 262_144},
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It tries the exact name first, then strips the ":cloud" suffix.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
if l, ok := cloudModelLimits[name]; ok {
return l, true
}
base := strings.TrimSuffix(name, ":cloud")
if base != name {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string) error {
func (o *OpenCode) Run(model string, args []string) error {
if _, err := exec.LookPath("opencode"); err != nil {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
}
@@ -32,7 +73,7 @@ func (o *OpenCode) Run(model string) error {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("opencode")
cmd := exec.Command("opencode", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
@@ -113,6 +154,8 @@ func (o *OpenCode) Edit(modelList []string) error {
}
}
client, _ := api.ClientFromEnvironment()
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
@@ -122,12 +165,29 @@ func (o *OpenCode) Edit(modelList []string) error {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
existing["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
continue
}
models[model] = map[string]any{
entry := map[string]any{
"name": model,
"_launch": true,
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
models[model] = entry
}
ollama["models"] = models

View File

@@ -2,6 +2,7 @@ package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
@@ -495,6 +496,165 @@ func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
}
}
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
t.Helper()
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry, ok := models[model].(map[string]any)
if !ok {
t.Fatalf("model %s not found in config", model)
}
return entry
}
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
if entry["limit"] != nil {
t.Errorf("local model should not have limit set, got %v", entry["limit"])
}
}
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
// Set up a model with a user-configured limit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{
"provider": {
"ollama": {
"models": {
"llama3.2": {
"name": "llama3.2",
"_launch": true,
"limit": {"context": 8192, "output": 4096}
}
}
}
}
}`), 0o644)
// Re-edit should preserve the user's limit (not delete it)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("user-configured limit was removed")
}
if limit["context"] != float64(8192) {
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
}
if limit["output"] != float64(4096) {
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
}
}
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
// Verify that when a cloud model entry has limits set (as Edit would do),
// the structure matches what opencode expects and re-edit preserves them.
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
expected := cloudModelLimits["glm-4.7"]
// Simulate a cloud model that already has the limit set by a previous Edit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
"provider": {
"ollama": {
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud",
"_launch": true,
"limit": {"context": %d, "output": %d}
}
}
}
}
}`, expected.Context, expected.Output)), 0o644)
// Re-edit should preserve the cloud model limit
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model limit was removed on re-edit")
}
if limit["context"] != float64(expected.Context) {
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
}
if limit["output"] != float64(expected.Output) {
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
}
}
func TestLookupCloudModelLimit(t *testing.T) {
tests := []struct {
name string
wantOK bool
wantContext int
wantOutput int
}{
{"glm-4.7", true, 202_752, 131_072},
{"glm-4.7:cloud", true, 202_752, 131_072},
{"kimi-k2.5", true, 262_144, 262_144},
{"kimi-k2.5:cloud", true, 262_144, 262_144},
{"deepseek-v3.2", true, 163_840, 65_536},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3-coder:480b", true, 262_144, 65_536},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(tt.name)
if ok != tt.wantOK {
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
}
if ok {
if l.Context != tt.wantContext {
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
}
if l.Output != tt.wantOutput {
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
}
}
})
}
}
func TestOpenCodeModels_NoConfig(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()

View File

@@ -17,6 +17,7 @@ const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiClearDown = "\033[J"
)
@@ -275,7 +276,11 @@ func parseInput(r io.Reader) (inputEvent, byte, error) {
func renderSelect(w io.Writer, prompt string, s *selectState) int {
filtered := s.filtered()
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
if s.filter == "" {
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
} else {
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
}
lineCount := 1
if len(filtered) == 0 {
@@ -314,7 +319,11 @@ func renderSelect(w io.Writer, prompt string, s *selectState) int {
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
filtered := s.filtered()
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
if s.filter == "" {
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
} else {
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
}
lineCount := 1
if len(filtered) == 0 {
@@ -345,10 +354,15 @@ func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
suffix = " " + ansiGray + "(default)" + ansiReset
}
desc := ""
if item.Description != "" {
desc = " " + ansiGray + "- " + item.Description + ansiReset
}
if idx == s.highlighted && !s.focusOnButton {
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
fmt.Fprintf(w, " %s%s %s %s%s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, desc, suffix)
} else {
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
fmt.Fprintf(w, " %s %s %s%s%s\r\n", prefix, checkbox, item.Name, desc, suffix)
}
lineCount++
}

View File

@@ -96,6 +96,14 @@ func TestSelectState(t *testing.T) {
}
})
t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) {
s := newSelectState([]selectItem{})
done, result, err := s.handleInput(eventEnter, 0)
if done || result != "" || err != nil {
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
s := newSelectState(items)
done, result, err := s.handleInput(eventEscape, 0)
@@ -574,8 +582,19 @@ func TestRenderSelect(t *testing.T) {
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "no matches") {
t.Errorf("expected 'no matches' message, got: %s", output)
}
})
t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) {
s := newSelectState([]selectItem{})
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' message")
t.Error("expected 'no matches' message for empty list with no filter")
}
})

View File

@@ -10,19 +10,21 @@ import (
"github.com/ollama/ollama/api"
)
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
func startApp(ctx context.Context, client *api.Client) error {
exe, err := os.Executable()
if err != nil {
return err
return errNotRunning
}
link, err := os.Readlink(exe)
if err != nil {
return err
return errNotRunning
}
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
m := r.FindStringSubmatch(link)
if len(m) != 1 {
return errors.New("could not find ollama app")
return errNotRunning
}
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
return err

View File

@@ -313,8 +313,12 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &deepseek2Model{}
case "Glm4MoeLiteForCausalLM":
conv = &glm4MoeLiteModel{}
case "GlmOcrForConditionalGeneration":
conv = &glmOcrModel{}
case "Lfm2ForCausalLM":
conv = &lfm2Model{}
case "Qwen3NextForCausalLM":
conv = &qwen3NextModel{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}

455
convert/convert_glmocr.go Normal file
View File

@@ -0,0 +1,455 @@
package convert
import (
"cmp"
"encoding/json"
"io/fs"
"log/slog"
"regexp"
"strconv"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
// normalToNeoXRepacker creates a repacker that permutes Q/K weights from interleaved (LLaMA)
// to NeoX ordering for compatibility with GGML's M-RoPE kernel.
//
// For weights: reshape [out, in] -> [n_heads, head_dim, in], permute rotary dims, reshape back
// For biases: reshape [out] -> [n_heads, head_dim], permute rotary dims, reshape back
func normalToNeoXRepacker(nHeads, headDim int, partialRotaryFactor float32) func(string, []float32, []uint64) ([]float32, error) {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
rotaryDim := int(float32(headDim) * partialRotaryFactor)
if rotaryDim%2 != 0 {
rotaryDim = (rotaryDim / 2) * 2 // Round down to even
}
// Handle 1D (bias) or 2D (weight) tensors
is1D := len(shape) == 1
var inFeatures int
if is1D {
inFeatures = 1
} else {
inFeatures = int(shape[1])
}
outFeatures := int(shape[0])
nEffectiveHeads := outFeatures / headDim
if nEffectiveHeads != nHeads {
slog.Warn("normalToNeoX: unexpected head count", "effective", nEffectiveHeads, "expected", nHeads)
}
// Reshape to [n_heads, head_dim, in_features]
reshaped := make([]float32, len(data))
copy(reshaped, data)
// Permute the rotary dimensions: even indices first, then odd
// For each head, reorder [0,1,2,3,4,5...] to [0,2,4...,1,3,5...]
result := make([]float32, len(data))
halfRotary := rotaryDim / 2
for h := range nEffectiveHeads {
for f := range inFeatures {
for i := range halfRotary {
// Even dim (0, 2, 4, ...) -> position i
srcIdx := h*headDim*inFeatures + (2*i)*inFeatures + f
dstIdx := h*headDim*inFeatures + i*inFeatures + f
result[dstIdx] = reshaped[srcIdx]
// Odd dim (1, 3, 5, ...) -> position halfRotary + i
srcIdx = h*headDim*inFeatures + (2*i+1)*inFeatures + f
dstIdx = h*headDim*inFeatures + (halfRotary+i)*inFeatures + f
result[dstIdx] = reshaped[srcIdx]
}
// Non-rotary part: copy as-is
for i := rotaryDim; i < headDim; i++ {
srcIdx := h*headDim*inFeatures + i*inFeatures + f
result[srcIdx] = reshaped[srcIdx]
}
}
}
return result, nil
}
}
type glmOcrModel struct {
ModelParameters
TextConfig struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
MaxPositionEmbed uint32 `json:"max_position_embeddings"`
RMSNormEps float32 `json:"rms_norm_eps"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeParameters struct {
RopeType string `json:"rope_type"`
MRopeSection []int32 `json:"mrope_section"`
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionConfig struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
Depth uint32 `json:"depth"`
NumHeads uint32 `json:"num_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
OutHiddenSize uint32 `json:"out_hidden_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"`
ImageStartTokenID uint32 `json:"image_start_token_id"`
ImageEndTokenID uint32 `json:"image_end_token_id"`
VideoStartTokenID uint32 `json:"video_start_token_id"`
VideoEndTokenID uint32 `json:"video_end_token_id"`
ImageTokenID uint32 `json:"image_token_id"`
VideoTokenID uint32 `json:"video_token_id"`
// Preprocessor config (preprocessor_config.json)
Preprocessor struct {
Size struct {
ShortestEdge uint32 `json:"shortest_edge"`
LongestEdge uint32 `json:"longest_edge"`
} `json:"size"`
PatchSize uint32 `json:"patch_size"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
MergeSize uint32 `json:"merge_size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
} `json:"-"`
}
var _ ModelConverter = (*glmOcrModel)(nil)
func (m *glmOcrModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "preprocessor_config.json")
if err != nil {
return err
}
return json.Unmarshal(bts, &m.Preprocessor)
}
func (m *glmOcrModel) KV(t *Tokenizer) KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "glmocr"
// Text model parameters
kv["glmocr.block_count"] = cmp.Or(m.TextConfig.NumHiddenLayers, 16)
kv["glmocr.embedding_length"] = cmp.Or(m.TextConfig.HiddenSize, 1536)
kv["glmocr.attention.head_count"] = cmp.Or(m.TextConfig.NumAttentionHeads, 16)
kv["glmocr.attention.head_count_kv"] = cmp.Or(m.TextConfig.NumKeyValueHeads, 8)
headDim := cmp.Or(m.TextConfig.HeadDim, m.TextConfig.HiddenSize/m.TextConfig.NumAttentionHeads)
kv["glmocr.attention.key_length"] = headDim
kv["glmocr.attention.value_length"] = headDim
kv["glmocr.feed_forward_length"] = cmp.Or(m.TextConfig.IntermediateSize, 4608)
kv["glmocr.attention.layer_norm_rms_epsilon"] = cmp.Or(m.TextConfig.RMSNormEps, 1e-5)
kv["glmocr.context_length"] = cmp.Or(m.TextConfig.MaxPositionEmbed, 131072)
kv["glmocr.rope.freq_base"] = cmp.Or(m.TextConfig.RopeParameters.RopeTheta, float32(10000))
kv["glmocr.rope.partial_rotary_factor"] = cmp.Or(m.TextConfig.RopeParameters.PartialRotaryFactor, m.TextConfig.PartialRotaryFactor, float32(1.0))
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 {
kv["glmocr.rope.mrope_section"] = m.TextConfig.RopeParameters.MRopeSection
}
// Vision model parameters
kv["glmocr.vision.block_count"] = cmp.Or(m.VisionConfig.Depth, 24)
kv["glmocr.vision.embedding_length"] = cmp.Or(m.VisionConfig.HiddenSize, 1024)
kv["glmocr.vision.attention.head_count"] = cmp.Or(m.VisionConfig.NumHeads, 16)
kv["glmocr.vision.image_size"] = cmp.Or(m.VisionConfig.ImageSize, 336)
kv["glmocr.vision.patch_size"] = cmp.Or(m.VisionConfig.PatchSize, m.Preprocessor.PatchSize, 14)
kv["glmocr.vision.spatial_merge_size"] = cmp.Or(m.VisionConfig.SpatialMergeSize, m.Preprocessor.MergeSize, 2)
kv["glmocr.vision.temporal_patch_size"] = cmp.Or(m.VisionConfig.TemporalPatchSize, m.Preprocessor.TemporalPatchSize, 2)
kv["glmocr.vision.out_hidden_size"] = cmp.Or(m.VisionConfig.OutHiddenSize, 1536)
kv["glmocr.vision.intermediate_size"] = cmp.Or(m.VisionConfig.IntermediateSize, 4096)
kv["glmocr.vision.attention.layer_norm_rms_epsilon"] = cmp.Or(m.VisionConfig.RMSNormEps, 1e-5)
// Preprocessor-derived image settings (min/max pixels and normalization)
// Note: fs.Config.keyValue() auto-prepends architecture prefix, so use full key
if m.Preprocessor.Size.ShortestEdge > 0 {
kv["glmocr.vision.min_pixels"] = m.Preprocessor.Size.ShortestEdge
}
if m.Preprocessor.Size.LongestEdge > 0 {
kv["glmocr.vision.max_pixels"] = m.Preprocessor.Size.LongestEdge
}
if len(m.Preprocessor.ImageMean) == 3 {
kv["glmocr.vision.image_mean"] = m.Preprocessor.ImageMean
}
if len(m.Preprocessor.ImageStd) == 3 {
kv["glmocr.vision.image_std"] = m.Preprocessor.ImageStd
}
// Special tokens
kv["glmocr.image_token_id"] = m.ImageTokenID
kv["glmocr.image_start_token_id"] = m.ImageStartTokenID
kv["glmocr.image_end_token_id"] = m.ImageEndTokenID
kv["glmocr.video_token_id"] = m.VideoTokenID
kv["glmocr.video_start_token_id"] = m.VideoStartTokenID
kv["glmocr.video_end_token_id"] = m.VideoEndTokenID
return kv
}
func (m *glmOcrModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
// Skip layers >= num_hidden_layers (Multi-Token Prediction layers not needed for basic inference)
numLayers := int(cmp.Or(m.TextConfig.NumHiddenLayers, 16))
skipLayer := func(name string) bool {
// Tensor names are already replaced to "blk.N.xxx" format
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(name)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return blkNum >= numLayers
}
for _, t := range ts {
name := t.Name()
// Skip next-n prediction layers (layers >= num_hidden_layers)
if skipLayer(name) {
continue
}
// Split ffn_gate_up into separate gate and up projections
if strings.Contains(name, "ffn_gate_up") {
for t := range splitDim(t, 0,
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_gate")},
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_up")},
) {
out = append(out, t)
}
continue
}
if strings.HasSuffix(name, "patch_embd.weight") {
shape := t.Shape()
if len(shape) == 5 && shape[2] == 2 {
newShape := []uint64{shape[0], shape[1], shape[3], shape[4]}
t0 := t.Clone()
t0.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tt.Slice(nil, nil, tensor.S(0, 1), nil, nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
if err := tt.Reshape(newDims...); err != nil {
return nil, err
}
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
})
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t0,
})
t1 := t.Clone()
t1.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tt.Slice(nil, nil, tensor.S(1, 2), nil, nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
if err := tt.Reshape(newDims...); err != nil {
return nil, err
}
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
})
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_1.weight", 1),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t1,
})
continue
}
if len(shape) == 4 {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
slog.Warn("glmocr: patch_embed weight has unexpected shape - not splitting", "shape", shape)
// Fall through to default handling
}
// Handle pre-split patch embedding weights
// Pattern 1: v.patch_embd.0.weight, v.patch_embd.1.weight -> patch_embd_0.weight, patch_embd_1.weight
// Pattern 2: v.patch_embd.weight.0, v.patch_embd.weight.1 -> patch_embd_0.weight, patch_embd_1.weight
if strings.Contains(name, "patch_embd.0.") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.0.", "patch_embd_0.", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
if strings.Contains(name, "patch_embd.1.") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.1.", "patch_embd_1.", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
// Handle .weight.0 and .weight.1 suffix patterns
if strings.HasSuffix(name, "patch_embd.weight.0") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight.0", "patch_embd_0.weight", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
if strings.HasSuffix(name, "patch_embd.weight.1") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "patch_embd.weight.1", "patch_embd_1.weight", 1),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
continue
}
// Permute Q/K weights for M-RoPE compatibility (interleaved -> NeoX ordering)
// GGML's M-RoPE kernel uses NeoX-style rotation, but GLM-OCR uses interleaved (LLaMA-style)
// We permute at conversion time so the weights work correctly with GGML's kernel
// This aligns Q/K rotary dimensions with GGML's NeoX-style rotation
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 &&
strings.Contains(name, "blk.") && (strings.Contains(name, "attn_q.") || strings.Contains(name, "attn_k.")) {
// Get config values for permutation
nHeads := int(cmp.Or(m.TextConfig.NumAttentionHeads, 16))
nKVHeads := int(cmp.Or(m.TextConfig.NumKeyValueHeads, 8))
hiddenSize := int(cmp.Or(m.TextConfig.HiddenSize, 1536))
headDim := int(cmp.Or(m.TextConfig.HeadDim, uint32(hiddenSize/nHeads)))
partialRotaryFactor := cmp.Or(m.TextConfig.PartialRotaryFactor, m.TextConfig.RopeParameters.PartialRotaryFactor, float32(1.0))
// Use appropriate head count: nHeads for Q, nKVHeads for K
effectiveHeads := nHeads
if strings.Contains(name, "attn_k.") {
effectiveHeads = nKVHeads
}
permutedT := t.Clone()
permutedT.SetRepacker(normalToNeoXRepacker(effectiveHeads, headDim, partialRotaryFactor))
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: permutedT,
})
continue
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *glmOcrModel) Replacements() []string {
return []string{
// Vision encoder
"model.visual.patch_embed.proj_1", "v.patch_embd_1", // Second temporal split
"model.visual.patch_embed.proj", "v.patch_embd",
"model.visual.blocks", "v.blk",
"model.visual.post_layernorm", "v.post_ln",
"model.visual.downsample", "mm.patch_merger",
// Vision attention
"attn.qkv", "attn_qkv",
"attn.proj", "attn_out",
"attn.q_norm", "attn_q_norm",
"attn.k_norm", "attn_k_norm",
// Vision norms
"norm1", "ln1",
"norm2", "ln2",
// Vision MLP
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
// Merger (multimodal projector)
"model.visual.merger.proj", "mm.model.fc",
"model.visual.merger.post_projection_norm", "mm.post_norm",
"model.visual.merger.gate_proj", "mm.gate",
"model.visual.merger.up_proj", "mm.up",
"model.visual.merger.down_proj", "mm.down",
// Language model
"model.language_model.embed_tokens", "token_embd",
"model.language_model.layers", "blk",
"model.language_model.norm", "output_norm",
"lm_head", "output",
// Language model attention
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_out",
// Language model norms
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"post_self_attn_layernorm", "post_attn_norm",
"post_mlp_layernorm", "post_ffn_norm",
// Language model MLP (remove mlp. prefix so ffn_* names work)
"mlp.gate_up_proj", "ffn_gate_up",
"mlp.down_proj", "ffn_down",
}
}

View File

@@ -0,0 +1,512 @@
package convert
import (
"fmt"
"io/fs"
"math"
"slices"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type qwen3NextModel struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
// MoE config
NumExperts uint32 `json:"num_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NormTopkProb bool `json:"norm_topk_prob"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
// Hybrid attention config
FullAttentionInterval uint32 `json:"full_attention_interval"`
// Linear attention (Gated Delta Net) config
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
LinearKeyHeadDim uint32 `json:"linear_key_head_dim"`
LinearNumKeyHeads uint32 `json:"linear_num_key_heads"`
LinearNumValueHeads uint32 `json:"linear_num_value_heads"`
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
// RoPE config
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeScaling struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
} `json:"rope_scaling"`
}
var _ ModelConverter = (*qwen3NextModel)(nil)
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
if q.NumHiddenLayers == 0 {
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
}
if q.NumAttentionHeads == 0 {
return fmt.Errorf("qwen3next: num_attention_heads must be set")
}
if q.NumKeyValueHeads == 0 {
return fmt.Errorf("qwen3next: num_key_value_heads must be set")
}
if q.HeadDim == 0 {
return fmt.Errorf("qwen3next: head_dim must be set")
}
if q.RopeTheta == 0 {
return fmt.Errorf("qwen3next: rope_theta must be set")
}
if q.PartialRotaryFactor <= 0 || q.PartialRotaryFactor > 1 {
return fmt.Errorf("qwen3next: partial_rotary_factor must be in (0,1], got %v", q.PartialRotaryFactor)
}
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
}
if q.FullAttentionInterval == 0 {
return fmt.Errorf("qwen3next: full_attention_interval must be set")
}
if q.FullAttentionInterval > q.NumHiddenLayers {
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
hasFull := false
for i := range q.NumHiddenLayers {
if (i+1)%q.FullAttentionInterval == 0 {
hasFull = true
break
}
}
if !hasFull {
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
return nil
}
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen3next"
kv["tokenizer.ggml.pre"] = "qwen2"
kv["block_count"] = q.NumHiddenLayers
kv["context_length"] = q.MaxPositionEmbeddings
kv["embedding_length"] = q.HiddenSize
kv["feed_forward_length"] = q.IntermediateSize
kv["attention.head_count"] = q.NumAttentionHeads
headDim := q.HeadDim
if headDim == 0 && q.NumAttentionHeads > 0 {
headDim = q.HiddenSize / q.NumAttentionHeads
}
kv["attention.key_length"] = headDim
kv["attention.value_length"] = headDim
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
kv["rope.freq_base"] = q.RopeTheta
// RoPE dimension count (partial rotary)
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
partialRotary := q.PartialRotaryFactor
if partialRotary > 0 && partialRotary <= 1 {
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
}
// MoE config
if q.NumExperts > 0 {
kv["expert_count"] = q.NumExperts
kv["expert_used_count"] = q.NumExpertsPerToken
kv["norm_top_k_prob"] = q.NormTopkProb
if q.MoEIntermediateSize > 0 {
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
}
if q.SharedExpertIntermSize > 0 {
kv["expert_shared_feed_forward_length"] = q.SharedExpertIntermSize
}
}
// SSM/Linear attention config
// d_inner = linear_value_head_dim * linear_num_value_heads
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
kv["ssm.inner_size"] = dInner
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
interval := q.FullAttentionInterval
kv["full_attention_interval"] = interval
// Build per-layer KV head count array to identify layer types
// 0 = recurrent (linear attention), non-zero = full attention
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
for i := range q.NumHiddenLayers {
// Full attention every full_attention_interval layers (starting at interval-1)
if interval > 0 && (i+1)%interval == 0 {
kvHeadCounts[i] = q.NumKeyValueHeads
}
// else stays 0 (recurrent layer)
}
kv["attention.head_count_kv"] = kvHeadCounts
// RoPE scaling
if q.RopeScaling.Type != "" {
kv["rope.scaling.type"] = q.RopeScaling.Type
kv["rope.scaling.factor"] = q.RopeScaling.Factor
}
return kv
}
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
// Create merges for expert tensors - stack individual experts into batched tensors
merges := make([]merge, q.NumHiddenLayers*3)
for i := range q.NumHiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
// Merge expert tensors
merged, remaining := mergeTensors(ts, merges...)
out = append(out, merged...)
// Process remaining tensors
for _, t := range remaining {
name := t.Name()
shape := t.Shape()
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
if strings.HasSuffix(name, ".ssm_in.weight") {
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
out = append(out, qkv, gate)
continue
}
panic(fmt.Sprintf("qwen3next: failed to split %s into attn_qkv/attn_gate (shape=%v)", name, shape))
}
switch {
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
// This matches the Python converter behavior for qwen3next
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
t.SetRepacker(q.addOne)
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
// Handle linear attention A_log -> ssm_a (negate and exp)
// Note: name has already been transformed by Replacements at this point
case strings.HasSuffix(name, ".ssm_a"):
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
// Compute -exp(A_log)
result := make([]float32, len(data))
for i, v := range data {
// -exp(v)
result[i] = -float32(math.Exp(float64(v)))
}
return result, nil
})
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
newShape := slices.Clone(shape)
if len(shape) == 3 {
if shape[0] == 1 {
// [1, D, K] -> [D, K]
newShape = []uint64{shape[1], shape[2]}
} else if shape[1] == 1 {
// [D, 1, K] -> [D, K]
newShape = []uint64{shape[0], shape[2]}
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
newShape := slices.Clone(shape)
if len(shape) == 2 {
if shape[0] == 1 && shape[1] > 1 {
newShape = []uint64{shape[1]}
} else if shape[1] == 1 && shape[0] > 1 {
newShape = []uint64{shape[0]}
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
default:
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
}
}
return out
}
type qkvzSplitSpec struct {
hidden int
headKDim int
headVDim int
numKHeads int
numVHeads int
qkvzDim int
qkvOut int
gateOut int
}
func (q *qwen3NextModel) qkvzSpec(shape []uint64) (qkvzSplitSpec, bool) {
if len(shape) != 2 {
return qkvzSplitSpec{}, false
}
numKHeads := int(q.LinearNumKeyHeads)
numVHeads := int(q.LinearNumValueHeads)
headKDim := int(q.LinearKeyHeadDim)
headVDim := int(q.LinearValueHeadDim)
if numKHeads == 0 || numVHeads == 0 || headKDim == 0 || headVDim == 0 {
return qkvzSplitSpec{}, false
}
if numVHeads%numKHeads != 0 {
return qkvzSplitSpec{}, false
}
hidden := int(shape[1])
vPerHead := headVDim * (numVHeads / numKHeads)
qkvzDim := 2*headKDim + 2*vPerHead
expectedOut := qkvzDim * numKHeads
if int(shape[0]) != expectedOut {
return qkvzSplitSpec{}, false
}
return qkvzSplitSpec{
hidden: hidden,
headKDim: headKDim,
headVDim: headVDim,
numKHeads: numKHeads,
numVHeads: numVHeads,
qkvzDim: qkvzDim,
qkvOut: 2*headKDim*numKHeads + headVDim*numVHeads,
gateOut: headVDim * numVHeads,
}, true
}
func (q *qwen3NextModel) splitQKVZTensor(t Tensor) (*ggml.Tensor, *ggml.Tensor, bool) {
spec, ok := q.qkvzSpec(t.Shape())
if !ok {
return nil, nil, false
}
qkvTensor := t.Clone()
qkvTensor.SetRepacker(q.repackQKVZ(spec, false))
gateTensor := t.Clone()
gateTensor.SetRepacker(q.repackQKVZ(spec, true))
qkvName := strings.Replace(t.Name(), "ssm_in", "attn_qkv", 1)
gateName := strings.Replace(t.Name(), "ssm_in", "attn_gate", 1)
return &ggml.Tensor{
Name: qkvName,
Kind: t.Kind(),
Shape: []uint64{uint64(spec.qkvOut), uint64(spec.hidden)},
WriterTo: qkvTensor,
}, &ggml.Tensor{
Name: gateName,
Kind: t.Kind(),
Shape: []uint64{uint64(spec.gateOut), uint64(spec.hidden)},
WriterTo: gateTensor,
}, true
}
func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repacker {
vPerHead := spec.headVDim * (spec.numVHeads / spec.numKHeads)
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
var err error
// Convert to [hidden, out_features] layout for slicing
tt, err = tensor.Transpose(tt, 1, 0)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
if err := tt.Reshape(spec.hidden, spec.numKHeads, spec.qkvzDim); err != nil {
return nil, err
}
offset := 0
qSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
if err != nil {
return nil, err
}
offset += spec.headKDim
kSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
if err != nil {
return nil, err
}
offset += spec.headKDim
vSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
if err != nil {
return nil, err
}
offset += vPerHead
zSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
if err != nil {
return nil, err
}
qMat := tensor.Materialize(qSlice).(*tensor.Dense)
kMat := tensor.Materialize(kSlice).(*tensor.Dense)
vMat := tensor.Materialize(vSlice).(*tensor.Dense)
zMat := tensor.Materialize(zSlice).(*tensor.Dense)
if err := qMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
return nil, err
}
if err := kMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
return nil, err
}
if err := vMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
return nil, err
}
if err := zMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
return nil, err
}
var out tensor.Tensor
if extractGate {
out = zMat
} else {
out, err = tensor.Concat(1, qMat, kMat, vMat)
if err != nil {
return nil, err
}
}
out = tensor.Materialize(out)
out, err = tensor.Transpose(out, 1, 0)
if err != nil {
return nil, err
}
out = tensor.Materialize(out)
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(out.(*tensor.Dense))
}
}
// addOne adds 1.0 to all elements in the tensor (for norm weights)
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, int(shape[0]))
n, err := n.Add(ones)
if err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 0)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}
func (q *qwen3NextModel) Replacements() []string {
return []string{
// Embeddings and output
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
// Layer norms
"input_layernorm", "attn_norm",
"post_attention_layernorm", "post_attention_norm",
// Full attention (self_attn)
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
// Linear attention (Gated Delta Net)
"linear_attn.in_proj_qkvz", "ssm_in",
"linear_attn.in_proj_ba", "ssm_ba",
"linear_attn.conv1d", "ssm_conv1d",
"linear_attn.dt_bias", "ssm_dt",
"linear_attn.dt_proj", "ssm_dt",
"linear_attn.A_log", "ssm_a",
"linear_attn.norm", "ssm_norm",
"linear_attn.out_proj", "ssm_out",
// MoE (experts are stacked via mergeTensors, not replaced here)
"mlp.gate.weight", "ffn_gate_inp.weight",
"mlp.shared_expert.down_proj", "ffn_down_shexp",
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
"mlp.shared_expert.up_proj", "ffn_up_shexp",
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
// Dense FFN (if any layers use it)
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
}
}

View File

@@ -41,6 +41,7 @@ func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
strings.HasSuffix(t.name, ".bias") ||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||

View File

@@ -99,6 +99,8 @@ func (st safetensor) Kind() uint32 {
if st.dtype == "BF16" &&
!strings.HasPrefix(st.name, "v.") &&
!strings.HasPrefix(st.name, "s.") &&
!strings.HasPrefix(st.name, "mm.") &&
!strings.Contains(st.name, "ffn_gate_inp_shexp.weight") &&
kind != tensorKindFP32 {
kind = tensorKindBF16
}

View File

@@ -71,6 +71,10 @@
{
"source": "/api",
"destination": "/api/introduction"
},
{
"source": "/integrations/clawdbot",
"destination": "/integrations/openclaw"
}
],
"navigation": {
@@ -102,8 +106,8 @@
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/clawdbot",
"/integrations/cline",
"/integrations/openclaw",
"/integrations/codex",
"/integrations/droid",
"/integrations/goose",

View File

@@ -10,6 +10,7 @@ Check your compute compatibility to see if your card is supported:
| Compute Capability | Family | Cards |
| ------------------ | ------------------- | ------------------------------------------------------------------------------------------------------------------------------ |
| 12.1 | NVIDIA | `GB10 (DGX Spark)` |
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
| | NVIDIA Professional | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
| 9.0 | NVIDIA | `H200` `H100` |
@@ -163,4 +164,4 @@ To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
by setting `GGML_VK_VISIBLE_DEVICES=-1`
by setting `GGML_VK_VISIBLE_DEVICES=-1`

View File

@@ -134,22 +134,12 @@ success
### Supported Quantizations
- `q4_0`
- `q4_1`
- `q5_0`
- `q5_1`
- `q8_0`
#### K-means Quantizations
- `q3_K_S`
- `q3_K_M`
- `q3_K_L`
- `q4_K_S`
- `q4_K_M`
- `q5_K_S`
- `q5_K_M`
- `q6_K`
## Sharing your model on ollama.com

View File

@@ -1,41 +1,43 @@
---
title: Clawdbot
title: OpenClaw
---
Clawdbot is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
## Install
Install [Clawdbot](https://clawd.bot/)
Install [OpenClaw](https://openclaw.ai/)
```bash
npm install -g clawdbot@latest
npm install -g openclaw@latest
```
Then run the onboarding wizard:
```bash
clawdbot onboard --install-daemon
openclaw onboard --install-daemon
```
<Note>Clawdbot requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
```bash
ollama launch clawdbot
ollama launch openclaw
```
This configures Clawdbot to use Ollama and starts the gateway.
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
This configures OpenClaw to use Ollama and starts the gateway.
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
To configure without launching:
```shell
ollama launch clawdbot --config
ollama launch openclaw --config
```
## Recommended Models

View File

@@ -9,7 +9,7 @@ OpenCode is an open-source AI coding assistant that runs in your terminal.
Install the [OpenCode CLI](https://opencode.ai):
```bash
curl -fsSL https://opencode.ai/install.sh | bash
curl -fsSL https://opencode.ai/install | bash
```
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>

View File

@@ -201,7 +201,7 @@ var (
// Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 0)
// Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH")
// Enable Vulkan backend
@@ -290,7 +290,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},

View File

@@ -282,7 +282,7 @@ func TestVar(t *testing.T) {
func TestContextLength(t *testing.T) {
cases := map[string]uint{
"": 4096,
"": 0,
"2048": 2048,
}

View File

@@ -268,8 +268,10 @@ func (kv KV) OllamaEngineRequired() bool {
"olmo3",
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3next",
"qwen3vl", "qwen3vlmoe",
"glm4moelite",
"glmocr",
"lfm2",
}, kv.Architecture())
}
@@ -859,11 +861,13 @@ func (f GGML) FlashAttention() bool {
"bert",
"gemma3",
"glm4moelite",
"glmocr",
"gptoss", "gpt-oss",
"lfm2",
"mistral3",
"olmo3",
"qwen3", "qwen3moe",
"qwen3next",
"qwen3vl", "qwen3vlmoe",
}, f.KV().String("general.architecture"))
}

6
go.mod
View File

@@ -13,7 +13,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.24
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/stretchr/testify v1.10.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
@@ -27,7 +27,10 @@ require (
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/tree-sitter/go-tree-sitter v0.25.0
github.com/tree-sitter/tree-sitter-cpp v0.23.4
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0
@@ -49,6 +52,7 @@ require (
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-pointer v0.0.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect

34
go.sum
View File

@@ -152,6 +152,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
@@ -174,6 +176,8 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -204,12 +208,39 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
@@ -304,6 +335,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=

View File

@@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) {
}
ChatTestHelper(ctx, t, req, blueSkyExpected)
}
// TestNumPredict verifies that when num_predict is set, the model generates
// exactly that many tokens. It uses logprobs to count the actual tokens output.
func TestNumPredict(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
t.Fatalf("failed to pull model: %v", err)
}
req := api.GenerateRequest{
Model: "qwen3:0.6b",
Prompt: "Write a long story.",
Stream: &stream,
Logprobs: true,
Options: map[string]any{
"num_predict": 10,
"temperature": 0,
"seed": 123,
},
}
logprobCount := 0
var finalResponse api.GenerateResponse
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
logprobCount += len(resp.Logprobs)
if resp.Done {
finalResponse = resp
}
return nil
})
if err != nil {
t.Fatalf("generate failed: %v", err)
}
if logprobCount != 10 {
t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
}
}

View File

@@ -75,3 +75,10 @@ type Cache interface {
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}
// CheckpointCache optionally supports restoring recurrent state to a prior
// position to avoid full prompt reprocessing when a prefix mismatch occurs.
// The returned position is the number of tokens that can be kept (prefix length).
type CheckpointCache interface {
PrepareRestore(seq int, targetPos int32) (int32, bool)
}

View File

@@ -0,0 +1,276 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan <jmorganca@gmail.com>
Date: Tue, 3 Feb 2026 12:00:00 -0800
Subject: [PATCH] ggml: metal solve_tri
---
ggml/src/ggml-metal/ggml-metal-device.cpp | 20 +++++++
ggml/src/ggml-metal/ggml-metal-device.h | 1 +
ggml/src/ggml-metal/ggml-metal-device.m | 11 ++++
ggml/src/ggml-metal/ggml-metal-impl.h | 21 ++++++++
ggml/src/ggml-metal/ggml-metal-ops.cpp | 63 +++++++++++++++++++++++
ggml/src/ggml-metal/ggml-metal-ops.h | 1 +
ggml/src/ggml-metal/ggml-metal.metal | 60 +++++++++++++++++++++
7 files changed, 177 insertions(+)
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
index 680904d13..83385c9ef 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
return res;
}
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
+ assert(op->op == GGML_OP_SOLVE_TRI);
+
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
+
+ char base[256];
+ char name[256];
+
+ snprintf(base, 256, "kernel_solve_tri_f32");
+ snprintf(name, 256, "%s", base);
+
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+ if (!res.pipeline) {
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+ }
+
+ return res;
+}
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_GROUP_NORM);
diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
index 0a8b9211a..8a9d17460 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.h
+++ b/ggml/src/ggml-metal/ggml-metal-device.h
@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
index 7b5ee968c..4e5acfbe5 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.m
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
+ case GGML_OP_SOLVE_TRI:
+ return ggml_is_contiguous(op->src[0]) &&
+ ggml_is_contiguous(op->src[1]) &&
+ op->src[0]->type == GGML_TYPE_F32 &&
+ op->src[1]->type == GGML_TYPE_F32 &&
+ op->type == GGML_TYPE_F32;
+ case GGML_OP_COUNT_EQUAL:
+ return has_simdgroup_reduction &&
+ op->src[0]->type == GGML_TYPE_I32 &&
+ op->src[1]->type == GGML_TYPE_I32 &&
+ op->type == GGML_TYPE_I64;
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_NORM:
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
index 8944b07e9..cfdea9c07 100644
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -500,6 +500,27 @@ typedef struct {
float eps;
} ggml_metal_kargs_l2_norm;
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne10;
+ int32_t ne11;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_solve_tri;
+
typedef struct {
int64_t ne00;
int64_t ne01;
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index 80864f303..4ac135603 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
} break;
+ case GGML_OP_SOLVE_TRI:
+ {
+ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
+ } break;
case GGML_OP_GROUP_NORM:
{
n_fuse = ggml_metal_op_group_norm(ctx, idx);
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
return 1;
}
+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_solve_tri args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
+
+ const int64_t ncols = ne10;
+ const int64_t n_batches = (int64_t)ne02 * ne03;
+ const int64_t nr = n_batches * ncols;
+
+ int nth = 64;
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+ if (nth < 1) {
+ nth = 1;
+ }
+
+ const int64_t n_tg = (nr + nth - 1) / nth;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h
index 902b54452..a475183d3 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.h
+++ b/ggml/src/ggml-metal/ggml-metal-ops.h
@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index d33c16079..c37447a10 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
}
}
+kernel void kernel_solve_tri_f32(
+ constant ggml_metal_kargs_solve_tri & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint tgpig[[threadgroup_position_in_grid]],
+ ushort tpitg[[thread_position_in_threadgroup]],
+ ushort ntg[[threads_per_threadgroup]]) {
+ const uint64_t ncols = (uint64_t) args.ne10;
+ const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
+ const uint64_t nr = n_batches * ncols;
+
+ const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
+ if (gid >= nr) {
+ return;
+ }
+
+ const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
+ const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
+ const uint64_t i02 = rem / ncols;
+ const uint64_t i01 = rem - i02 * ncols;
+
+ const uint64_t sa0 = args.nb00 / sizeof(float);
+ const uint64_t sa1 = args.nb01 / sizeof(float);
+ const uint64_t sa2 = args.nb02 / sizeof(float);
+ const uint64_t sa3 = args.nb03 / sizeof(float);
+
+ const uint64_t sb0 = args.nb10 / sizeof(float);
+ const uint64_t sb1 = args.nb11 / sizeof(float);
+ const uint64_t sb2 = args.nb12 / sizeof(float);
+ const uint64_t sb3 = args.nb13 / sizeof(float);
+
+ const uint64_t sx0 = args.nb0 / sizeof(float);
+ const uint64_t sx1 = args.nb1 / sizeof(float);
+ const uint64_t sx2 = args.nb2 / sizeof(float);
+ const uint64_t sx3 = args.nb3 / sizeof(float);
+
+ device const float * A = (device const float *) src0;
+ device const float * B = (device const float *) src1;
+ device float * X = (device float *) dst;
+
+ const uint64_t A_base = i02 * sa2 + i03 * sa3;
+ const uint64_t B_base = i02 * sb2 + i03 * sb3;
+ const uint64_t X_base = i02 * sx2 + i03 * sx3;
+
+ const uint64_t n = (uint64_t) args.ne11;
+
+ for (uint64_t i00 = 0; i00 < n; ++i00) {
+ float sum = 0.0f;
+ for (uint64_t t = 0; t < i00; ++t) {
+ sum += A[A_base + i00 * sa1 + t * sa0] *
+ X[X_base + t * sx1 + i01 * sx0];
+ }
+
+ const float diag = A[A_base + i00 * sa1 + i00 * sa0];
+ X[X_base + i00 * sx1 + i01 * sx0] =
+ (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
+ }
+}
+
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,

View File

@@ -34,6 +34,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
type filteredEnv []string
@@ -80,6 +81,7 @@ type LlamaServer interface {
GetPort() int
GetDeviceInfos(ctx context.Context) []ml.DeviceInfo
HasExited() bool
ContextLength() int
}
// llmServer is an instance of a runner hosting a single model
@@ -115,7 +117,7 @@ type llamaServer struct {
type ollamaServer struct {
llmServer
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
}
// LoadModel will load a model from disk. The model must be in the GGML format.
@@ -141,11 +143,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
var llamaModel *llama.Model
var textProcessor model.TextProcessor
var tok tokenizer.Tokenizer
var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 {
textProcessor, err = model.NewTextProcessor(modelPath)
tok, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
@@ -154,7 +156,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
}
}
if textProcessor == nil {
if tok == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
@@ -210,7 +212,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType())
if textProcessor == nil {
if tok == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
@@ -260,7 +262,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
gpuLibs := ml.LibraryPaths(gpus)
status := NewStatusWriter(os.Stderr)
cmd, port, err := StartRunner(
textProcessor != nil,
tok != nil,
modelPath,
gpuLibs,
status,
@@ -309,8 +311,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}
}()
if textProcessor != nil {
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
if tok != nil {
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
}
@@ -1200,7 +1202,8 @@ func (s *llmServer) initModel(ctx context.Context, req LoadRequest, operation Lo
resp, err := http.DefaultClient.Do(r)
if err != nil {
return nil, fmt.Errorf("do load request: %w", err)
slog.Error("do load request", "error", err)
return nil, errors.New("model failed to load, this may be due to resource limitations or an internal error, check ollama server logs for details")
}
defer resp.Body.Close()
@@ -1772,7 +1775,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
}
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
tokens, err := s.textProcessor.Encode(content, false)
tokens, err := s.tokenizer.Encode(content, false)
if err != nil {
return nil, err
}
@@ -1807,7 +1810,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
toks[i] = int32(t)
}
content, err := s.textProcessor.Decode(toks)
content, err := s.tokenizer.Decode(toks)
if err != nil {
return "", err
}
@@ -1901,6 +1904,10 @@ func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
return 0
}
func (s *llmServer) ContextLength() int {
return s.options.NumCtx
}
func (s *ollamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
devices, err := ml.GetDevicesFromRunner(ctx, s)
if err != nil {

View File

@@ -131,12 +131,15 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
messageID := anthropic.GenerateMessageID()
// Estimate input tokens for streaming (actual count not available until generation completes)
estimatedTokens := anthropic.EstimateInputTokens(req)
w := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
model: req.Model,
converter: anthropic.NewStreamConverter(messageID, req.Model),
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
}
if req.Stream {

View File

@@ -170,10 +170,12 @@ type Tensor interface {
Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
GELU_ERF(ctx Context) Tensor
QuickGELU(ctx Context, up ...Tensor) Tensor
SILU(ctx Context, up ...Tensor) Tensor
RELU(ctx Context, up ...Tensor) Tensor
Sigmoid(ctx Context) Tensor
SigmoidOut(ctx Context) Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
@@ -206,6 +208,32 @@ type Tensor interface {
Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor
Exp(ctx Context) Tensor
Neg(ctx Context) Tensor
// Clamp clamps values to [min, max] range
Clamp(ctx Context, min, max float32) Tensor
// Softplus computes ln(1 + exp(x))
Softplus(ctx Context) Tensor
// CumSum computes cumulative sum along dimension 0
CumSum(ctx Context) Tensor
// Diag creates a diagonal matrix from a 1D tensor
Diag(ctx Context) Tensor
// Tri converts a matrix to triangular form (0=upper+diag, 1=upper, 2=lower+diag, 3=lower)
Tri(ctx Context, triType int) Tensor
// Fill fills a tensor with a constant value (in-place)
Fill(ctx Context, value float32) Tensor
// Repeat4D repeats tensor to match target shape
Repeat4D(ctx Context, dim0, dim1, dim2, dim3 int) Tensor
// SolveTri solves a triangular system Ax = B
SolveTri(ctx Context, b Tensor, lower, left, unitDiag bool) Tensor
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}

View File

@@ -378,7 +378,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
}
}
maxGraphNodes := max(1024, len(meta.Tensors().Items())*8)
maxGraphNodes := max(1024, len(meta.Tensors().Items())*32)
sched := C.ggml_backend_sched_new_ext(
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
@@ -1468,6 +1468,13 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) SigmoidOut(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sigmoid(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
switch len(shape) {
case 1:
@@ -1581,6 +1588,13 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) GELU_ERF(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_gelu_erf_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
var tt *C.struct_ggml_tensor
if len(t2) > 0 {
@@ -1772,6 +1786,76 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) Exp(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_exp(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
}
}
func (t *Tensor) Softplus(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_softplus(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) CumSum(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cumsum(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Diag(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_diag(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Tri(ctx ml.Context, triType int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_tri(ctx.(*Context).ctx, t.t, C.enum_ggml_tri_type(triType)),
}
}
func (t *Tensor) Fill(ctx ml.Context, value float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_fill_inplace(ctx.(*Context).ctx, t.t, C.float(value)),
}
}
func (t *Tensor) Repeat4D(ctx ml.Context, dim0, dim1, dim2, dim3 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_repeat_4d(ctx.(*Context).ctx, t.t, C.int64_t(dim0), C.int64_t(dim1), C.int64_t(dim2), C.int64_t(dim3)),
}
}
func (t *Tensor) SolveTri(ctx ml.Context, b ml.Tensor, lower, left, unitDiag bool) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_solve_tri(ctx.(*Context).ctx, t.t, b.(*Tensor).t, C._Bool(lower), C._Bool(left), C._Bool(unitDiag)),
}
}
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
var mode C.uint32_t
switch samplingMode {

View File

@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_SOLVE_TRI);
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
char base[256];
char name[256];
snprintf(base, 256, "kernel_solve_tri_f32");
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_GROUP_NORM);

View File

@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);

View File

@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_SOLVE_TRI:
return ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32;
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&
op->src[1]->type == GGML_TYPE_I32 &&
op->type == GGML_TYPE_I64;
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_NORM:

View File

@@ -2385,6 +2385,27 @@ typedef struct {
float eps;
} ggml_metal_kargs_l2_norm;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_solve_tri;
typedef struct {
int64_t ne00;
int64_t ne01;
@@ -5813,6 +5834,66 @@ kernel void kernel_l2_norm_f32(
}
}
kernel void kernel_solve_tri_f32(
constant ggml_metal_kargs_solve_tri & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort ntg[[threads_per_threadgroup]]) {
const uint64_t ncols = (uint64_t) args.ne10;
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
const uint64_t nr = n_batches * ncols;
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
if (gid >= nr) {
return;
}
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
const uint64_t i02 = rem / ncols;
const uint64_t i01 = rem - i02 * ncols;
const uint64_t sa0 = args.nb00 / sizeof(float);
const uint64_t sa1 = args.nb01 / sizeof(float);
const uint64_t sa2 = args.nb02 / sizeof(float);
const uint64_t sa3 = args.nb03 / sizeof(float);
const uint64_t sb0 = args.nb10 / sizeof(float);
const uint64_t sb1 = args.nb11 / sizeof(float);
const uint64_t sb2 = args.nb12 / sizeof(float);
const uint64_t sb3 = args.nb13 / sizeof(float);
const uint64_t sx0 = args.nb0 / sizeof(float);
const uint64_t sx1 = args.nb1 / sizeof(float);
const uint64_t sx2 = args.nb2 / sizeof(float);
const uint64_t sx3 = args.nb3 / sizeof(float);
device const float * A = (device const float *) src0;
device const float * B = (device const float *) src1;
device float * X = (device float *) dst;
const uint64_t A_base = i02 * sa2 + i03 * sa3;
const uint64_t B_base = i02 * sb2 + i03 * sb3;
const uint64_t X_base = i02 * sx2 + i03 * sx3;
const uint64_t n = (uint64_t) args.ne11;
for (uint64_t i00 = 0; i00 < n; ++i00) {
float sum = 0.0f;
for (uint64_t t = 0; t < i00; ++t) {
sum += A[A_base + i00 * sa1 + t * sa0] *
X[X_base + t * sx1 + i01 * sx0];
}
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
X[X_base + i00 * sx1 + i01 * sx0] =
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
}
}
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,

View File

@@ -500,6 +500,27 @@ typedef struct {
float eps;
} ggml_metal_kargs_l2_norm;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_solve_tri;
typedef struct {
int64_t ne00;
int64_t ne01;

View File

@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
} break;
case GGML_OP_SOLVE_TRI:
{
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
} break;
case GGML_OP_GROUP_NORM:
{
n_fuse = ggml_metal_op_group_norm(ctx, idx);
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_solve_tri args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
const int64_t ncols = ne10;
const int64_t n_batches = (int64_t)ne02 * ne03;
const int64_t nr = n_batches * ncols;
int nth = 64;
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
if (nth < 1) {
nth = 1;
}
const int64_t n_tg = (nr + nth - 1) / nth;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
return 1;
}
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

View File

@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);

View File

@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
}
}
kernel void kernel_solve_tri_f32(
constant ggml_metal_kargs_solve_tri & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort ntg[[threads_per_threadgroup]]) {
const uint64_t ncols = (uint64_t) args.ne10;
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
const uint64_t nr = n_batches * ncols;
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
if (gid >= nr) {
return;
}
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
const uint64_t i02 = rem / ncols;
const uint64_t i01 = rem - i02 * ncols;
const uint64_t sa0 = args.nb00 / sizeof(float);
const uint64_t sa1 = args.nb01 / sizeof(float);
const uint64_t sa2 = args.nb02 / sizeof(float);
const uint64_t sa3 = args.nb03 / sizeof(float);
const uint64_t sb0 = args.nb10 / sizeof(float);
const uint64_t sb1 = args.nb11 / sizeof(float);
const uint64_t sb2 = args.nb12 / sizeof(float);
const uint64_t sb3 = args.nb13 / sizeof(float);
const uint64_t sx0 = args.nb0 / sizeof(float);
const uint64_t sx1 = args.nb1 / sizeof(float);
const uint64_t sx2 = args.nb2 / sizeof(float);
const uint64_t sx3 = args.nb3 / sizeof(float);
device const float * A = (device const float *) src0;
device const float * B = (device const float *) src1;
device float * X = (device float *) dst;
const uint64_t A_base = i02 * sa2 + i03 * sa3;
const uint64_t B_base = i02 * sb2 + i03 * sb3;
const uint64_t X_base = i02 * sx2 + i03 * sx3;
const uint64_t n = (uint64_t) args.ne11;
for (uint64_t i00 = 0; i00 < n; ++i00) {
float sum = 0.0f;
for (uint64_t t = 0; t < i00; ++t) {
sum += A[A_base + i00 * sa1 + t * sa0] *
X[X_base + t * sx1 + i01 * sx0];
}
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
X[X_base + i00 * sx1 + i01 * sx0] =
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
}
}
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,

View File

@@ -1,272 +0,0 @@
package model
import (
"cmp"
"iter"
"slices"
"strings"
"github.com/dlclark/regexp2"
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil"
)
type BytePairEncoding struct {
vocab *Vocabulary
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
}
}),
}
}
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
return bpe.vocab
}
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
return bpe.vocab.Is(id, special)
}
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
parts := []string{s}
for _, re := range bpe.regexps {
parts = slices.Collect(func(yield func(string) bool) {
for _, part := range parts {
r := []rune(part)
var offset int
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
if offset-m.Index != 0 {
if !yield(string(r[:m.Index])) {
return
}
}
if !yield(m.String()) {
return
}
offset = m.Index + m.Length
}
if offset < len(r) {
if !yield(string(r[offset:])) {
return
}
}
}
})
}
return slices.Values(parts)
}
// fragment is a string fragment and their corresponding token IDs
type fragment struct {
value string
ids []int32
}
// pair is a pair of runes and its rank
type pair struct {
a, b int
rank int
value string
}
type merge struct {
p, n int
runes []rune
}
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := bpe.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range bpe.split(frag.value) {
// TODO: process splits concurrently
var sb strings.Builder
for _, b := range []byte(split) {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
sb.WriteRune(r)
}
// short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *pair {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
rank := bpe.vocab.Merge(left, right)
if rank < 0 {
return nil
}
return &pair{
a: a,
b: b,
rank: rank,
value: left + right,
}
}
pairs := heap.NewWith(func(i, j *pair) int {
return cmp.Compare(i.rank, j.rank)
})
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pairs.Push(pair)
}
}
for !pairs.Empty() {
pair, _ := pairs.Pop()
left, right := merges[pair.a], merges[pair.b]
if len(left.runes) == 0 || len(right.runes) == 0 ||
string(left.runes)+string(right.runes) != pair.value {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pairs.Push(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pairs.Push(pair)
}
}
for _, merge := range merges {
if len(merge.runes) > 0 {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
}
}
}
}
}
if addSpecial {
ids = bpe.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
for _, r := range bpe.vocab.Decode(id) {
switch {
case r == 0x0100:
// this produces 0x00 aka NULL
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// NOTE: not using WriteRune here because it writes the UTF-8
// encoding of the rune which is _not_ what we want
if err := sb.WriteByte(byte(r)); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "string", sb.String(), "from", ids)
return sb.String(), nil
}

View File

@@ -23,6 +23,7 @@ import (
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var (
@@ -133,7 +134,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil
}
func NewTextProcessor(s string) (TextProcessor, error) {
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
@@ -150,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
tp, ok := m.(TextProcessor)
tp, ok := m.(tokenizer.Tokenizer)
if !ok {
return nil, ErrUnsupportedTokenizer
}

View File

@@ -56,6 +56,18 @@ type fakeTensor struct {
Name string
}
// Stub methods to satisfy ml.Tensor interface
func (f *fakeTensor) Exp(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) Neg(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) Clamp(ctx ml.Context, _, _ float32) ml.Tensor { return f }
func (f *fakeTensor) Softplus(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) CumSum(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) Diag(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor { return f }
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
func (m *fakeBackend) Get(name string) ml.Tensor {
if slices.Contains(m.names, name) {
return &fakeTensor{Name: name}

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
}
func New(c fs.Config) (model.Model, error) {
vocab := &model.Vocabulary{
vocab := &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
},
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(vocab, true)
t = tokenizer.NewWordPiece(vocab, true)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Tokenizer: t,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
@@ -134,8 +135,8 @@ func init() {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -43,8 +44,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
poolingType pooling.Type
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
),
}
var processor model.TextProcessor
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model") {
case "gpt2":
processor = model.NewBytePairEncoding(&vocabulary)
t = tokenizer.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
processor = model.NewSentencePiece(&vocabulary)
t = tokenizer.NewSentencePiece(&vocabulary)
}
m := Model{
TextProcessor: processor,
Tokenizer: t,
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),

View File

@@ -6,11 +6,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.SentencePiece
tokenizer.Tokenizer
*TextModel
}
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizer.NewSentencePiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
@@ -198,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -236,8 +237,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -0,0 +1,174 @@
package glmocr
import (
"image"
"log/slog"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize int
patchSize int
temporalPatchSize int
spatialMergeSize int
minPixels int
maxPixels int
factor int
imageMean [3]float32
imageStd [3]float32
}
func newImageProcessor(c fs.Config) ImageProcessor {
patchSize := int(c.Uint("vision.patch_size", 14))
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
// Read normalization values from config if available, otherwise use CLIP defaults
imageMean := c.Floats("vision.image_mean", imageproc.ClipDefaultMean[:])
imageStd := c.Floats("vision.image_std", imageproc.ClipDefaultSTD[:])
// Default max_pixels: 2048 * patchSize^2 * mergeSize^2 * temporal = ~3.2M pixels
// This limits to ~16k patches (4k output tokens) to keep memory stable without flash attention
defaultMaxPixels := 2048 * patchSize * patchSize * spatialMergeSize * spatialMergeSize * temporalPatchSize
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 336)),
patchSize: patchSize,
temporalPatchSize: temporalPatchSize,
spatialMergeSize: spatialMergeSize,
minPixels: int(c.Uint("vision.min_pixels", uint32(8*patchSize*patchSize*spatialMergeSize*spatialMergeSize*temporalPatchSize))),
maxPixels: int(c.Uint("vision.max_pixels", uint32(defaultMaxPixels))),
factor: patchSize * spatialMergeSize,
imageMean: [3]float32{imageMean[0], imageMean[1], imageMean[2]},
imageStd: [3]float32{imageStd[0], imageStd[1], imageStd[2]},
}
}
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
factor := p.factor
temporalFactor := p.temporalPatchSize
numFrames := temporalFactor // single image
if height < factor || width < factor {
// Scale up small images
scale := float64(factor) / float64(min(height, width))
height = int(math.Ceil(float64(height) * scale))
width = int(math.Ceil(float64(width) * scale))
}
if temporalFactor <= 0 {
slog.Warn("temporal_patch_size must be > 0, defaulting to 1")
temporalFactor = 1
}
if numFrames < temporalFactor {
slog.Warn("num_frames must be >= temporal_patch_size, adjusting num_frames", "num_frames", numFrames, "temporal_patch_size", temporalFactor)
numFrames = temporalFactor
}
if aspectRatio := float64(max(height, width)) / float64(min(height, width)); aspectRatio > 200 {
slog.Warn("aspect ratio exceeds 200, image quality may be affected", "aspect_ratio", aspectRatio)
}
round := func(x float64) int { return int(math.RoundToEven(x)) }
hBar := round(float64(height)/float64(factor)) * factor
wBar := round(float64(width)/float64(factor)) * factor
tBar := round(float64(numFrames)/float64(temporalFactor)) * temporalFactor
if tBar*hBar*wBar > p.maxPixels {
beta := math.Sqrt(float64(numFrames*height*width) / float64(p.maxPixels))
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
} else if tBar*hBar*wBar < p.minPixels {
beta := math.Sqrt(float64(p.minPixels) / float64(numFrames*height*width))
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar
}
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
img = imageproc.Composite(img)
origWidth := img.Bounds().Dx()
origHeight := img.Bounds().Dy()
// Calculate smart resize dimensions
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
// Resize image
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeCatmullrom)
// Normalize pixels - output format is [C, H, W] with rescale and channelFirst
// We keep [C, H, W] for patch extraction
normalizedPixels := imageproc.Normalize(resizedImg, p.imageMean, p.imageStd, true, true)
// Calculate grid dimensions (after Conv2D patching)
grid := &Grid{
Height: resizedHeight / p.patchSize,
Width: resizedWidth / p.patchSize,
Temporal: 1, // Single image
ImageHeight: resizedHeight,
ImageWidth: resizedWidth,
}
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
if err != nil {
return nil, nil, err
}
return patches, grid, nil
}
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
channels := 3
patchSize := p.patchSize
mergeSize := p.spatialMergeSize
temporalPatchSize := p.temporalPatchSize
numPatches := grid.Temporal * grid.Height * grid.Width
patchDim := channels * temporalPatchSize * patchSize * patchSize
result := make([]float32, numPatches*patchDim)
patchIndex := 0
// Single temporal frame handling (copies to all frames)
for range grid.Temporal {
for h := 0; h < grid.Height; h += mergeSize {
for w := 0; w < grid.Width; w += mergeSize {
for mh := range mergeSize {
for mw := range mergeSize {
baseOffset := patchIndex * patchDim
for c := range channels {
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
for py := range patchSize {
for px := range patchSize {
y := (h+mh)*patchSize + py
x := (w+mw)*patchSize + px
srcIdx := c*height*width + y*width + x
dstIdx := channelOffset + (py * patchSize) + px
result[dstIdx] = pixels[srcIdx]
}
}
if temporalPatchSize > 1 {
frameSize := patchSize * patchSize
for tp := 1; tp < temporalPatchSize; tp++ {
currentFrameOffset := channelOffset + (tp * frameSize)
copy(result[currentFrameOffset:currentFrameOffset+frameSize],
result[channelOffset:channelOffset+frameSize])
}
}
}
patchIndex++
}
}
}
}
}
return result, nil
}

View File

@@ -0,0 +1,236 @@
package glmocr
import (
"bytes"
"errors"
"image"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
VisionDownsample *VisionDownsample `gguf:"mm.patch_merger"`
PatchMerger *PatchMerger `gguf:"mm"`
ImageProcessor
imageTokenID int32
imageStartTokenID int32
imageEndTokenID int32
}
var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
eosTokenID := int32(c.Uint("tokenizer.ggml.eos_token_id"))
eosTokenIDs := c.Ints("tokenizer.ggml.eos_token_ids")
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
m := &Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: allEOS,
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
imageTokenID: int32(c.Uint("image_token_id", 59280)),
imageStartTokenID: int32(c.Uint("image_start_token_id", 59256)),
imageEndTokenID: int32(c.Uint("image_end_token_id", 59257)),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
return m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Blocks) == 0 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, grid, err := m.ImageProcessor.ProcessImage(img)
if err != nil {
return nil, err
}
// Create pixel values tensor from flattened patches
// Shape: [patchDim, numPatches]
patchDim := m.VisionModel.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
// Forward through vision encoder
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid)
// Forward through downsample (patch merger)
if m.VisionDownsample == nil || m.VisionDownsample.Weight == nil {
return nil, errors.New("glmocr: missing vision downsample weights")
}
visionOutputs = m.VisionDownsample.Forward(ctx, visionOutputs, grid, m.VisionModel.VisionModelOptions)
// Forward through patch merger (FC + LayerNorm + GELU + SwiGLU FFN)
if m.PatchMerger == nil {
return nil, errors.New("glmocr: missing patch merger weights")
}
visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.VisionModelOptions)
return []input.Multimodal{{Tensor: visionOutputs, Data: grid}}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
// Reset position cache
m.TextModel.positionCache = m.TextModel.positionCache[:0]
m.TextModel.ropeDelta = 0
pos := int32(0)
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
pos++
continue
}
// Get grid info for position calculation
grid := inp.Multimodal[0].Data.(*Grid)
mergedH := grid.Height / m.VisionModel.spatialMergeSize
mergedW := grid.Width / m.VisionModel.spatialMergeSize
// Add image start token
result = append(result, &input.Input{Token: m.imageStartTokenID})
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
pos++
// Add image tokens with multimodal data
// All image tokens share the same base position for temporal dimension
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
basePos := pos
sameBatch := tokensPerGrid - 1
if sameBatch < 0 {
sameBatch = 0
}
result = append(result, &input.Input{
Token: m.imageTokenID,
Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash,
SameBatch: sameBatch,
})
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
// Add placeholder tokens for remaining positions
// All image tokens use the same base position (temporal stays constant)
for range tokensPerGrid - 1 {
result = append(result, &input.Input{Token: m.imageTokenID})
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
}
// Advance position by max(mergedH, mergedW) after image tokens
pos = basePos + int32(max(mergedH, mergedW))
// Add image end token
result = append(result, &input.Input{Token: m.imageEndTokenID})
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
pos++
}
// Compute rope delta for continuation after the prefill segment:
// delta = (max_position_id + 1) - sequence_length
if len(m.TextModel.positionCache) > 0 {
last := m.TextModel.positionCache[len(m.TextModel.positionCache)-1]
m.TextModel.ropeDelta = last + 1 - int32(len(m.TextModel.positionCache))
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
// Initial token embedding
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
ctx.Forward(hiddenStates)
// Build position slices for M-RoPE
positionSlice := func() [][]int32 {
s := [][]int32{
make([]int32, len(batch.Positions)), // temporal
make([]int32, len(batch.Positions)), // height
make([]int32, len(batch.Positions)), // width
make([]int32, len(batch.Positions)), // unused (zeros)
}
for i, position := range batch.Positions {
// Translate through position cache or continue sequence
if position < int32(len(m.TextModel.positionCache)) {
position = m.TextModel.positionCache[position]
} else if len(m.TextModel.positionCache) > 0 {
// Continue sequence after cached positions using ropeDelta
position = position + m.TextModel.ropeDelta
}
s[0][i] = position
s[1][i] = position
s[2][i] = position
}
return s
}()
// Inject vision embeddings and adjust positions for image tokens
for _, mi := range batch.Multimodal {
img := mi.Multimodal[0].Tensor
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
w := grid.Width / m.VisionModel.spatialMergeSize
for i := range img.Dim(1) {
positionSlice[1][mi.Index+i] += int32(i / w)
positionSlice[2][mi.Index+i] += int32(i % w)
}
}
}
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
// Process through transformer layers
for i, layer := range m.TextModel.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.TextModel.Layers)-1 {
lastLayerOutputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, m.Cache, m.TextModel.TextModelOptions)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.TextModel.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("glmocr", New)
}

View File

@@ -0,0 +1,190 @@
package glmocr
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
)
type TextModelOptions struct {
hiddenSize int
numHeads int
numKVHeads int
headDim int
rotaryDim int
intermediateSize int
eps float32
ropeBase float32
mropeSections []int
}
func (o *TextModelOptions) applyMRoPE(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
// With 4 sections for [temporal, height, width, unused]
return nn.RoPE(ctx, states, positions, o.rotaryDim, o.ropeBase, 1.0, rope.WithMRoPE(o.mropeSections))
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenStates.Dim(1)
// Separate Q, K, V projections
q := sa.Query.Forward(ctx, hiddenStates)
k := sa.Key.Forward(ctx, hiddenStates)
v := sa.Value.Forward(ctx, hiddenStates)
// Reshape for GQA
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
k = k.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
v = v.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
// Apply M-RoPE (multi-resolution rotary position embeddings)
q = opts.applyMRoPE(ctx, q, positions)
k = opts.applyMRoPE(ctx, k, positions)
// Scaled dot-product attention with KV cache
scaleFactor := 1.0 / math.Sqrt(float64(opts.headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
// Reshape attention output: [headDim, numHeads, batchSize] -> [numHeads*headDim, batchSize]
// Note: numHeads * headDim = 16 * 128 = 2048, which is the attention hidden size
kqv = kqv.Reshape(ctx, opts.numHeads*opts.headDim, batchSize)
return sa.Output.Forward(ctx, kqv)
}
type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextModelOptions) ml.Tensor {
// SwiGLU: down(silu(gate(x)) * up(x))
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, gate)
}
type TextDecoderLayer struct {
// Input layernorm (before attention)
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *TextSelfAttention
// Post self-attention layernorm (after attention, before residual add)
PostAttnNorm *nn.RMSNorm `gguf:"post_attn_norm"`
// FFN input layernorm (after first residual, before MLP)
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *TextMLP
// Post MLP layernorm (after MLP, before residual add)
PostFFNNorm *nn.RMSNorm `gguf:"post_ffn_norm"`
}
func (l *TextDecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
// Attention block
residual := hiddenStates
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = l.SelfAttention.Forward(ctx, hiddenStates, positions, cache, opts)
hiddenStates = l.PostAttnNorm.Forward(ctx, hiddenStates, opts.eps)
// Prune to output positions in final layer
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
// MLP block
residual = hiddenStates
hiddenStates = l.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = l.PostFFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []TextDecoderLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextModelOptions
// positionCache stores the M-RoPE position for each token in the sequence.
// This is needed because image tokens share the same base position but have
// different height/width offsets, and the end token position depends on the
// image grid dimensions.
positionCache []int32
ropeDelta int32
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// Clear position cache when KV cache shifts
m.positionCache = nil
m.ropeDelta = 0
return m.applyMRoPE(ctx, key, shift), nil
}
func newTextModel(c fs.Config) *TextModel {
hiddenSize := int(c.Uint("embedding_length", 1536))
numHeads := int(c.Uint("attention.head_count", 16))
numKVHeads := int(c.Uint("attention.head_count_kv", 8))
intermediateSize := int(c.Uint("feed_forward_length", 4608))
eps := c.Float("attention.layer_norm_rms_epsilon", 1e-5)
ropeBase := c.Float("rope.freq_base", 10000)
headDim := int(c.Uint("attention.key_length", uint32(hiddenSize/numHeads)))
ropeDim := int(c.Uint("rope.dimension_count", uint32(headDim)))
if ropeDim <= 0 {
ropeDim = headDim
}
mropeSections := c.Ints("rope.mrope_section")
var sectionInts []int
if len(mropeSections) > 0 {
sectionInts = make([]int, len(mropeSections))
for i, section := range mropeSections {
sectionInts[i] = int(section)
}
} else {
// Default to GLM-OCR's HF ratio (2:3:3) scaled to rotaryDim/2.
// For rotaryDim=64 this yields [8, 12, 12].
total := ropeDim / 2
if total <= 0 {
total = 32
}
s0 := total * 2 / 8
s1 := total * 3 / 8
s2 := total - s0 - s1
sectionInts = []int{s0, s1, s2}
}
// GGML rope_multi: sector = (dim_pair) % sum(sections), mapping each pair to its position dim
rotaryDim := ropeDim
return &TextModel{
Layers: make([]TextDecoderLayer, c.Uint("block_count", 16)),
TextModelOptions: &TextModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
numKVHeads: numKVHeads,
headDim: headDim,
rotaryDim: rotaryDim,
intermediateSize: intermediateSize,
eps: eps,
ropeBase: ropeBase,
mropeSections: sectionInts,
},
}
}

View File

@@ -0,0 +1,355 @@
package glmocr
import (
"log/slog"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
)
type Grid struct {
Height int // Number of patches in height direction
Width int // Number of patches in width direction
Temporal int
ImageHeight int // Full image height in pixels
ImageWidth int // Full image width in pixels
}
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
numChannels int
patchSize int
temporalPatchSize int
imageSize int
spatialMergeSize int
outHiddenSize int
intermediateSize int
eps float32
}
type VisionPatchEmbed struct {
Proj *nn.Conv2D `gguf:"patch_embd_0"`
Proj1 *nn.Conv2D `gguf:"patch_embd_1"`
Bias ml.Tensor `gguf:"patch_embd.bias"`
}
func (pe *VisionPatchEmbed) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
_ = grid // patches are already in merge-block order
// pixelValues shape: [patchDim, numPatches]
numPatches := pixelValues.Shape()[1]
// Reshape to [patchSize*patchSize, temporalPatchSize, numChannels, numPatches]
pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
// Permute to [temporalPatchSize, patchSize*patchSize, numChannels, numPatches]
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Slice temporal frames for Conv2D (simulate Conv3D)
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
s0, s1 := opts.patchSize, opts.patchSize
p0, p1 := 0, 0
d0, d1 := 1, 1
hiddenStates := pe.Proj.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
if pe.Proj1 != nil && opts.temporalPatchSize > 1 {
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
out1 := pe.Proj1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
hiddenStates = hiddenStates.Add(ctx, out1)
}
// Flatten to [hidden_size, num_patches]
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, numPatches)
// Add patch bias - reshape from [hidden_size] to [hidden_size, 1] for broadcasting
if pe.Bias != nil {
hiddenStates = hiddenStates.Add(ctx, pe.Bias.Reshape(ctx, opts.hiddenSize, 1))
}
return hiddenStates
}
type VisionSelfAttention struct {
QKV *nn.Linear `gguf:"attn_qkv"`
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
batchSize := hiddenStates.Dim(1)
// Combined QKV projection: [3*hidden_size, batch_size]
qkv := sa.QKV.Forward(ctx, hiddenStates)
// Split using ChunkSections along dim 0 (handles byte offsets correctly)
// ChunkSections returns views - must make contiguous before further operations
chunks := qkv.ChunkSections(ctx, 0, opts.hiddenSize, opts.hiddenSize, opts.hiddenSize)
q := chunks[0].Contiguous(ctx)
k := chunks[1].Contiguous(ctx)
v := chunks[2].Contiguous(ctx)
// Reshape for multi-head attention: [hiddenSize, N] -> [headDim, numHeads, N]
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
k = k.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
v = v.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
// Apply Q-norm and K-norm after head reshape
// Weights are [headDim]=64, tensor is [headDim, numHeads, N]
q = sa.QNorm.Forward(ctx, q, opts.eps)
k = sa.KNorm.Forward(ctx, k, opts.eps)
// Apply rotary position embeddings with vision-style 2D positions.
// ggml's vision RoPE uses two position dimensions (H/W) with half-rotation pairs.
// We provide H/W sections and leave the remaining sections empty.
ropeFreqBase := float32(10000.0)
section := opts.headDim / 4
if section <= 0 {
section = 1
}
sections := []int{section, section, 0, 0}
q = nn.RoPE(ctx, q, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
k = nn.RoPE(ctx, k, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
// Try flash attention first (ScaledDotProductAttention), fall back to manual
if sdpa, ok := q.(ml.ScaledDotProductAttention); ok {
attention := sdpa.ScaledDotProductAttention(ctx, k, v, nil, nil, nil, scale, false)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
slog.Warn("glmocr: vision attention falling back to manual attention",
"batchSize", batchSize, "numHeads", opts.numHeads,
"hint", "set OLLAMA_FLASH_ATTENTION=1 to enable flash attention")
// Manual attention fallback
// q, k, v are [headDim, numHeads, batchSize] - GGML treats as 4D with implicit dim 3 = 1
q = q.Permute(ctx, 0, 2, 1, 3)
k = k.Permute(ctx, 0, 2, 1, 3)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
// Attention scores
kq := k.MulmatFullPrec(ctx, q)
kq = kq.Scale(ctx, scale)
kq = kq.Softmax(ctx)
// Attention output: v @ kq (note: v first)
kqv := v.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
// SwiGLU: down(silu(gate(x)) * up(x))
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, gate)
}
type VisionBlock struct {
Norm1 *nn.RMSNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
Norm2 *nn.RMSNorm `gguf:"ln2"`
MLP *VisionMLP
}
func (b *VisionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Pre-norm architecture
residual := hiddenStates
hiddenStates = b.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = b.SelfAttention.Forward(ctx, hiddenStates, positions, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = b.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = b.MLP.Forward(ctx, hiddenStates)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type VisionDownsample struct {
*nn.Conv2D
}
func (d *VisionDownsample) Forward(ctx ml.Context, hiddenStates ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
// Apply spatial downsampling via Conv2D
// Input: [hidden_size, num_patches] where patches are in merge-block order
if d.Conv2D == nil || d.Weight == nil {
slog.Error("VisionDownsample weights not loaded - model may be corrupted or incompatible")
return hiddenStates // Return input unchanged as fallback
}
merge := opts.spatialMergeSize
numOutputTokens := (grid.Height / merge) * (grid.Width / merge)
// Step 1: Reshape to [hidden_size, merge, merge, num_output_tokens]
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, merge, merge, numOutputTokens)
// Step 2: Permute to [merge, merge, hidden_size, num_output_tokens]
// ggml semantics: result.ne[perm[i]] = input.ne[i]
// So permute(2,0,1,3) on [1024,2,2,N] gives: ne[2]=1024, ne[0]=2, ne[1]=2, ne[3]=N -> [2,2,1024,N]
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
// Step 3: Apply Conv2D without bias (bias added after reshape)
// Note: ggml_conv_2d takes (kernel, input) - kernel must be receiver in ollama
s0, s1 := merge, merge
p0, p1 := 0, 0
d0, d1 := 1, 1
hiddenStates = d.Weight.Conv2D(ctx, hiddenStates, s0, s1, p0, p1, d0, d1)
// Step 4: Reshape to [out_hidden_size, num_output_tokens]
hiddenStates = hiddenStates.Reshape(ctx, opts.outHiddenSize, numOutputTokens)
// Step 5: Add bias after reshape
// Reshape bias from [out_hidden_size] to [out_hidden_size, 1] for proper broadcasting
if d.Bias != nil {
hiddenStates = hiddenStates.Add(ctx, d.Bias.Reshape(ctx, opts.outHiddenSize, 1))
}
return hiddenStates
}
type PatchMerger struct {
// GGUF tags align with mm.* keys used by the model
Proj *nn.Linear `gguf:"model.fc"` // mm.model.fc.weight
PostLN *nn.LayerNorm `gguf:"post_norm"` // mm.post_norm.weight/bias
GateProj *nn.Linear `gguf:"gate"` // mm.gate.weight
UpProj *nn.Linear `gguf:"up"` // mm.up.weight
DownProj *nn.Linear `gguf:"down"` // mm.down.weight
}
func (m *PatchMerger) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Linear projection
hiddenStates = m.Proj.Forward(ctx, hiddenStates)
// Post-projection layer norm + GELU ERF
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = hiddenStates.GELU_ERF(ctx)
// Force a copy to avoid in-place mutation issues with GELU_ERF
hiddenStates = hiddenStates.Contiguous(ctx)
// SwiGLU MLP: down(silu(gate(x)) * up(x))
gateOut := m.GateProj.Forward(ctx, hiddenStates)
upOut := m.UpProj.Forward(ctx, hiddenStates)
gate := gateOut.SILU(ctx, upOut)
return m.DownProj.Forward(ctx, gate)
}
type VisionModel struct {
PatchEmbed *VisionPatchEmbed
Blocks []VisionBlock `gguf:"blk"`
PostLN *nn.RMSNorm `gguf:"post_ln"`
// Note: Downsample is applied at the model level so mm.patch_merger stays separate
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
// Extract patch embeddings from flattened patches
hiddenStates := m.PatchEmbed.Forward(ctx, pixelValues, grid, m.VisionModelOptions)
// Create position IDs for RoPE (spatial grid)
// Patches are already in merge-block order from preprocessing
positions := m.createPositions(ctx, grid)
// Process through vision blocks
for _, block := range m.Blocks {
hiddenStates = block.Forward(ctx, hiddenStates, positions, m.VisionModelOptions)
}
// Post-layernorm
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, m.eps)
// Note: Downsample is now applied separately in Model.EncodeMultimodal
// so mm.patch_merger remains a distinct module
return hiddenStates
}
func (m *VisionModel) createPositions(ctx ml.Context, grid *Grid) ml.Tensor {
// Create spatial position IDs for vision RoPE
// Position layout: [height, width, height, width] - 4 sections for mrope
// Patches are in MERGE-BLOCK order after VisionPatchEmbed interleaving
// This follows the GLM-OCR rot_pos_emb layout
numPatches := grid.Height * grid.Width
mergeRatio := m.spatialMergeSize
// Build position arrays in merge-block order
// Each merge_ratio x merge_ratio block of patches is grouped together
hpos := make([]int32, numPatches)
wpos := make([]int32, numPatches)
ptr := 0
for y := 0; y < grid.Height; y += mergeRatio {
for x := 0; x < grid.Width; x += mergeRatio {
for dy := range mergeRatio {
for dx := range mergeRatio {
hpos[ptr] = int32(y + dy)
wpos[ptr] = int32(x + dx)
ptr++
}
}
}
}
// Build position arrays for 4 sections (mrope). ggml vision RoPE uses only H/W;
// keep remaining sections zeroed to match its conventions.
zeros := make([]int32, numPatches)
s := [][]int32{
hpos, // Section 0: height
wpos, // Section 1: width
zeros, // Section 2: unused
zeros, // Section 3: unused
}
return ctx.Input().FromInts(slices.Concat(s...), numPatches*4)
}
func newVisionModel(c fs.Config) *VisionModel {
hiddenSize := int(c.Uint("vision.embedding_length", 1024))
numHeads := int(c.Uint("vision.attention.head_count", 16))
numChannels := int(c.Uint("vision.num_channels", 3))
patchSize := int(c.Uint("vision.patch_size", 14))
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
imageSize := int(c.Uint("vision.image_size", 336))
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
outHiddenSize := int(c.Uint("vision.out_hidden_size", 1536))
intermediateSize := int(c.Uint("vision.intermediate_size", 4096))
eps := c.Float("vision.attention.layer_norm_rms_epsilon", 1e-5)
return &VisionModel{
Blocks: make([]VisionBlock, c.Uint("vision.block_count", 24)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: hiddenSize / numHeads,
numChannels: numChannels,
patchSize: patchSize,
temporalPatchSize: temporalPatchSize,
imageSize: imageSize,
spatialMergeSize: spatialMergeSize,
outHiddenSize: outHiddenSize,
intermediateSize: intermediateSize,
eps: eps,
},
}
}

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Transformer struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TransformerBlocks []TransformerBlock `gguf:"blk"`
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -59,7 +60,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -78,7 +79,7 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedTokenizer
}
vocabulary := model.Vocabulary{
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -104,8 +105,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedModel
}
var processor model.TextProcessor
vocabulary := model.Vocabulary{
var processor tokenizer.Tokenizer
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
processor = model.NewSentencePiece(&vocabulary)
processor = tokenizer.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
ImageProcessor
*VisionModel `gguf:"v"`
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -28,12 +29,12 @@ type Model struct {
var _ model.MultimodalProcessor = (*Model)(nil)
// Implement TextProcessor interface
var _ model.TextProcessor = (*Model)(nil)
var _ tokenizer.Tokenizer = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -32,8 +33,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -8,6 +8,7 @@ import (
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/glm4moelite"
_ "github.com/ollama/ollama/model/models/glmocr"
_ "github.com/ollama/ollama/model/models/gptoss"
_ "github.com/ollama/ollama/model/models/lfm2"
_ "github.com/ollama/ollama/model/models/llama"
@@ -19,5 +20,6 @@ import (
_ "github.com/ollama/ollama/model/models/qwen2"
_ "github.com/ollama/ollama/model/models/qwen25vl"
_ "github.com/ollama/ollama/model/models/qwen3"
_ "github.com/ollama/ollama/model/models/qwen3next"
_ "github.com/ollama/ollama/model/models/qwen3vl"
)

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -178,29 +179,6 @@ func New(c fs.Config) (model.Model, error) {
numHeads := int(c.Uint("attention.head_count"))
headDim := hiddenSize / numHeads
processor := model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
)
blockCount := int(c.Uint("block_count"))
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
layers := make([]EncoderLayer, blockCount)
@@ -219,8 +197,29 @@ func New(c fs.Config) (model.Model, error) {
}
return &Model{
TextProcessor: processor,
Layers: layers,
Tokenizer: tokenizer.NewWordPiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
),
Layers: layers,
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
const (
@@ -33,7 +34,7 @@ type Options struct {
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -44,28 +45,24 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
processor := model.NewBytePairEncoding(
&vocabulary,
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []DecoderLayer `gguf:"blk"`
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
*Model
poolingType pooling.Type
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
layers[i].MLP = &dense{}
}
m := embedModel{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -0,0 +1,103 @@
package qwen3next
import (
"errors"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
// with the attention layer requirements.
var ErrUnsupportedBatchLayout = errors.New("qwen3next: unsupported batch layout")
// FullAttention implements gated attention with QK normalization and sigmoid-gated output.
// Key differences from standard attention:
// - Q projection outputs 2x size (Q + gate interleaved)
// - Both Q and K have RMSNorm
// - Output is gated: attn * sigmoid(gate)
type FullAttention struct {
Query *nn.Linear `gguf:"attn_q"` // outputs [n_embd_head * 2, n_head]
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
// Use Dim() instead of Shape() for consistent behavior during graph construction
hiddenDim := hiddenStates.Dim(0)
batchSize := hiddenStates.Dim(1)
nSeqs := hiddenStates.Dim(2) // 0 if 2D tensor
if cache != nil && cache.IsSupportedForBatch() {
seqTokens := cache.seqTokens()
seqs := cache.numSeqs()
if seqTokens > 0 && seqs > 0 {
if nSeqs > 0 {
// 3D tensor: [hiddenDim, seqTokens, nSeqs]
if batchSize != seqTokens || nSeqs != seqs {
return nil, ErrUnsupportedBatchLayout
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
batchSize = seqTokens * seqs
} else if batchSize != seqTokens*seqs {
return nil, ErrUnsupportedBatchLayout
}
}
}
headDim := opts.headDim()
numHeads := opts.numHeads
// Q projection outputs query + gate interleaved
qFull := sa.Query.Forward(ctx, hiddenStates)
// Reshape to [headDim * 2, numHeads, batchSize]
qFull = qFull.Reshape(ctx, headDim*2, numHeads, batchSize)
// Split Q and gate along dimension 0
// Q: first headDim elements, gate: second headDim elements
query := qFull.Slice(ctx, 0, 0, headDim, 1)
gate := qFull.Slice(ctx, 0, headDim, headDim*2, 1)
// Make query contiguous for further operations
query = query.Contiguous(ctx, headDim, numHeads, batchSize)
// K and V projections
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
// Derive numKVHeads from tensor dimensions (per-layer value)
numKVHeads := key.Dim(0) / headDim
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
// Apply QK normalization
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
// Apply RoPE
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
// Standard attention computation
scale := opts.attentionScale
if scale == 0 {
scale = 1.0 / math.Sqrt(float64(headDim))
}
attention := nn.Attention(ctx, query, key, value, scale, cache)
// Flatten heads
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
// Apply sigmoid gate
// gate shape: [headDim, numHeads, batchSize] -> [headDim*numHeads, batchSize]
gate = gate.Contiguous(ctx, headDim*numHeads, batchSize)
gateSigmoid := gate.Sigmoid(ctx)
attention = attention.Mul(ctx, gateSigmoid)
return sa.Output.Forward(ctx, attention), nil
}

View File

@@ -0,0 +1,596 @@
package qwen3next
import (
"math"
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
var _ kvcache.Cache = (*HybridCache)(nil)
// HybridCache stores:
// - a standard causal KV cache for full attention layers
// - per-sequence conv state for linear attention layers
// - per-sequence delta state for linear attention layers
//
// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels]
// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads]
type HybridCache struct {
kv *kvcache.Causal
backend ml.Backend
dtype ml.DType
maxSequences int
// Conv state dimensions
convDim int // convKernelSize - 1
convChannels int // d_inner + 2 * num_k_heads * head_k_dim
// Delta state dimensions
deltaStateSize int // headVDim * headVDim * numVHeads
// slot mapping for recurrent state (copy-on-write)
slotForSeq map[int]int
refCount []int
freeSlots []int
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
// per-layer delta state buffers (allocated lazily)
deltaCtxs map[int]ml.Context
deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots]
// recurrent checkpoints (per slot)
checkpointCount int
checkpointMinPos int32
checkpointInterval int32
checkpointCtxSize int
checkpoints map[int]*slotCheckpointStore
pendingRestore map[int]checkpointRestore
curCheckpointPos []int32
curCheckpointSlots map[int]int
reserveCheckpoints bool
checkpointConvCtxs map[int]ml.Context
checkpointDeltaCtxs map[int]ml.Context
checkpointReserved map[int]struct{}
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
writableError error
}
func NewHybridCache(
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
convDim, convChannels, deltaStateSize int,
) *HybridCache {
return &HybridCache{
kv: kvcache.NewCausalCache(shift),
convDim: convDim,
convChannels: convChannels,
deltaStateSize: deltaStateSize,
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
deltaCtxs: make(map[int]ml.Context),
deltaStates: make(map[int]ml.Tensor),
checkpointCount: checkpointCountDefault,
checkpointMinPos: checkpointMinPosDefault,
checkpointInterval: checkpointIntervalDefault,
checkpoints: make(map[int]*slotCheckpointStore),
pendingRestore: make(map[int]checkpointRestore),
curCheckpointSlots: make(map[int]int),
checkpointConvCtxs: make(map[int]ml.Context),
checkpointDeltaCtxs: make(map[int]ml.Context),
checkpointReserved: make(map[int]struct{}),
}
}
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
c.checkpoints = make(map[int]*slotCheckpointStore)
c.pendingRestore = make(map[int]checkpointRestore)
c.curCheckpointPos = c.curCheckpointPos[:0]
c.curCheckpointSlots = make(map[int]int)
c.checkpointReserved = make(map[int]struct{})
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
if c.checkpointCtxSize < 8 {
c.checkpointCtxSize = 8
}
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *HybridCache) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
for _, ctx := range c.deltaCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointConvCtxs {
ctx.Close()
}
for _, ctx := range c.checkpointDeltaCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *HybridCache) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
// Derive equal-length sequence layout for recurrent layers
seqCounts := make(map[int]int)
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if _, ok := seqCounts[s]; !ok {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
if len(c.curSeqs) == 0 {
return nil
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
}
}
c.curSeqTokens = want
// When reserving memory for estimation, use fake slot assignments
if reserve {
c.curSlots = c.curSlots[:0]
slots := make([]int32, nSeqs)
for i := range nSeqs {
c.curSlots = append(c.curSlots, i)
slots[i] = int32(i)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
c.reserveCheckpoints = true
c.planCheckpoints(batch)
return nil
}
// Ensure slots exist for sequences in this batch
c.curSlots = c.curSlots[:0]
var newSlots []int
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
c.curSlots = append(c.curSlots, slot)
}
// Zero state for newly allocated slots
if len(newSlots) > 0 {
c.zeroSlots(ctx, newSlots)
}
// Create a tensor for the current slots
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
// Reset writable state for new forward pass
c.writableEnsured = false
c.writableError = nil
c.reserveCheckpoints = false
c.planCheckpoints(batch)
return nil
}
func (c *HybridCache) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, kvcache.ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *HybridCache) freeSlot(slot int) {
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// zeroSlots zeros the recurrent state for the given slots across all layers.
func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) {
if len(slots) == 0 {
return
}
inputCtx := ctx.Input()
slotIndices := make([]int32, len(slots))
for i, s := range slots {
slotIndices[i] = int32(s)
}
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
// Zero conv states
if len(c.convStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
for _, buf := range c.convStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
// Zero delta states
if len(c.deltaStates) > 0 {
zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots))
for _, buf := range c.deltaStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
}
// EnsureWritable ensures sequences have private slots (copy-on-write).
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
}
// Rebuild current slots tensor
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
for _, buf := range c.convStates {
rows := buf.Rows(ctx, src)
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
}
for _, buf := range c.deltaStates {
rows := buf.Rows(ctx, src)
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
}
}
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
// Copy-on-write for recurrent state
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
if c.validSlot(dstSlot) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
return
}
if c.validSlot(srcSlot) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *HybridCache) CanResume(seq int, pos int32) bool {
if !c.kv.CanResume(seq, pos) {
return false
}
if pos == 0 {
return true
}
return c.hasCheckpoint(seq, pos)
}
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if beginIndex > 0 && endIndex != math.MaxInt32 {
return kvcache.ErrNotSupported
}
if beginIndex > 0 {
restore, ok := c.pendingRestore[seq]
if !ok || restore.pos+1 != beginIndex {
return kvcache.ErrNotSupported
}
if !c.restoreComplete(restore) {
return kvcache.ErrNotSupported
}
// If the recurrent slot is shared, detach it before applying a restore.
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
newSlot, err := c.allocSlot()
if err != nil {
return err
}
ctx := c.backend.NewContext()
c.copyRecurrentState(ctx, slot, newSlot)
c.copyCheckpoints(ctx, slot, newSlot)
if len(c.convStates) > 0 || len(c.deltaStates) > 0 {
ctx.Compute()
}
ctx.Close()
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
restore.slot = newSlot
c.pendingRestore[seq] = restore
}
}
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
if beginIndex > 0 {
restore := c.pendingRestore[seq]
delete(c.pendingRestore, seq)
return c.applyCheckpointRestore(restore)
}
// Removal invalidates recurrent state
slot, ok := c.slotForSeq[seq]
delete(c.pendingRestore, seq)
if !ok {
return nil
}
if !c.validSlot(slot) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.clearCheckpoints(slot)
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
}
func (c *HybridCache) validSlot(slot int) bool {
return slot >= 0 && slot < len(c.refCount)
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.curSlotsInput
}
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
func (c *HybridCache) contiguousSlots() (int, bool) {
if len(c.curSlots) == 0 {
return 0, false
}
start := c.curSlots[0]
for i, s := range c.curSlots {
if s != start+i {
return 0, false
}
}
return start, true
}
func (c *HybridCache) seqTokens() int {
return c.curSeqTokens
}
func (c *HybridCache) numSeqs() int {
return len(c.curSeqs)
}
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
// Recurrent state must stay in F32 (ssm_conv kernels are F32-only).
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
c.convStates[layer] = buf
return buf
}
func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.deltaStates[layer]; ok {
return buf
}
if _, ok := c.deltaCtxs[layer]; !ok {
c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
// Recurrent delta state must stay in F32.
buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences)
c.deltaStates[layer] = buf
return buf
}
func (c *HybridCache) ensureWritableOnce(ctx ml.Context) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
}
// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs].
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil
}
// UpdateConvState writes a new conv state for current batch sequences.
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs())
srcF32 := src.Cast(ctx, ml.DTypeF32)
if start, ok := c.contiguousSlots(); ok {
// Fast path: contiguous slots allow a single view + copy
offset := start * buf.Stride(1)
view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs())
ctx.Forward(srcF32.Copy(ctx, view))
} else {
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
}
c.captureConvCheckpoint(ctx, layer, srcF32)
}
// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs].
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
c.ensureWritableOnce(ctx)
if c.writableError != nil {
return nil, c.writableError
}
buf := c.deltaBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil
}
// UpdateDeltaState writes a new delta state for current batch sequences.
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.deltaBuffer(ctx, layer)
src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs())
srcF32 := src.Cast(ctx, ml.DTypeF32)
if start, ok := c.contiguousSlots(); ok {
// Fast path: contiguous slots allow a single view + copy
offset := start * buf.Stride(1)
view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs())
ctx.Forward(srcF32.Copy(ctx, view))
} else {
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
}
c.captureDeltaCheckpoint(ctx, layer, srcF32)
}
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *HybridCache) Seqs() []int {
return slices.Clone(c.curSeqs)
}

View File

@@ -0,0 +1,498 @@
package qwen3next
import (
"log/slog"
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
const (
checkpointCountDefault = 32
checkpointMinPosDefault = int32(16)
checkpointIntervalDefault = int32(1280)
)
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
// memory usage while preserving prefix reuse for recurrent state.
type checkpointEntry struct {
pos int32
conv map[int]ml.Tensor
delta map[int]ml.Tensor
}
type slotCheckpointStore struct {
entries []checkpointEntry
size int
next int
lastPos int32
}
type checkpointRestore struct {
slot int
idx int
pos int32
}
func newSlotCheckpointStore(n int) *slotCheckpointStore {
entries := make([]checkpointEntry, n)
for i := range entries {
entries[i].pos = -1
}
return &slotCheckpointStore{
entries: entries,
lastPos: -1,
}
}
func (s *slotCheckpointStore) reset() {
s.size = 0
s.next = 0
s.lastPos = -1
for i := range s.entries {
s.entries[i].pos = -1
}
}
func (s *slotCheckpointStore) record(pos int32) int {
if len(s.entries) == 0 {
return -1
}
idx := s.next
s.next = (s.next + 1) % len(s.entries)
if s.size < len(s.entries) {
s.size++
}
s.entries[idx].pos = pos
s.lastPos = pos
return idx
}
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
bestIdx := -1
bestPos := int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 || pos >= targetPos {
continue
}
if pos > bestPos {
bestPos = pos
bestIdx = i
}
}
if bestIdx < 0 {
return -1, -1, false
}
return bestIdx, bestPos, true
}
func (s *slotCheckpointStore) pruneAfter(pos int32) {
if len(s.entries) == 0 {
s.size = 0
s.next = 0
s.lastPos = -1
return
}
size := 0
next := -1
minPos := int32(math.MaxInt32)
minIdx := 0
for i := range s.entries {
if s.entries[i].pos > pos {
s.entries[i].pos = -1
}
if s.entries[i].pos >= 0 {
size++
if s.entries[i].pos < minPos {
minPos = s.entries[i].pos
minIdx = i
}
} else if next == -1 {
next = i
}
}
s.size = size
if size == 0 {
s.next = 0
s.lastPos = -1
return
}
if next != -1 {
s.next = next
} else {
// Full ring: overwrite the oldest checkpoint next.
s.next = minIdx
}
s.lastPos = pos
}
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
minPos = int32(math.MaxInt32)
maxPos = int32(-1)
for i := range s.entries {
pos := s.entries[i].pos
if pos < 0 {
continue
}
size++
if pos < minPos {
minPos = pos
}
if pos > maxPos {
maxPos = pos
}
}
if size == 0 {
minPos = -1
maxPos = -1
}
return size, minPos, maxPos, s.lastPos
}
func (c *HybridCache) planCheckpoints(batch input.Batch) {
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
c.curCheckpointPos = c.curCheckpointPos[:0]
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
return
}
if cap(c.curCheckpointPos) < len(c.curSeqs) {
c.curCheckpointPos = make([]int32, len(c.curSeqs))
} else {
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
}
for i := range c.curCheckpointPos {
c.curCheckpointPos[i] = -1
}
for k := range c.curCheckpointSlots {
delete(c.curCheckpointSlots, k)
}
posMax := make(map[int]int32, len(c.curSeqs))
for i, seq := range batch.Sequences {
pos := batch.Positions[i]
if cur, ok := posMax[seq]; !ok || pos > cur {
posMax[seq] = pos
}
}
for i, seq := range c.curSeqs {
pos, ok := posMax[seq]
if !ok {
continue
}
if pos < c.checkpointMinPos {
continue
}
slot := c.curSlots[i]
store := c.checkpointStore(slot)
lastPos := store.lastPos
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
c.curCheckpointPos[i] = pos
}
}
}
func (c *HybridCache) checkpointStore(slot int) *slotCheckpointStore {
store, ok := c.checkpoints[slot]
if ok {
return store
}
store = newSlotCheckpointStore(c.checkpointCount)
c.checkpoints[slot] = store
return store
}
func (c *HybridCache) checkpointIndexForSlot(slot int, pos int32) int {
if c.checkpointCount == 0 {
return -1
}
if idx, ok := c.curCheckpointSlots[slot]; ok {
return idx
}
store := c.checkpointStore(slot)
idx := store.record(pos)
if idx >= 0 {
c.curCheckpointSlots[slot] = idx
}
return idx
}
func (c *HybridCache) hasCheckpoint(seq int, pos int32) bool {
if pos <= 0 {
return false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return false
}
store, ok := c.checkpoints[slot]
if !ok {
return false
}
_, _, ok = store.bestIndex(pos)
return ok
}
func (c *HybridCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
if targetPos <= 0 {
return 0, false
}
slot, ok := c.slotForSeq[seq]
if !ok {
return 0, false
}
store, ok := c.checkpoints[slot]
if !ok {
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
return 0, false
}
idx, pos, ok := store.bestIndex(targetPos)
if !ok {
size, minPos, maxPos, lastPos := store.window()
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
"min", minPos, "max", maxPos, "last", lastPos)
return 0, false
}
c.pendingRestore[seq] = checkpointRestore{
slot: slot,
idx: idx,
pos: pos,
}
return pos + 1, true
}
func (c *HybridCache) applyCheckpointRestore(restore checkpointRestore) error {
entry, ok := c.restoreEntry(restore)
if !ok {
return kvcache.ErrNotSupported
}
ctx := c.backend.NewContext()
defer ctx.Close()
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
for layer, src := range entry.conv {
buf := c.convBuffer(ctx, layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
for layer, src := range entry.delta {
buf := c.deltaBuffer(ctx, layer)
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
}
if len(entry.conv) > 0 || len(entry.delta) > 0 {
ctx.Compute()
}
store := c.checkpoints[restore.slot]
store.pruneAfter(restore.pos)
return nil
}
func (c *HybridCache) restoreComplete(restore checkpointRestore) bool {
_, ok := c.restoreEntry(restore)
return ok
}
func (c *HybridCache) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
store, ok := c.checkpoints[restore.slot]
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
return nil, false
}
entry := &store.entries[restore.idx]
if entry.pos < 0 {
return nil, false
}
if !c.entryComplete(entry) {
return nil, false
}
return entry, true
}
func (c *HybridCache) entryComplete(entry *checkpointEntry) bool {
for layer := range c.convStates {
if entry.conv == nil || entry.conv[layer] == nil {
return false
}
}
for layer := range c.deltaStates {
if entry.delta == nil || entry.delta[layer] == nil {
return false
}
}
return true
}
func (c *HybridCache) clearCheckpoints(slot int) {
if store, ok := c.checkpoints[slot]; ok {
store.reset()
}
}
func (c *HybridCache) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
if c.checkpointCount == 0 {
return
}
srcStore, ok := c.checkpoints[srcSlot]
if !ok || srcStore.size == 0 {
return
}
dstStore := c.checkpointStore(dstSlot)
dstStore.size = srcStore.size
dstStore.next = srcStore.next
dstStore.lastPos = srcStore.lastPos
for i := range srcStore.entries {
srcEntry := &srcStore.entries[i]
dstEntry := &dstStore.entries[i]
dstEntry.pos = srcEntry.pos
if srcEntry.conv != nil {
if dstEntry.conv == nil {
dstEntry.conv = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.conv {
dst := c.ensureCheckpointConv(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
if srcEntry.delta != nil {
if dstEntry.delta == nil {
dstEntry.delta = make(map[int]ml.Tensor)
}
for layer, src := range srcEntry.delta {
dst := c.ensureCheckpointDelta(layer, dstEntry)
ctx.Forward(src.Copy(ctx, dst))
}
}
}
}
func (c *HybridCache) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointConv(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointConv(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *HybridCache) captureDeltaCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
if c.checkpointCount == 0 {
return
}
if c.reserveCheckpoints {
c.reserveCheckpointDelta(layer)
return
}
if len(c.curCheckpointPos) == 0 {
return
}
for i, pos := range c.curCheckpointPos {
if pos < 0 {
continue
}
slot := c.curSlots[i]
idx := c.checkpointIndexForSlot(slot, pos)
if idx < 0 {
continue
}
entry := &c.checkpoints[slot].entries[idx]
dst := c.ensureCheckpointDelta(layer, entry)
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
ctx.Forward(seqSlice.Copy(ctx, dst))
}
}
func (c *HybridCache) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
if entry.conv == nil {
entry.conv = make(map[int]ml.Tensor)
}
if t, ok := entry.conv[layer]; ok {
return t
}
ctx, ok := c.checkpointConvCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointConvCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
entry.conv[layer] = t
return t
}
func (c *HybridCache) ensureCheckpointDelta(layer int, entry *checkpointEntry) ml.Tensor {
if entry.delta == nil {
entry.delta = make(map[int]ml.Tensor)
}
if t, ok := entry.delta[layer]; ok {
return t
}
ctx, ok := c.checkpointDeltaCtxs[layer]
if !ok {
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
c.checkpointDeltaCtxs[layer] = ctx
}
t := ctx.Zeros(ml.DTypeF32, c.deltaStateSize, 1)
entry.delta[layer] = t
return t
}
func (c *HybridCache) reserveCheckpointConv(layer int) {
key := checkpointReserveKey(layer, 0)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointConv(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func (c *HybridCache) reserveCheckpointDelta(layer int) {
key := checkpointReserveKey(layer, 1)
if _, ok := c.checkpointReserved[key]; ok {
return
}
for slot := range c.maxSequences {
store := c.checkpointStore(slot)
for i := range store.entries {
entry := &store.entries[i]
_ = c.ensureCheckpointDelta(layer, entry)
}
}
c.checkpointReserved[key] = struct{}{}
}
func checkpointReserveKey(layer int, kind int) int {
return layer*2 + kind
}

View File

@@ -0,0 +1,300 @@
package qwen3next
import (
"errors"
"math"
"os"
"testing"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
func newTestBackend(tb testing.TB) ml.Backend {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "*.gguf")
if err != nil {
tb.Fatal(err)
}
if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
_ = f.Close()
tb.Fatal(err)
}
if err := f.Close(); err != nil {
tb.Fatal(err)
}
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
if err != nil {
tb.Fatal(err)
}
tb.Cleanup(func() {
b.Close()
})
return b
}
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
store := newSlotCheckpointStore(2)
store.record(10)
store.record(20)
_, pos, ok := store.bestIndex(15)
if !ok || pos != 10 {
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
}
store.record(30) // overwrite oldest (10)
if _, _, ok := store.bestIndex(15); ok {
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
}
_, pos, ok = store.bestIndex(40)
if !ok || pos != 30 {
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
}
}
func TestHybridCachePrepareRestore(t *testing.T) {
cache := NewHybridCache(nil, 1, 1, 1)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
store := cache.checkpointStore(0)
store.record(5)
store.record(9)
store.record(15)
restorePos, ok := cache.PrepareRestore(1, 12)
if !ok {
t.Fatalf("expected restore ok")
}
if restorePos != 10 {
t.Fatalf("expected restorePos 10, got %d", restorePos)
}
rest, ok := cache.pendingRestore[1]
if !ok {
t.Fatalf("expected pending restore entry")
}
if rest.pos != 9 {
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
}
}
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
store.pruneAfter(20)
if store.lastPos != 20 {
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
}
_, pos, ok := store.bestIndex(25)
if !ok || pos != 20 {
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
}
_, pos, ok = store.bestIndex(35)
if !ok || pos != 20 {
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
}
}
func TestHybridCacheRestoreDetachesSharedSlot(t *testing.T) {
backend := newTestBackend(t)
cache := NewHybridCache(nil, 1, 2, 2)
cache.Init(backend, ml.DTypeF16, 2, 8, 2)
cache.slotForSeq[1] = 0
cache.slotForSeq[2] = 0
cache.refCount[0] = 2
cache.refCount[1] = 0
cache.freeSlots = []int{1}
store := cache.checkpointStore(0)
idx := store.record(9)
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
if err := cache.Remove(1, 10, math.MaxInt32); err != nil {
t.Fatalf("Remove failed: %v", err)
}
if cache.slotForSeq[1] == cache.slotForSeq[2] {
t.Fatalf("expected restore to detach shared slot, got same slot %d", cache.slotForSeq[1])
}
if cache.slotForSeq[1] != 1 {
t.Fatalf("expected seq 1 to move to slot 1, got %d", cache.slotForSeq[1])
}
if cache.slotForSeq[2] != 0 {
t.Fatalf("expected seq 2 to remain on slot 0, got %d", cache.slotForSeq[2])
}
if cache.refCount[0] != 1 || cache.refCount[1] != 1 {
t.Fatalf("unexpected refCounts: slot0=%d slot1=%d", cache.refCount[0], cache.refCount[1])
}
if _, ok := cache.pendingRestore[1]; ok {
t.Fatalf("expected pending restore to be cleared")
}
}
func TestHybridCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
cache := NewHybridCache(nil, 1, 2, 2)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
// Simulate that layer 0 has both conv and delta state (so entryComplete expects both)
cache.convStates[0] = nil // placeholder to indicate layer 0 exists
cache.deltaStates[0] = nil // placeholder to indicate layer 0 exists
store := cache.checkpointStore(0)
idx := store.record(9)
entry := &store.entries[idx]
// Only set conv checkpoint, not delta - making it incomplete
entry.conv = map[int]ml.Tensor{0: nil}
// entry.delta is not set, so checkpoint is incomplete
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
err := cache.Remove(1, 10, math.MaxInt32)
if !errors.Is(err, kvcache.ErrNotSupported) {
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
}
}
func TestHybridCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
cache := NewHybridCache(nil, 1, 2, 2)
cache.checkpointCount = 3
cache.checkpoints = make(map[int]*slotCheckpointStore)
cache.pendingRestore = make(map[int]checkpointRestore)
cache.slotForSeq[1] = 0
cache.refCount = []int{1}
cache.freeSlots = nil
// Don't set convStates/deltaStates - with no layers to check,
// entryComplete will return true as long as entry.pos >= 0
store := cache.checkpointStore(0)
idx := store.record(9)
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
// Test that restoreComplete returns true when no layers need checkpoints
restore := cache.pendingRestore[1]
if !cache.restoreComplete(restore) {
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
}
}
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
// Test that ring buffer wrap-around reuses entries without clearing maps.
store := newSlotCheckpointStore(3)
// Fill the buffer
store.record(10)
store.record(20)
store.record(30)
// Create fake tensor data in the first entry's maps
store.entries[0].conv = make(map[int]ml.Tensor)
store.entries[0].conv[0] = nil // Simulated tensor reference
store.entries[0].delta = make(map[int]ml.Tensor)
store.entries[0].delta[0] = nil // Simulated tensor reference
// Record another entry, which should wrap around and overwrite entry 0
store.record(40)
// Verify the maps are still present (we reuse tensors)
if store.entries[0].conv == nil {
t.Fatalf("expected conv map to be preserved on reuse")
}
if store.entries[0].delta == nil {
t.Fatalf("expected delta map to be preserved on reuse")
}
// Verify the new position was recorded
if store.entries[0].pos != 40 {
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
}
}
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
// Test behavior when buffer is exactly at capacity
store := newSlotCheckpointStore(2)
idx1 := store.record(10)
idx2 := store.record(20)
if idx1 != 0 || idx2 != 1 {
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
}
if store.size != 2 {
t.Fatalf("expected size 2, got %d", store.size)
}
// Verify both checkpoints are accessible
_, pos1, ok1 := store.bestIndex(15)
_, pos2, ok2 := store.bestIndex(25)
if !ok1 || pos1 != 10 {
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
}
if !ok2 || pos2 != 20 {
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
}
}
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
// Test behavior with zero-size buffer
store := newSlotCheckpointStore(0)
idx := store.record(10)
if idx != -1 {
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
}
_, _, ok := store.bestIndex(15)
if ok {
t.Fatalf("expected no checkpoint for empty buffer")
}
}
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
// Test pruning that removes all checkpoints
store := newSlotCheckpointStore(3)
store.record(10)
store.record(20)
store.record(30)
// Prune everything by setting threshold below all positions
store.pruneAfter(5)
if store.size != 0 {
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
}
// When all checkpoints are pruned, lastPos is reset to -1
if store.lastPos != -1 {
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
}
_, _, ok := store.bestIndex(100)
if ok {
t.Fatalf("expected no checkpoint after pruning all")
}
}

View File

@@ -0,0 +1,472 @@
package qwen3next
import (
"errors"
"log/slog"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
const chunkSize = 64
// TriType constants for triangular matrix operations
const (
TriTypeUpperDiag = 0
TriTypeUpper = 1
TriTypeLowerDiag = 2
TriTypeLower = 3
)
// convKernel wraps the 1D convolution kernel tensor
type convKernel struct {
Weight ml.Tensor `gguf:"weight"`
}
// Masks holds pre-computed mask tensors for chunked attention
type Masks struct {
Causal ml.Tensor // Lower triangular [chunkSize, chunkSize]
Identity ml.Tensor // Diagonal [chunkSize, chunkSize]
Diag ml.Tensor // causal + identity
}
// GatedDeltaNet implements linear attention with SSM convolution and recurrent state.
// It implements the Operator interface directly.
type GatedDeltaNet struct {
// Optimized path: pre-split QKV and gate
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
SSMOut *nn.Linear `gguf:"ssm_out"`
// Layer index for cache access (set during model construction)
Layer int
}
// createMasks builds the constant mask tensors (called once, reused for all chunks)
func createMasks(ctx ml.Context) *Masks {
ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
ones = ones.Fill(ctx, 1.0)
causalMask := ones.Tri(ctx, TriTypeLower)
onesVec := ctx.Input().Zeros(ml.DTypeF32, chunkSize)
onesVec = onesVec.Fill(ctx, 1.0)
identity := onesVec.Diag(ctx)
diagMask := causalMask.Add(ctx, identity)
return &Masks{
Causal: causalMask,
Identity: identity,
Diag: diagMask,
}
}
func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
layer := gdn.Layer
nSeqTokens := hiddenStates.Dim(1)
nSeqs := hiddenStates.Dim(2)
if cache != nil && cache.IsSupportedForBatch() {
seqTokens := cache.seqTokens()
seqs := cache.numSeqs()
if seqTokens > 0 && seqs > 0 {
if nSeqs > 1 {
if nSeqTokens != seqTokens || nSeqs != seqs {
return nil, ErrUnsupportedBatchLayout
}
} else {
if nSeqTokens != seqTokens*seqs {
return nil, ErrUnsupportedBatchLayout
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
nSeqTokens = seqTokens
nSeqs = seqs
}
}
}
headKDim := opts.ssmDState
numKHeads := opts.ssmNGroup
numVHeads := opts.ssmDtRank
headVDim := opts.ssmDInner / numVHeads
convKernelSize := opts.convKernelSize
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
}
// Optimized path: pre-split QKV and gate
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
baNewDim := 2 * numVHeads / numKHeads
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
// Split beta and alpha
betaSize := numVHeads / numKHeads
alphaSize := numVHeads / numKHeads
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
// Reshape to merge head dimensions
beta := b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs)
alpha := a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
// Compute gate: softplus(alpha + dt_bias) * -A
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
alphaSoftplus := alphaBiased.Softplus(ctx)
gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
// Get conv state from cache
convStates, err := cache.ConvState(ctx, layer)
if err != nil {
// Log this - if it happens, short-term context will be lost
slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
convStates = ctx.Input().Zeros(ml.DTypeF32, convKernelSize-1, qkvDim, nSeqs)
}
// Reshape conv states
convStates = convStates.Reshape(ctx, convKernelSize-1, qkvDim, nSeqs)
// Concatenate with input for convolution
convInput := convStates.Concat(ctx, qkvMixed, 0)
// Save new conv state (last convKernelSize-1 tokens)
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
cache.UpdateConvState(ctx, layer, lastConvStates)
// Apply SSM convolution (kernel must be F32 for Metal)
convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
convOutput = convOutput.SILU(ctx)
// Reshape for extraction
convQKVMix := convOutput.Contiguous(ctx, qkvDim, nSeqTokens*nSeqs)
// Extract convolved Q, K, V
qConv := convQKVMix.Slice(ctx, 0, 0, headKDim*numKHeads, 1)
kConv := convQKVMix.Slice(ctx, 0, headKDim*numKHeads, 2*headKDim*numKHeads, 1)
vConv := convQKVMix.Slice(ctx, 0, 2*headKDim*numKHeads, qkvDim, 1)
// Reshape to 4D
qConv = qConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
kConv = kConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
// Get delta state from cache
state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads)
if err != nil {
// Log this - if it happens frequently, context will degrade
slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
state = ctx.Input().Zeros(ml.DTypeF32, headVDim, headVDim*numVHeads, nSeqs)
}
state = state.Reshape(ctx, headVDim, headVDim*numVHeads, 1, nSeqs)
// Repeat interleave Q and K if numKHeads != numVHeads
if numKHeads != numVHeads {
repeatFactor := numVHeads / numKHeads
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
}
// Choose computation mode based on sequence length
var attnOut ml.Tensor
if nSeqTokens == 1 {
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
} else {
// Use pre-computed masks from opts (created once in Model.Forward)
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
}
// Apply gated normalization
attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
// norm(attnOut, z) = RMSNorm(attnOut) * silu(z)
attnOutNorm := gdn.SSMNorm.Forward(ctx, attnOut2D, opts.eps)
zSilu := z2D.SILU(ctx)
attnOutGated := attnOutNorm.Mul(ctx, zSilu)
// Reshape for output projection
finalOutput := attnOutGated.Reshape(ctx, headVDim*numVHeads, nSeqTokens, nSeqs)
out := gdn.SSMOut.Forward(ctx, finalOutput)
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
}
// deltaNetAutoregressive implements single-token state update.
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
func (gdn *GatedDeltaNet) deltaNetAutoregressive(
ctx ml.Context,
q, k, v, gate, beta, state ml.Tensor,
opts *Options,
layer int,
cache *HybridCache,
) ml.Tensor {
numVHeads := v.Dim(1)
headVDim := v.Dim(0)
nSeqs := q.Dim(3)
// L2 normalize Q and K
q = q.L2Norm(ctx, opts.eps)
k = k.L2Norm(ctx, opts.eps)
// Scale Q
scale := 1.0 / math.Sqrt(float64(headVDim))
q = q.Scale(ctx, scale)
// Sigmoid beta
beta = beta.Sigmoid(ctx)
// Reshape state: [headVDim, headVDim, numVHeads, nSeqs]
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
// Reshape gate and beta for broadcasting
gT := gate.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
betaT := beta.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
// Apply exponential to gate
gT = gT.Exp(ctx)
// state = state * g_t
state = state.Mul(ctx, gT)
// kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2)
kTUnsqueezed := k.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
kvMem := state.Mul(ctx, kTUnsqueezed)
// Sum over dim=-2 (second dimension after permute)
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
kvMem = kvMem.SumRows(ctx)
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3)
// v_t with singleton dimension
vT := v.Reshape(ctx, headVDim, 1, numVHeads, nSeqs)
// delta = (v_t - kv_mem) * beta_t
vDiff := vT.Sub(ctx, kvMem)
delta := vDiff.Mul(ctx, betaT)
// state = state + k_t.unsqueeze(-1) * delta
kTUnsqueezedBroad := kTUnsqueezed.Repeat4D(ctx, headVDim, headVDim, numVHeads, nSeqs)
kTDelta := kTUnsqueezedBroad.Mul(ctx, delta)
state = state.Add(ctx, kTDelta)
// core_attn_out = (state * q_t.unsqueeze(-1)).sum(dim=-2)
qTUnsqueezed := q.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
stateQ := state.Mul(ctx, qTUnsqueezed)
stateQ = stateQ.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
coreAttnOut := stateQ.SumRows(ctx)
coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs)
}
// deltaNetChunked implements chunked computation for prefill.
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
func (gdn *GatedDeltaNet) deltaNetChunked(
ctx ml.Context,
q, k, v, gate, beta, state ml.Tensor,
masks *Masks,
opts *Options,
layer int,
cache *HybridCache,
) ml.Tensor {
headKDim := q.Dim(0)
numVHeads := v.Dim(1)
headVDim := v.Dim(0)
nTokens := q.Dim(2)
nSeqs := q.Dim(3)
// L2 normalize Q and K
q = q.L2Norm(ctx, opts.eps)
k = k.L2Norm(ctx, opts.eps)
// Scale Q
scale := 1.0 / math.Sqrt(float64(headVDim))
q = q.Scale(ctx, scale)
// Sigmoid beta
beta = beta.Sigmoid(ctx)
// Permute tensors for chunked computation
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
gate = gate.Permute(ctx, 2, 0, 3, 1).Contiguous(ctx, nTokens, 1, numVHeads, nSeqs)
beta = beta.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
// Compute padding
pad := (chunkSize - nTokens%chunkSize) % chunkSize
nChunks := (nTokens + pad) / chunkSize
// Pad tensors
if pad > 0 {
q = q.Pad(ctx, 0, pad, 0, 0)
k = k.Pad(ctx, 0, pad, 0, 0)
v = v.Pad(ctx, 0, pad, 0, 0)
gate = gate.Pad(ctx, pad, 0, 0, 0)
beta = beta.Pad(ctx, 0, pad, 0, 0)
}
// Use pre-computed masks (passed in, not recreated)
causalMask := masks.Causal
identity := masks.Identity
diagMask := masks.Diag
identity4D := identity.Reshape(ctx, chunkSize, chunkSize, 1, 1)
// v_beta = v * beta, k_beta = k * beta
vBeta := v.Mul(ctx, beta)
kBeta := k.Mul(ctx, beta)
// Reshape for chunked computation
q = q.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
k = k.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs)
gate = gate.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
// g_cumsum = cumsum(gate)
gCumsum := gate.CumSum(ctx)
// Compute decay mask
gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
gcsJ := gCumsum.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
gcsBroadcast := gcsJ.Repeat4D(ctx, chunkSize, chunkSize, nChunks, numVHeads*nSeqs)
decayMask := gcsBroadcast.Sub(ctx, gcsI)
decayMask = decayMask.Mul(ctx, diagMask)
decayMask = decayMask.Exp(ctx)
decayMask = decayMask.Mul(ctx, diagMask)
// k @ k_beta^T
kMulKBeta := k.Mulmat(ctx, kBeta)
// k_decay = k @ k_beta^T * decay_mask
kDecay := kMulKBeta.Mul(ctx, decayMask)
// attn = -k_decay * causal_mask
attn := kDecay.Neg(ctx).Mul(ctx, causalMask)
// Triangular solve: (I - attn_lower)^-1 @ attn
attnLower := attn.Mul(ctx, causalMask)
lhs := attnLower.Neg(ctx).Add(ctx, identity4D)
linSolve := lhs.SolveTri(ctx, attn, true, true, false)
attn = linSolve.Mul(ctx, causalMask)
attn = attn.Add(ctx, identity4D)
// v = v_beta^T @ attn
vBetaT := vBeta.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
v = vBetaT.Mulmat(ctx, attn)
// Compute g_exp for state update
gCumsumT := gCumsum.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
gExp := gCumsumT.Exp(ctx)
// kbeta_gexp = k_beta * g_exp
kBetaGExp := kBeta.Mul(ctx, gExp)
// k_cumdecay = attn @ kbeta_gexp^T
kBetaGExpT := kBetaGExp.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
kCumdecay := attn.Mulmat(ctx, kBetaGExpT)
kCumdecay = kCumdecay.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Pre-compute attn_kq = (k @ q) * decay_mask * diag_mask
attnKQ := k.Mulmat(ctx, q)
attnKQ = attnKQ.Mul(ctx, decayMask)
attnKQ = attnKQ.Mul(ctx, diagMask)
// Pre-compute g_last and key_gdiff
// g_last = view of last element in g_cumsum along chunk_size dimension
// We need to get the last row of gCumsum: shape [chunkSize, 1, nChunks, H*n_seqs] -> [1, 1, nChunks, H*n_seqs]
gLast := gCumsum.Slice(ctx, 0, chunkSize-1, chunkSize, 1).Contiguous(ctx, 1, 1, nChunks, numVHeads*nSeqs)
gLastExp := gLast.Exp(ctx)
// g_diff = -(g_cumsum - g_last) = g_last - g_cumsum
gDiff := gCumsum.Neg(ctx).Add(ctx, gLast)
gDiffExp := gDiff.Exp(ctx)
// Reshapes g_diff_exp to [1, chunkSize, nChunks, ...]
gDiffExpReshaped := gDiffExp.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Process chunks and update state
var coreAttnOut ml.Tensor
newState := state
for chunk := range nChunks {
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1)
gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
// state^T - permute is needed but Contiguous creates a copy
stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
// v_prime = k_cumdecay @ state
vPrime := stateT.Mulmat(ctx, kCumdecayChunk)
// v_new = v - v_prime
vNew := vChunk.Sub(ctx, vPrime)
vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// attn_inter = (q * g_exp) @ state
qGExp := qChunk.Mul(ctx, gExpChunk)
attnInter := stateT.Mulmat(ctx, qGExp)
// core_attn_out = attn_inter + attn @ v_new
vAttn := vNewT.Mulmat(ctx, attnChunk)
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
if coreAttnOut == nil {
coreAttnOut = coreAttnOutChunk
} else {
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
}
// Update state for next chunk
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
// state = state * g_last + kgdmulvnew
gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs)
newState = newState.Mul(ctx, gExpLastReshaped)
newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs))
}
// Final reshape
coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
// Slice to remove padding
if pad > 0 {
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
}
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs)
}

View File

@@ -0,0 +1,384 @@
package qwen3next
import (
"cmp"
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
// Options contains model configuration
type Options struct {
hiddenSize int
numHeads int
numKVHeads int
keyLength int
valueLength int
ropeDim int
eps float32
ropeBase float32
ropeScale float32
ropeType string
originalContextLength int
attentionScale float64
// MoE config
numExperts int
numExpertsUsed int
normTopKProb bool
// Linear attention (Gated Delta Net) config
ssmDInner int // d_inner = head_v_dim * num_v_heads
ssmDState int // head_k_dim
ssmNGroup int // num_k_heads
ssmDtRank int // num_v_heads
convKernelSize int // SSM conv kernel size
// Per-layer type from GGUF metadata
isRecurrent []bool
// Pre-computed masks for chunked attention (created once per forward pass)
masks *Masks
}
func (o Options) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
opts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(attnFactor),
)
}
ropeDim := cmp.Or(o.ropeDim, o.headDim())
return nn.RoPE(ctx, states, positions, ropeDim, o.ropeBase, 1./o.ropeScale, opts...)
}
// Operator is the interface for attention-like operators
type Operator interface {
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
}
// MLP is the interface for feedforward networks
type MLP interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
}
// sparse implements MoE with shared experts
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
// Shared experts
SharedGateInp *nn.Linear `gguf:"ffn_gate_inp_shexp"`
SharedGate *nn.Linear `gguf:"ffn_gate_shexp"`
SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
}
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
if batchSize == 0 {
batchSize = 1
}
hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
// Router logits
routerLogits := mlp.Router.Forward(ctx, hiddenStates2D)
// Softmax routing weights
routingWeights := routerLogits.Softmax(ctx)
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
if opts.normTopKProb {
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
}
hiddenStates3D := hiddenStates2D.Reshape(ctx, hiddenStates2D.Dim(0), 1, hiddenStates2D.Dim(1))
// Expert computation with SILU activation
gateOut := mlp.Gate.Forward(ctx, hiddenStates3D, selectedExperts)
upOut := mlp.Up.Forward(ctx, hiddenStates3D, selectedExperts)
experts := gateOut.SILU(ctx, upOut)
experts = mlp.Down.Forward(ctx, experts, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
// Sum over experts
moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
// Add shared experts if present
if mlp.SharedUp != nil {
sharedGate := mlp.SharedGate.Forward(ctx, hiddenStates2D)
sharedUp := mlp.SharedUp.Forward(ctx, hiddenStates2D)
sharedOut := sharedGate.SILU(ctx, sharedUp)
sharedOut = mlp.SharedDown.Forward(ctx, sharedOut)
// Apply shared expert gating
if mlp.SharedGateInp != nil {
sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D)
sharedGateVal = sharedGateVal.SigmoidOut(ctx)
// Broadcast gate to match dimensions
sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0))
sharedOut = sharedOut.Mul(ctx, sharedGateVal)
}
moeOut = moeOut.Add(ctx, sharedOut)
}
return moeOut
}
// dense implements standard feedforward
type dense struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
// Layer represents a single transformer layer
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
AttentionPostNorm *nn.RMSNorm `gguf:"post_attention_norm"` // Post-attention norm before FFN
Operator Operator
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP MLP
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
residual := hiddenStates
// Pre-attention norm
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
// Attention (full or linear)
var err error
hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, positions, cache, opts)
if err != nil {
return nil, err
}
// Output projection for last layer
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
// First residual connection
hiddenStates = hiddenStates.Add(ctx, residual)
// Save for FFN residual
ffnResidual := hiddenStates
// Post-attention norm (before FFN)
hiddenStates = l.AttentionPostNorm.Forward(ctx, hiddenStates, opts.eps)
// FFN
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
// Second residual connection
return hiddenStates.Add(ctx, ffnResidual), nil
}
// Model is the main Qwen3-Next model
type Model struct {
model.Base
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Layers []Layer `gguf:"blk"`
*Options
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
cache := m.Cache.(*HybridCache)
// Create masks once per forward pass
m.Options.masks = createMasks(ctx)
for i, layer := range m.Layers {
cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
var err error
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, positions, outputs, cache, m.Options)
if err != nil {
return nil, err
}
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
var _ model.Model = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers)
// Get per-layer head counts (for detecting layer type)
type headCounts interface {
HeadCount() []uint64
HeadCountKV() []uint64
}
var isRecurrent []bool
var headCountKV []uint64
if hc, ok := c.(headCounts); ok {
headCountKV = hc.HeadCountKV()
}
isRecurrent = make([]bool, numLayers)
hasZero := false
hasFull := false
for i := range numLayers {
// If KV head count is 0, it's a recurrent layer
if i < len(headCountKV) && headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else if i < len(headCountKV) && headCountKV[i] > 0 {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
}
// Determine if MoE
isMoE := c.Uint("expert_count") > 0
for i := range layers {
if isRecurrent[i] {
layers[i].Operator = &GatedDeltaNet{Layer: i}
} else {
layers[i].Operator = &FullAttention{}
}
if isMoE {
layers[i].MLP = &sparse{}
} else {
layers[i].MLP = &dense{}
}
}
opts := &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: func() int {
for _, v := range headCountKV {
if v > 0 {
return int(v)
}
}
return 0
}(),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeType: c.String("rope.scaling.type"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
attentionScale: float64(c.Float("attention.scale")),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true),
ssmDInner: int(c.Uint("ssm.inner_size")),
ssmDState: int(c.Uint("ssm.state_size")),
ssmNGroup: int(c.Uint("ssm.group_count")),
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
convKernelSize: int(c.Uint("ssm.conv_kernel")),
isRecurrent: isRecurrent,
}
if opts.numKVHeads == 0 {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
}
// Calculate cache dimensions
convDim := max(0, opts.convKernelSize-1)
convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
headVDim := 0
numVHeads := opts.ssmDtRank
if numVHeads > 0 {
headVDim = opts.ssmDInner / numVHeads
}
deltaStateSize := headVDim * headVDim * numVHeads
// Validate dimension assumption: headKDim == headVDim is required for state computations
headKDim := opts.ssmDState
if headKDim != headVDim && headKDim > 0 && headVDim > 0 {
return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim)
}
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
// Qwen3 tokenizers typically set add_bos_token=false and bos_token=null.
// Default to false when the GGUF key is missing to avoid injecting a spurious BOS.
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
Layers: layers,
Options: opts,
}
m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize)
return &m, nil
}
func init() {
model.Register("qwen3next", New)
}

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
model.TextProcessor
tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

17
model/parsers/glmocr.go Normal file
View File

@@ -0,0 +1,17 @@
package parsers
import "github.com/ollama/ollama/api"
// GlmOcrParser is the GLM46 parser with thinking disabled.
type GlmOcrParser struct {
GLM46Parser
}
func (p *GlmOcrParser) HasThinkingSupport() bool {
return false
}
func (p *GlmOcrParser) Init(tools []api.Tool, _ *api.Message, _ *api.ThinkValue) []api.Tool {
p.tools = tools
return tools
}

View File

@@ -71,6 +71,8 @@ func ParserForName(name string) Parser {
return &FunctionGemmaParser{}
case "glm-4.7":
return &GLM47Parser{}
case "glm-ocr":
return &GlmOcrParser{}
case "lfm2":
return &LFM2Parser{hasThinkingSupport: false}
case "lfm2-thinking":

109
model/renderers/glmocr.go Normal file
View File

@@ -0,0 +1,109 @@
package renderers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
type GlmOcrRenderer struct{}
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(formatGLM47ToolJSON(d))
sb.WriteString("\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
}
enableThinking := false
thinkingExplicitlySet := false
if thinkValue != nil {
enableThinking = thinkValue.Bool()
thinkingExplicitlySet = true
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>\n")
sb.WriteString(message.Content)
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
sb.WriteString("/nothink")
}
case "assistant":
sb.WriteString("<|assistant|>\n")
if message.Thinking != "" {
sb.WriteString("<think>" + strings.TrimSpace(message.Thinking) + "</think>")
} else {
sb.WriteString("<think></think>")
}
if message.Content != "" {
sb.WriteString("\n" + strings.TrimSpace(message.Content))
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>" + toolCall.Function.Name)
sb.WriteString(renderGlmOcrToolArguments(toolCall.Function.Arguments))
sb.WriteString("</tool_call>")
}
}
sb.WriteString("\n")
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("\n<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>\n")
case "system":
sb.WriteString("<|system|>\n")
sb.WriteString(message.Content)
sb.WriteString("\n")
}
}
sb.WriteString("<|assistant|>\n")
if thinkingExplicitlySet && !enableThinking {
sb.WriteString("<think></think>\n")
}
return sb.String(), nil
}
func renderGlmOcrToolArguments(args api.ToolCallFunctionArguments) string {
var sb strings.Builder
for key, value := range args.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
}
return sb.String()
}

View File

@@ -167,12 +167,12 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
// only start a new user block if this is the first tool response
if i == 0 || filteredMessages[i-1].Role != "tool" {
sb.WriteString(imStartTag + "user\n")
sb.WriteString(imStartTag + "user")
}
sb.WriteString("<tool_response>\n")
sb.WriteString("\n<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>\n")
sb.WriteString("\n</tool_response>")
// close the user block only if this is the last tool response
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {

View File

@@ -1,6 +1,7 @@
package renderers
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
@@ -127,8 +128,7 @@ fahrenheit
<|im_start|>user
<tool_response>
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
</tool_response>
<|im_end|>
</tool_response><|im_end|>
<|im_start|>user
That sounds nice! What about New York?<|im_end|>
<|im_start|>assistant
@@ -233,8 +233,7 @@ I'll call double(1) and triple(2) for you.
</tool_response>
<tool_response>
{"number": 6}
</tool_response>
<|im_end|>
</tool_response><|im_end|>
<|im_start|>assistant
`,
},
@@ -280,8 +279,7 @@ call tool<|im_end|>
<|im_start|>user
<tool_response>
{"payload": {"foo": "bar"}}
</tool_response>
<|im_end|>
</tool_response><|im_end|>
<|im_start|>assistant
`,
},
@@ -337,6 +335,31 @@ func TestFormatToolCallArgument(t *testing.T) {
}
}
func TestQwen3CoderRendererToolResponseNoTrailingNewline(t *testing.T) {
msgs := []api.Message{
{Role: "user", Content: "call tool"},
{Role: "assistant", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{
Name: "echo",
Arguments: testArgs(map[string]any{"payload": "ok"}),
}},
}},
{Role: "tool", Content: "{\"payload\":\"ok\"}", ToolName: "echo"},
}
rendered, err := (&Qwen3CoderRenderer{}).Render(msgs, nil, nil)
if err != nil {
t.Fatal(err)
}
if strings.Contains(rendered, "</tool_response>\n<|im_end|>") {
t.Fatalf("expected no newline after </tool_response>, got:\n%s", rendered)
}
if !strings.Contains(rendered, "</tool_response><|im_end|>") {
t.Fatalf("expected </tool_response> to be immediately followed by <|im_end|>, got:\n%s", rendered)
}
}
func TestQwen3ToolDefinitionTypes(t *testing.T) {
tests := []struct {
name string

View File

@@ -82,6 +82,8 @@ func rendererForName(name string) Renderer {
return &FunctionGemmaRenderer{}
case "glm-4.7":
return &GLM47Renderer{}
case "glm-ocr":
return &GlmOcrRenderer{}
case "lfm2":
return &LFM2Renderer{IsThinking: false}
case "lfm2-thinking":

View File

@@ -1,249 +0,0 @@
package model
import (
"container/heap"
"fmt"
"log/slog"
"strconv"
"strings"
"github.com/ollama/ollama/logutil"
)
const spmWhitespaceSep = "▁"
type SentencePiece struct {
maxTokenLen int
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
}
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
var maxTokenLen int
for cnt := range vocab.Types {
switch vocab.Types[cnt] {
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
fallthrough
default:
counter[int(vocab.Types[cnt])] += 1
}
}
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePiece{
maxTokenLen: maxTokenLen,
vocab: vocab,
}
}
func (spm SentencePiece) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
if id := spm.vocab.Encode(text); id >= 0 {
ids = append(ids, id)
continue
}
q := &queue{}
heap.Init(q)
runes := []rune(text)
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
size: len(left) + len(right),
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(q, pair)
}
}
for q.Len() > 0 {
pair := heap.Pop(q).(*candidate)
left, right := merges[pair.a], merges[pair.b]
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
heap.Push(q, pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
heap.Push(q, pair)
}
}
for _, merge := range merges {
if token := string(merge.runes); token != "" {
id := spm.vocab.Encode(token)
if id >= 0 {
ids = append(ids, id)
continue
}
// Fallback to byte tokenization
var result []int32
for _, b := range []byte(token) {
byteToken := fmt.Sprintf("<0x%02X>", b)
unknownID := spm.vocab.Encode(byteToken)
if unknownID >= 0 {
result = append(result, unknownID)
} else {
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
}
}
ids = append(ids, result...)
}
}
}
if addSpecial {
ids = spm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
type candidate struct {
a, b int
score float32
size int
}
type queue []*candidate
func (q queue) Len() int { return len(q) }
func (q queue) Less(i, j int) bool {
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
}
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
func (q *queue) Push(x interface{}) {
item := x.(*candidate)
*q = append(*q, item)
}
func (q *queue) Pop() interface{} {
old := *q
n := len(old)
item := old[n-1]
*q = old[0 : n-1]
return item
}
func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizers that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
if err != nil {
return "", fmt.Errorf("failed to parse hex byte: %v", err)
}
if err := sb.WriteByte(byte(byteVal)); err != nil {
return "", err
}
} else {
if _, err := sb.WriteString(data); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "ids", ids, "string", sb.String())
return sb.String(), nil
}

View File

@@ -1,17 +0,0 @@
package model
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocabulary() *Vocabulary
}

View File

@@ -1,53 +0,0 @@
package model
import (
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWordPiece(t *testing.T) {
wpm := NewWordPiece(
&Vocabulary{
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
AddBOS: true,
AddEOS: true,
BOS: []int32{1},
EOS: []int32{2},
},
true, // lowercase
)
ids, err := wpm.Encode("Hello world!", true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
t.Errorf("unexpected ids (-want +got):\n%s", diff)
}
words, err := wpm.Decode(ids)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}
func TestWordPieceWords(t *testing.T) {
var wpm WordPiece
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}

View File

@@ -124,8 +124,17 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*In
}
if c.cache != nil {
if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) {
numPast = 0
if numPast > 0 {
// Recurrent caches use checkpoints to pick a safe resume position.
if cc, ok := c.cache.(kvcache.CheckpointCache); ok {
if restored, ok := cc.PrepareRestore(slot.Id, numPast); ok {
numPast = restored
} else {
numPast = 0
}
} else if !c.cache.CanResume(slot.Id, numPast) {
numPast = 0
}
}
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)

View File

@@ -37,6 +37,7 @@ import (
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/tokenizer"
_ "github.com/ollama/ollama/model/models"
)
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
decoder := func(tokenID int) string {
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
text, _ := tok.Decode([]int32{int32(tokenID)})
return text
}
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
if err != nil {
return nil, nil, nil, err
}
@@ -514,13 +515,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength)
nextBatch.seqs[seqIdx] = nil
continue
}
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []*input.Input{}
@@ -709,7 +703,6 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
seq.numPredicted++
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
seq.inputs = []*input.Input{nextToken}
nextBatchTokens[i] = nextToken
@@ -740,8 +733,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
if seq == nil || nextBatchTokens[i] == nil {
continue
}
// If the sequence was replaced while this batch was computing, discard results.
if activeBatch.seqs[i] != seq {
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
seq.lastUpdatedAt = t
seq.numPredicted++
if seq.numPredicted == 1 {
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
seq.startedAt = seq.lastUpdatedAt
@@ -766,7 +765,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
@@ -775,18 +774,25 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
if err != nil {
panic("failed to decode token")
}
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
if seq.logprobs {
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
}
seq.pendingResponses = append(seq.pendingResponses, piece)
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(i, llm.DoneReasonLength)
continue
}
sequence := strings.Join(seq.pendingResponses, "")
if ok, stop := common.FindStop(sequence, seq.stop); ok {
@@ -873,7 +879,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var grammar *sample.GrammarSampler
var err error
if req.Grammar != "" {
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return
@@ -1358,7 +1364,7 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) {
// Dummy load to get the backend wired up
f, err := os.CreateTemp("", "*.bin")
if err != nil {
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
return
}
defer f.Close()
@@ -1368,13 +1374,13 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) {
"general.architecture": "llama",
"tokenizer.ggml.model": "gpt2",
}, nil); err != nil {
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
return
}
m, err = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}})
if err != nil {
http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
return
}
slog.Debug("dummy model load took", "duration", time.Since(startLoad))

Some files were not shown because too many files have changed in this diff Show More