mirror of
https://github.com/ollama/ollama.git
synced 2026-02-06 21:53:11 -05:00
Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
099a0f18ef | ||
|
|
fff696ee31 | ||
|
|
2e3ce6eab3 | ||
|
|
9e2003f88a | ||
|
|
42e1d49fbe | ||
|
|
814630ca60 | ||
|
|
87cf187774 | ||
|
|
6ddd8862cd | ||
|
|
f1373193dc | ||
|
|
8a4b77f9da | ||
|
|
5f53fe7884 | ||
|
|
7ab4ca0e7f | ||
|
|
e36f389e82 | ||
|
|
c61023f554 | ||
|
|
d25535c3f3 | ||
|
|
c323161f24 | ||
|
|
255579aaa7 | ||
|
|
f7102ba826 | ||
|
|
cefabd79a8 | ||
|
|
df70249520 | ||
|
|
77eb2ca619 | ||
|
|
ee25219edd | ||
|
|
b1fccabb34 | ||
|
|
a6355329bf | ||
|
|
0398b24b42 | ||
|
|
75b1dddf91 | ||
|
|
e1e80ffc3e | ||
|
|
71896485fd | ||
|
|
ef00199fb4 | ||
|
|
8f4a008139 | ||
|
|
d8cc798c2b | ||
|
|
6582f6da5c | ||
|
|
0334ffa625 | ||
|
|
d11fbd2c60 | ||
|
|
6a7c3f188e | ||
|
|
427e2c962a |
22
.github/workflows/test-install.yaml
vendored
Normal file
22
.github/workflows/test-install.yaml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
name: test-install
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'scripts/install.sh'
|
||||
- '.github/workflows/test-install.yaml'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Run install script
|
||||
run: sh ./scripts/install.sh
|
||||
env:
|
||||
OLLAMA_NO_START: 1 # do not start app
|
||||
- name: Verify ollama is available
|
||||
run: ollama --version
|
||||
@@ -182,7 +182,7 @@ option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||
|
||||
if(MLX_ENGINE)
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
||||
|
||||
# Find CUDA toolkit if MLX is built with CUDA support
|
||||
find_package(CUDAToolkit)
|
||||
@@ -216,4 +216,4 @@ if(MLX_ENGINE)
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -147,7 +147,7 @@ ARG PARALLEL
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY x/imagegen/mlx x/imagegen/mlx
|
||||
COPY go.mod go.sum .
|
||||
COPY MLX_VERSION .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
|
||||
153
anthropic/anthropic.go
Normal file → Executable file
153
anthropic/anthropic.go
Normal file → Executable file
@@ -211,6 +211,7 @@ type MessageDelta struct {
|
||||
|
||||
// DeltaUsage contains cumulative token usage
|
||||
type DeltaUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
@@ -517,24 +518,26 @@ func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
estimatedInputTokens: estimatedInputTokens,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -550,7 +553,11 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
// Use actual metrics if available, otherwise use estimate
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
|
||||
c.inputTokens = c.estimatedInputTokens
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
@@ -721,6 +728,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
})
|
||||
}
|
||||
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
c.outputTokens = r.Metrics.EvalCount
|
||||
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||
|
||||
@@ -732,6 +740,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: DeltaUsage{
|
||||
InputTokens: c.inputTokens,
|
||||
OutputTokens: c.outputTokens,
|
||||
},
|
||||
},
|
||||
@@ -776,3 +785,117 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// CountTokensRequest represents an Anthropic count_tokens request
|
||||
type CountTokensRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
|
||||
func EstimateInputTokens(req MessagesRequest) int {
|
||||
return estimateTokens(CountTokensRequest{
|
||||
Model: req.Model,
|
||||
Messages: req.Messages,
|
||||
System: req.System,
|
||||
Tools: req.Tools,
|
||||
Thinking: req.Thinking,
|
||||
})
|
||||
}
|
||||
|
||||
// CountTokensResponse represents an Anthropic count_tokens response
|
||||
type CountTokensResponse struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
}
|
||||
|
||||
// estimateTokens returns a rough estimate of tokens (len/4).
|
||||
// TODO: Replace with actual tokenization via Tokenize API for accuracy.
|
||||
// Current len/4 heuristic is a rough approximation (~4 chars/token average).
|
||||
func estimateTokens(req CountTokensRequest) int {
|
||||
var totalLen int
|
||||
|
||||
// Count system prompt
|
||||
if req.System != nil {
|
||||
totalLen += countAnyContent(req.System)
|
||||
}
|
||||
|
||||
// Count messages
|
||||
for _, msg := range req.Messages {
|
||||
// Count role (always present)
|
||||
totalLen += len(msg.Role)
|
||||
// Count content
|
||||
contentLen := countAnyContent(msg.Content)
|
||||
totalLen += contentLen
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
|
||||
}
|
||||
|
||||
// Return len/4 as rough token estimate, minimum 1 if there's any content
|
||||
tokens := totalLen / 4
|
||||
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
|
||||
tokens = 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func countAnyContent(content any) int {
|
||||
if content == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return len(c)
|
||||
case []any:
|
||||
total := 0
|
||||
for _, block := range c {
|
||||
total += countContentBlock(block)
|
||||
}
|
||||
return total
|
||||
default:
|
||||
if data, err := json.Marshal(content); err == nil {
|
||||
return len(data)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func countContentBlock(block any) int {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
if s, ok := block.(string); ok {
|
||||
return len(s)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
total := 0
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
total += len(text)
|
||||
}
|
||||
|
||||
if thinking, ok := blockMap["thinking"].(string); ok {
|
||||
total += len(thinking)
|
||||
}
|
||||
|
||||
if blockType == "tool_use" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if blockType == "tool_result" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
142
anthropic/anthropic_test.go
Normal file → Executable file
142
anthropic/anthropic_test.go
Normal file → Executable file
@@ -321,8 +321,6 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
||||
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
||||
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
@@ -605,7 +603,7 @@ func TestGenerateMessageID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
@@ -642,7 +640,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{EvalCount: 5},
|
||||
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||
}
|
||||
|
||||
events2 := conv.Process(resp2)
|
||||
@@ -650,6 +648,24 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||
hasStop := false
|
||||
for _, e := range events2 {
|
||||
if e.Event == "message_delta" {
|
||||
if data, ok := e.Data.(MessageDeltaEvent); ok {
|
||||
if data.Type != "message_delta" {
|
||||
t.Errorf("unexpected data type: %+v", data)
|
||||
}
|
||||
|
||||
if data.Delta.StopReason != "end_turn" {
|
||||
t.Errorf("unexpected stop reason: %+v", data.Delta.StopReason)
|
||||
}
|
||||
|
||||
if data.Usage.InputTokens != 10 || data.Usage.OutputTokens != 5 {
|
||||
t.Errorf("unexpected usage: %+v", data.Usage)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("unexpected data: %+v", e.Data)
|
||||
}
|
||||
}
|
||||
|
||||
if e.Event == "message_stop" {
|
||||
hasStop = true
|
||||
}
|
||||
@@ -660,7 +676,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -713,7 +729,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
@@ -760,7 +776,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
@@ -824,10 +840,6 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
||||
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
||||
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
||||
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
||||
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -881,11 +893,9 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -919,7 +929,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -951,3 +961,105 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEstimateTokens_SimpleMessage(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// "user" (4) + "Hello, world!" (13) = 17 chars / 4 = 4 tokens
|
||||
if tokens < 1 {
|
||||
t.Errorf("expected at least 1 token, got %d", tokens)
|
||||
}
|
||||
// Sanity check: shouldn't be wildly off
|
||||
if tokens > 10 {
|
||||
t.Errorf("expected fewer than 10 tokens for short message, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// System prompt adds to count
|
||||
if tokens < 5 {
|
||||
t.Errorf("expected at least 5 tokens with system prompt, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithTools(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather for a location",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// Tools add significant content
|
||||
if tokens < 10 {
|
||||
t.Errorf("expected at least 10 tokens with tools, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithThinking(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this carefully...",
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "Here is my response.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// Thinking content should be counted
|
||||
if tokens < 10 {
|
||||
t.Errorf("expected at least 10 tokens with thinking content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_EmptyContent(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
if tokens != 0 {
|
||||
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -466,3 +466,25 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// AliasRequest is the request body for creating or updating a model alias.
|
||||
type AliasRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
|
||||
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
// AliasDeleteRequest is the request body for deleting a model alias.
|
||||
type AliasDeleteRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
|
||||
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
|
||||
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
19
cmd/cmd.go
19
cmd/cmd.go
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/containerd/console"
|
||||
"github.com/mattn/go-runewidth"
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/pkg/browser"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -52,7 +53,7 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||
@@ -366,14 +367,25 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
return err
|
||||
} else if info.RemoteHost != "" {
|
||||
// Cloud model, no need to load/unload
|
||||
|
||||
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||
|
||||
// Check if user is signed in for ollama.com cloud models
|
||||
if isCloud {
|
||||
if _, err := client.Whoami(cmd.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if opts.ShowConnect {
|
||||
p.StopAndClear()
|
||||
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
|
||||
if isCloud {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -663,6 +675,7 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println()
|
||||
|
||||
if aErr.SigninURL != "" {
|
||||
_ = browser.OpenURL(aErr.SigninURL)
|
||||
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
@@ -1750,7 +1763,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||
return err
|
||||
}
|
||||
if err := startApp(cmd.Context(), client); err != nil {
|
||||
return fmt.Errorf("ollama server not responding - %w", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
105
cmd/cmd_test.go
105
cmd/cmd_test.go
@@ -3,6 +3,7 @@ package cmd
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -1553,7 +1554,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
Details: api.ModelDetails{
|
||||
Family: "ZImagePipeline",
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
QuantizationLevel: "Q8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImage},
|
||||
Requires: "0.14.0",
|
||||
@@ -1565,7 +1566,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
expect := " Model\n" +
|
||||
" architecture ZImagePipeline \n" +
|
||||
" parameters 10.3B \n" +
|
||||
" quantization FP8 \n" +
|
||||
" quantization Q8 \n" +
|
||||
" requires 0.14.0 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
@@ -1659,3 +1660,103 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
t.Error("Copy Think should not be affected by original modification")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteHost string
|
||||
whoamiStatus int
|
||||
whoamiResp any
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "ollama.com cloud model - user signed in",
|
||||
remoteHost: "https://ollama.com",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
},
|
||||
{
|
||||
name: "ollama.com cloud model - user not signed in",
|
||||
remoteHost: "https://ollama.com",
|
||||
whoamiStatus: http.StatusUnauthorized,
|
||||
whoamiResp: map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
},
|
||||
expectedError: "unauthorized",
|
||||
},
|
||||
{
|
||||
name: "non-ollama.com remote - no auth check",
|
||||
remoteHost: "https://other-remote.com",
|
||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||
whoamiResp: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
whoamiCalled := false
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
RemoteHost: tt.remoteHost,
|
||||
RemoteModel: "test-model",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
case "/api/me":
|
||||
whoamiCalled = true
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(tt.whoamiStatus)
|
||||
if tt.whoamiResp != nil {
|
||||
if err := json.NewEncoder(w).Encode(tt.whoamiResp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
opts := &runOptions{
|
||||
Model: "test-cloud-model",
|
||||
ShowConnect: false,
|
||||
}
|
||||
|
||||
err := loadOrUnloadModel(cmd, opts)
|
||||
|
||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
|
||||
if !whoamiCalled {
|
||||
t.Error("expected whoami to be called for ollama.com cloud model")
|
||||
}
|
||||
} else {
|
||||
if whoamiCalled {
|
||||
t.Error("whoami should not be called for non-ollama.com remote")
|
||||
}
|
||||
}
|
||||
|
||||
if tt.expectedError != "" {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
||||
} else {
|
||||
var authErr api.AuthorizationError
|
||||
if !errors.As(err, &authErr) {
|
||||
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +1,32 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration
|
||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
// Compile-time check that Claude implements AliasConfigurer
|
||||
var _ AliasConfigurer = (*Claude)(nil)
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string) []string {
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
return []string{"--model", model}
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
return nil
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
@@ -41,20 +48,146 @@ func (c *Claude) findPath() (string, error) {
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string) error {
|
||||
func (c *Claude) Run(model string, args []string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command(claudePath, c.args(model)...)
|
||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
|
||||
env := append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
|
||||
env = append(env, c.modelEnvVars(model)...)
|
||||
|
||||
cmd.Env = env
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||
func (c *Claude) modelEnvVars(model string) []string {
|
||||
primary := model
|
||||
fast := model
|
||||
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
|
||||
if p := cfg.Aliases["primary"]; p != "" {
|
||||
primary = p
|
||||
}
|
||||
if f := cfg.Aliases["fast"]; f != "" {
|
||||
fast = f
|
||||
}
|
||||
}
|
||||
return []string{
|
||||
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
|
||||
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureAliases sets up model aliases for Claude Code.
|
||||
// model: the model to use (if empty, user will be prompted to select)
|
||||
// aliases: existing alias configuration to preserve/update
|
||||
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
|
||||
// there is a better strategy for prompt caching on local models.
|
||||
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existingAliases {
|
||||
aliases[k] = v
|
||||
}
|
||||
|
||||
if model != "" {
|
||||
aliases["primary"] = model
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" {
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
||||
return aliases, false, nil
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
return aliases, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := selectPrompt("Select model:", items)
|
||||
fmt.Fprintf(os.Stderr, "\033[3A\033[J")
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
aliases["primary"] = primary
|
||||
}
|
||||
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
return aliases, true, nil
|
||||
}
|
||||
|
||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
||||
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
|
||||
// prevent stale routing to a previous cloud model.
|
||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
if aliases["fast"] == "" {
|
||||
for _, prefix := range prefixes {
|
||||
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
prefixAliases := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
var errs []string
|
||||
for prefix, target := range prefixAliases {
|
||||
req := &api.AliasRequest{
|
||||
Alias: prefix,
|
||||
Target: target,
|
||||
PrefixMatching: true,
|
||||
}
|
||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
||||
errs = append(errs, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -84,18 +85,114 @@ func TestClaudeArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil},
|
||||
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil, nil},
|
||||
{"with model and verbose", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||
{"with allowed tools", "llama3.2", []string{"--allowedTools", "Read,Write,Bash"}, []string{"--model", "llama3.2", "--allowedTools", "Read,Write,Bash"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeModelEnvVars(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
envMap := func(envs []string) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for _, e := range envs {
|
||||
k, v, _ := strings.Cut(e, "=")
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
got := envMap(c.modelEnvVars("llama3.2"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
||||
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2" {
|
||||
t.Errorf("SONNET = %q, want llama3.2", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2" {
|
||||
t.Errorf("HAIKU = %q, want llama3.2", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"qwen3:8b"})
|
||||
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
||||
|
||||
got := envMap(c.modelEnvVars("qwen3:8b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses fast alias for haiku", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"llama3.2:70b"})
|
||||
saveAliases("claude", map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
})
|
||||
|
||||
got := envMap(c.modelEnvVars("llama3.2:70b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
|
||||
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("alias primary overrides model param", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"saved-model"})
|
||||
saveAliases("claude", map[string]string{"primary": "saved-model"})
|
||||
|
||||
got := envMap(c.modelEnvVars("different-model"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
|
||||
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,20 +14,21 @@ type Codex struct{}
|
||||
|
||||
func (c *Codex) String() string { return "Codex" }
|
||||
|
||||
func (c *Codex) args(model string) []string {
|
||||
func (c *Codex) args(model string, extra []string) []string {
|
||||
args := []string{"--oss"}
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Codex) Run(model string) error {
|
||||
func (c *Codex) Run(model string, args []string) error {
|
||||
if err := checkCodexVersion(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("codex", c.args(model)...)
|
||||
cmd := exec.Command("codex", c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
@@ -11,17 +11,20 @@ func TestCodexArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", []string{"--oss"}},
|
||||
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", nil, []string{"--oss"}},
|
||||
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
|
||||
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,7 +12,8 @@ import (
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
@@ -20,6 +21,14 @@ type config struct {
|
||||
}
|
||||
|
||||
func configPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "config.json"), nil
|
||||
}
|
||||
|
||||
func legacyConfigPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -27,6 +36,44 @@ func configPath() (string, error) {
|
||||
return filepath.Join(home, ".ollama", "config", "config.json"), nil
|
||||
}
|
||||
|
||||
// migrateConfig moves the config from the legacy path to ~/.ollama/config.json
|
||||
func migrateConfig() (bool, error) {
|
||||
oldPath, err := legacyConfigPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
oldData, err := os.ReadFile(oldPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
var js json.RawMessage
|
||||
if err := json.Unmarshal(oldData, &js); err != nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
newPath, err := configPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := os.WriteFile(newPath, oldData, 0o644); err != nil {
|
||||
return false, fmt.Errorf("write new config: %w", err)
|
||||
}
|
||||
|
||||
_ = os.Remove(oldPath)
|
||||
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func load() (*config, error) {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
@@ -34,6 +81,11 @@ func load() (*config, error) {
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
if migrated, merr := migrateConfig(); merr == nil && migrated {
|
||||
data, err = os.ReadFile(path)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &config{Integrations: make(map[string]*integration)}, nil
|
||||
@@ -79,8 +131,16 @@ func saveIntegration(appName string, models []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
||||
Models: models,
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
var aliases map[string]string
|
||||
if existing != nil && existing.Aliases != nil {
|
||||
aliases = existing.Aliases
|
||||
}
|
||||
|
||||
cfg.Integrations[key] = &integration{
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
@@ -100,6 +160,29 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return ic, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
if existing == nil {
|
||||
existing = &integration{}
|
||||
}
|
||||
|
||||
// Replace aliases entirely (not merge) so deletions are persisted
|
||||
existing.Aliases = aliases
|
||||
|
||||
cfg.Integrations[key] = existing
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
func listIntegrations() ([]integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
|
||||
677
cmd/config/config_cloud_test.go
Normal file
677
cmd/config/config_cloud_test.go
Normal file
@@ -0,0 +1,677 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSetAliases_CloudModel(t *testing.T) {
|
||||
// Test the SetAliases logic by checking the alias map behavior
|
||||
aliases := map[string]string{
|
||||
"primary": "kimi-k2.5:cloud",
|
||||
"fast": "kimi-k2.5:cloud",
|
||||
}
|
||||
|
||||
// Verify fast is set (cloud model behavior)
|
||||
if aliases["fast"] == "" {
|
||||
t.Error("cloud model should have fast alias set")
|
||||
}
|
||||
if aliases["fast"] != aliases["primary"] {
|
||||
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalModel(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:latest",
|
||||
}
|
||||
// Simulate local model behavior: fast should be empty
|
||||
delete(aliases, "fast")
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("local model should have empty fast alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save with both primary and fast
|
||||
initial := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
if err := saveAliases("claude", initial); err != nil {
|
||||
t.Fatalf("failed to save initial aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify both are saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
|
||||
// Now save without fast (simulating switch to local model)
|
||||
updated := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast intentionally missing
|
||||
}
|
||||
if err := saveAliases("claude", updated); err != nil {
|
||||
t.Fatalf("failed to save updated aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify fast is GONE (not merged/preserved)
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save integration with models
|
||||
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
|
||||
t.Fatalf("failed to save integration: %v", err)
|
||||
}
|
||||
|
||||
// Then update aliases
|
||||
aliases := map[string]string{"primary": "new-model"}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
|
||||
t.Errorf("models should be preserved, got %v", loaded.Models)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyMap clears all aliases
|
||||
func TestSaveAliases_EmptyMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save empty map
|
||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
||||
t.Fatalf("failed to save empty: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) != 0 {
|
||||
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_NilMap handles nil gracefully
|
||||
func TestSaveAliases_NilMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases first
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save nil map - should clear aliases
|
||||
if err := saveAliases("claude", nil); err != nil {
|
||||
t.Fatalf("failed to save nil: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) > 0 {
|
||||
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyAppName returns error
|
||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||
err := saveAliases("", map[string]string{"primary": "model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Load with different case
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model1" {
|
||||
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
|
||||
// Update with different case
|
||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to update: %v", err)
|
||||
}
|
||||
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model2" {
|
||||
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
|
||||
func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases for non-existent integration
|
||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("newintegration")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureAliases_AliasMap(t *testing.T) {
|
||||
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
|
||||
aliases := make(map[string]string)
|
||||
aliases["primary"] = "cloud-model"
|
||||
|
||||
// Simulate cloud model behavior
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model preserves custom fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "custom-fast-model",
|
||||
}
|
||||
|
||||
// Simulate cloud model behavior - should preserve existing fast
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "custom-fast-model" {
|
||||
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local model clears fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
"fast": "should-be-cleared",
|
||||
}
|
||||
|
||||
// Simulate local model behavior
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching cloud to local clears fast", func(t *testing.T) {
|
||||
// Start with cloud config
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
|
||||
// Switch to local
|
||||
aliases["primary"] = "local-model"
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
|
||||
}
|
||||
if aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated, got %q", aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching local to cloud sets fast", func(t *testing.T) {
|
||||
// Start with local config (no fast)
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
}
|
||||
|
||||
// Switch to cloud
|
||||
aliases["primary"] = "cloud-model"
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetAliases_PrefixMapping(t *testing.T) {
|
||||
// This tests the expected mapping without needing a real client
|
||||
aliases := map[string]string{
|
||||
"primary": "my-cloud-model",
|
||||
"fast": "my-fast-model",
|
||||
}
|
||||
|
||||
expectedMappings := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
|
||||
t.Errorf("claude-sonnet- should map to primary")
|
||||
}
|
||||
if expectedMappings["claude-haiku-"] != "my-fast-model" {
|
||||
t.Errorf("claude-haiku- should map to fast")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast is empty/missing - indicates local model
|
||||
}
|
||||
|
||||
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
// Verify the logic: when fast is empty, we should delete
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("fast should be empty for local model")
|
||||
}
|
||||
|
||||
// Verify we have the right prefixes to delete
|
||||
if len(prefixesToDelete) != 2 {
|
||||
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
|
||||
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server fails, config should NOT be saved
|
||||
serverErr := errors.New("server unavailable")
|
||||
|
||||
if serverErr == nil {
|
||||
t.Error("config should NOT be saved when server fails")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
|
||||
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server succeeds, config should be saved
|
||||
var serverErr error
|
||||
if serverErr != nil {
|
||||
t.Fatal("server should succeed")
|
||||
}
|
||||
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("saveAliases failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was actually saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Write config with extra fields
|
||||
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
os.MkdirAll(filepath.Dir(configPath), 0o755)
|
||||
|
||||
// Note: Our config struct only has Integrations, so top-level unknown fields
|
||||
// won't be preserved by our current implementation. This test documents that.
|
||||
initialConfig := `{
|
||||
"integrations": {
|
||||
"claude": {
|
||||
"models": ["model1"],
|
||||
"aliases": {"primary": "model1"},
|
||||
"unknownField": "should be lost"
|
||||
}
|
||||
},
|
||||
"topLevelUnknown": "will be lost"
|
||||
}`
|
||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||
|
||||
// Update aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Read raw file to check
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
// models should be preserved
|
||||
if !contains(content, "model1") {
|
||||
t.Error("models should be preserved")
|
||||
}
|
||||
|
||||
// primary should be updated
|
||||
if !contains(content, "model2") {
|
||||
t.Error("primary should be updated to model2")
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
||||
c := &Claude{}
|
||||
var _ AliasConfigurer = c // Compile-time check
|
||||
}
|
||||
|
||||
func TestModelNameEdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
model string
|
||||
}{
|
||||
{"simple", "llama3.2"},
|
||||
{"with tag", "llama3.2:latest"},
|
||||
{"with cloud tag", "kimi-k2.5:cloud"},
|
||||
{"with namespace", "library/llama3.2"},
|
||||
{"with dots", "glm-4.7-flash"},
|
||||
{"with numbers", "qwen3:8b"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
aliases := map[string]string{"primary": tc.model}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != tc.model {
|
||||
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchingScenarios(t *testing.T) {
|
||||
t.Run("cloud to local removes fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to local (no fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local to cloud adds fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial local config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to cloud (with fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud to different cloud updates both", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-1",
|
||||
"fast": "cloud-model-1",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to different cloud
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-2",
|
||||
"fast": "cloud-model-2",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model-2" {
|
||||
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolCapabilityFiltering(t *testing.T) {
|
||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
||||
// Both cloud and local models are checked for tool capability via Show API
|
||||
// Only models with "tools" in capabilities are included
|
||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("tool capable model should be marked as such")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("ToolCapable field should be accessible")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsCloudModel_RequiresClient(t *testing.T) {
|
||||
t.Run("nil client always returns false", func(t *testing.T) {
|
||||
// isCloudModel now only uses Show API, no suffix detection
|
||||
if isCloudModel(context.Background(), nil, "model:cloud") {
|
||||
t.Error("nil client should return false regardless of suffix")
|
||||
}
|
||||
if isCloudModel(context.Background(), nil, "local-model") {
|
||||
t.Error("nil client should return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases with one model
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Save integration with same model (this is the pattern we use)
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("out of sync config is detectable", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate out-of-sync state (like manual edit or bug)
|
||||
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
|
||||
// They should be different (this is the bug state)
|
||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||
t.Error("expected out-of-sync state for this test")
|
||||
}
|
||||
|
||||
// The fix: when updating aliases, also update models
|
||||
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ = loadIntegration("claude")
|
||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||
loaded.Models[0], loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updating primary alias updates models too", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial state
|
||||
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Update aliases AND models together
|
||||
newAliases := map[string]string{"primary": "updated-model"}
|
||||
if err := saveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Models[0] != "updated-model" {
|
||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "updated-model" {
|
||||
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -46,6 +46,53 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save and load aliases", func(t *testing.T) {
|
||||
models := []string{"llama3.2"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases == nil {
|
||||
t.Fatal("expected aliases to be saved")
|
||||
}
|
||||
for k, v := range aliases {
|
||||
if config.Aliases[k] != v {
|
||||
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases["primary"] != "model-a" {
|
||||
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
@@ -200,12 +247,10 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Create corrupted config.json file
|
||||
dir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
dir := filepath.Join(tmpDir, ".ollama")
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
// Corrupted file is treated as empty, so loadIntegration returns not found
|
||||
_, err := loadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
@@ -267,7 +312,7 @@ func TestConfigPath(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
if path != expected {
|
||||
t.Errorf("expected %s, got %s", expected, path)
|
||||
}
|
||||
@@ -322,6 +367,183 @@ func TestLoad(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateConfig(t *testing.T) {
|
||||
t.Run("migrates legacy file to new location", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
data := []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), data, 0o644)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
newPath, _ := configPath()
|
||||
got, err := os.ReadFile(newPath)
|
||||
if err != nil {
|
||||
t.Fatalf("new config not found: %v", err)
|
||||
}
|
||||
if string(got) != string(data) {
|
||||
t.Errorf("content mismatch: got %s", got)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||
t.Error("legacy file should have been removed")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(legacyDir); !os.IsNotExist(err) {
|
||||
t.Error("legacy directory should have been removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no-op when no legacy file exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("expected no migration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips corrupt legacy file", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{corrupt`), 0o644)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("should not migrate corrupt file")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); os.IsNotExist(err) {
|
||||
t.Error("corrupt legacy file should not have been deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("new path takes precedence over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"old":{"models":["old-model"]}}}`), 0o644)
|
||||
|
||||
newDir := filepath.Join(tmpDir, ".ollama")
|
||||
os.WriteFile(filepath.Join(newDir, "config.json"), []byte(`{"integrations":{"new":{"models":["new-model"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := cfg.Integrations["new"]; !ok {
|
||||
t.Error("expected new-path integration to be loaded")
|
||||
}
|
||||
if _, ok := cfg.Integrations["old"]; ok {
|
||||
t.Error("legacy integration should not have been loaded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("idempotent when called twice", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||
|
||||
if _, err := migrateConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("second migration should be a no-op")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy directory preserved if not empty", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "other-file.txt"), []byte("keep me"), 0o644)
|
||||
|
||||
if _, err := migrateConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(legacyDir); os.IsNotExist(err) {
|
||||
t.Error("directory with other files should not have been removed")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "other-file.txt")); os.IsNotExist(err) {
|
||||
t.Error("other files in legacy directory should be untouched")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save writes to new path after migration", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
// load triggers migration, then save should write to new path
|
||||
if err := saveIntegration("codex", []string{"qwen2.5"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
if _, err := os.Stat(newPath); os.IsNotExist(err) {
|
||||
t.Error("save should write to new path")
|
||||
}
|
||||
|
||||
// old path should not be recreated
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||
t.Error("save should not recreate legacy path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load triggers migration transparently", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Integrations["claude"] == nil || cfg.Integrations["claude"].Models[0] != "llama3.2" {
|
||||
t.Error("migration via load() did not preserve data")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
@@ -39,7 +41,7 @@ type modelEntry struct {
|
||||
|
||||
func (d *Droid) String() string { return "Droid" }
|
||||
|
||||
func (d *Droid) Run(model string) error {
|
||||
func (d *Droid) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("droid"); err != nil {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
@@ -53,7 +55,7 @@ func (d *Droid) Run(model string) error {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid")
|
||||
cmd := exec.Command("droid", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -112,9 +114,17 @@ func (d *Droid) Edit(models []string) error {
|
||||
}
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
var newModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
maxOutput := 64000
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
maxOutput = l.Output
|
||||
}
|
||||
}
|
||||
modelID := fmt.Sprintf("custom:%s-%d", model, i)
|
||||
newModels = append(newModels, modelEntry{
|
||||
Model: model,
|
||||
@@ -122,7 +132,7 @@ func (d *Droid) Edit(models []string) error {
|
||||
BaseURL: envconfig.Host().String() + "/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: 64000,
|
||||
MaxOutputTokens: maxOutput,
|
||||
SupportsImages: false,
|
||||
ID: modelID,
|
||||
Index: i,
|
||||
|
||||
@@ -1251,6 +1251,55 @@ func TestDroidEdit_LargeNumberOfModels(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||
|
||||
if err := d.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(settingsPath)
|
||||
var settings map[string]any
|
||||
json.Unmarshal(data, &settings)
|
||||
|
||||
models := settings["customModels"].([]any)
|
||||
entry := models[0].(map[string]any)
|
||||
if entry["maxOutputTokens"] != float64(64000) {
|
||||
t.Errorf("local model maxOutputTokens = %v, want 64000", entry["maxOutputTokens"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
||||
// Verify that every cloud model in cloudModelLimits has a valid output
|
||||
// value that would be used for maxOutputTokens when isCloudModel returns true.
|
||||
// :cloud suffix stripping must also work since that's how users specify them.
|
||||
for name, expected := range cloudModelLimits {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(name)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
|
||||
}
|
||||
if l.Output != expected.Output {
|
||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||
}
|
||||
// Also verify :cloud suffix lookup
|
||||
cloudName := name + ":cloud"
|
||||
l2, ok := lookupCloudModelLimit(cloudName)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
||||
}
|
||||
if l2.Output != expected.Output {
|
||||
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_ArraysWithMixedTypes(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -22,7 +23,7 @@ import (
|
||||
// Runner can run an integration with a model.
|
||||
|
||||
type Runner interface {
|
||||
Run(model string) error
|
||||
Run(model string, args []string) error
|
||||
// String returns the human-readable name of the integration
|
||||
String() string
|
||||
}
|
||||
@@ -38,6 +39,15 @@ type Editor interface {
|
||||
Models() []string
|
||||
}
|
||||
|
||||
// AliasConfigurer can configure model aliases (e.g., for subagent routing).
|
||||
// Integrations like Claude and Codex use this to route model requests to local models.
|
||||
type AliasConfigurer interface {
|
||||
// ConfigureAliases prompts the user to configure aliases and returns the updated map.
|
||||
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
|
||||
// SetAliases syncs the configured aliases to the server
|
||||
SetAliases(ctx context.Context, aliases map[string]string) error
|
||||
}
|
||||
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
@@ -49,6 +59,15 @@ var integrations = map[string]Runner{
|
||||
"openclaw": &Openclaw{},
|
||||
}
|
||||
|
||||
// recommendedModels are shown when the user has no models or as suggestions.
|
||||
// Order matters: local models first, then cloud models.
|
||||
var recommendedModels = []selectItem{
|
||||
{Name: "glm-4.7-flash", Description: "Recommended (requires ~25GB VRAM)"},
|
||||
{Name: "qwen3:8b", Description: "Recommended (requires ~11GB VRAM)"},
|
||||
{Name: "glm-4.7:cloud", Description: "Recommended"},
|
||||
{Name: "kimi-k2.5:cloud", Description: "Recommended"},
|
||||
}
|
||||
|
||||
// integrationAliases are hidden from the interactive selector but work as CLI arguments.
|
||||
var integrationAliases = map[string]bool{
|
||||
"clawdbot": true,
|
||||
@@ -94,152 +113,226 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(models.Models) == 0 {
|
||||
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
var items []selectItem
|
||||
cloudModels := make(map[string]bool)
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
if m.RemoteModel != "" {
|
||||
cloudModels[m.Name] = true
|
||||
}
|
||||
items = append(items, selectItem{Name: m.Name})
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
// Get previously configured models (saved config takes precedence)
|
||||
var preChecked []string
|
||||
if saved, err := loadIntegration(name); err == nil {
|
||||
preChecked = saved.Models
|
||||
} else if editor, ok := r.(Editor); ok {
|
||||
preChecked = editor.Models()
|
||||
}
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
items, preChecked, existingModels, cloudModels := buildModelList(existing, preChecked, current)
|
||||
|
||||
// If current model is configured, move to front of preChecked
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no models available")
|
||||
}
|
||||
|
||||
// Sort: checked first, then alphabetical
|
||||
slices.SortFunc(items, func(a, b selectItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
|
||||
var selected []string
|
||||
// only editors support multi-model selection
|
||||
if _, ok := r.(Editor); ok {
|
||||
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
|
||||
prompt := fmt.Sprintf("Select model for %s:", r)
|
||||
if _, ok := r.(AliasConfigurer); ok {
|
||||
prompt = fmt.Sprintf("Select Primary model for %s:", r)
|
||||
}
|
||||
model, err := selectPrompt(prompt, items)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selected = []string{model}
|
||||
}
|
||||
|
||||
// if any model in selected is a cloud model, ensure signed in
|
||||
var toPull []string
|
||||
for _, m := range selected {
|
||||
if !existingModels[m] {
|
||||
toPull = append(toPull, m)
|
||||
}
|
||||
}
|
||||
if len(toPull) > 0 {
|
||||
msg := fmt.Sprintf("Download %s?", strings.Join(toPull, ", "))
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return nil, err
|
||||
} else if !ok {
|
||||
return nil, errCancelled
|
||||
}
|
||||
for _, m := range toPull {
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, m); err != nil {
|
||||
return nil, fmt.Errorf("failed to pull %s: %w", m, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
|
||||
if existingModels[model] {
|
||||
return nil
|
||||
}
|
||||
msg := fmt.Sprintf("Download %s?", model)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, model); err != nil {
|
||||
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// showOrPull checks if a model exists via client.Show and offers to pull it if not found.
|
||||
func showOrPull(ctx context.Context, client *api.Client, model string) error {
|
||||
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||
return nil
|
||||
}
|
||||
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return pullModel(ctx, client, model)
|
||||
}
|
||||
|
||||
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{
|
||||
Name: m.Name,
|
||||
Remote: m.RemoteModel != "",
|
||||
})
|
||||
}
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
return items, existingModels, cloudModels, client, nil
|
||||
}
|
||||
|
||||
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) > 0 {
|
||||
// ensure user is signed in
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
}
|
||||
if len(selectedCloudModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return nil, err
|
||||
}
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return nil, fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
|
||||
// TODO(parthsareen): extract into auth package for cmd
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string) error {
|
||||
func runIntegration(name, modelName string, args []string) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName)
|
||||
return r.Run(modelName, args)
|
||||
}
|
||||
|
||||
// syncAliases syncs aliases to server and saves locally for an AliasConfigurer.
|
||||
func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, name, model string, existing map[string]string) error {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existing {
|
||||
aliases[k] = v
|
||||
}
|
||||
aliases["primary"] = model
|
||||
|
||||
if isCloudModel(ctx, client, model) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = model
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if err := ac.SetAliases(ctx, aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
return saveAliases(name, aliases)
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
@@ -248,7 +341,7 @@ func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) erro
|
||||
var configFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "launch [INTEGRATION]",
|
||||
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||
Short: "Launch an integration with Ollama",
|
||||
Long: `Launch an integration configured with Ollama models.
|
||||
|
||||
@@ -263,14 +356,37 @@ Examples:
|
||||
ollama launch
|
||||
ollama launch claude
|
||||
ollama launch claude --model <model>
|
||||
ollama launch droid --config (does not auto-launch)`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
ollama launch droid --config (does not auto-launch)
|
||||
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||
ollama launch codex -- --sandbox workspace-write`,
|
||||
Args: cobra.ArbitraryArgs,
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Extract integration name and args to pass through using -- separator
|
||||
var name string
|
||||
if len(args) > 0 {
|
||||
name = args[0]
|
||||
var passArgs []string
|
||||
dashIdx := cmd.ArgsLenAtDash()
|
||||
|
||||
if dashIdx == -1 {
|
||||
// No "--" separator: only allow 0 or 1 args (integration name)
|
||||
if len(args) > 1 {
|
||||
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
|
||||
}
|
||||
if len(args) == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
} else {
|
||||
// "--" was used: args before it = integration name, args after = passthrough
|
||||
if dashIdx > 1 {
|
||||
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||
}
|
||||
if dashIdx == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
passArgs = args[dashIdx:]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
var err error
|
||||
name, err = selectIntegration()
|
||||
if errors.Is(err, errCancelled) {
|
||||
@@ -286,16 +402,107 @@ Examples:
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
// If launching without --model, use saved config if available
|
||||
if !configFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0])
|
||||
// Handle AliasConfigurer integrations (claude, codex)
|
||||
if ac, ok := r.(AliasConfigurer); ok {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate --model flag if provided
|
||||
if modelFlag != "" {
|
||||
if err := showOrPull(cmd.Context(), client, modelFlag); err != nil {
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var model string
|
||||
var existingAliases map[string]string
|
||||
|
||||
// Load saved config
|
||||
if cfg, err := loadIntegration(name); err == nil {
|
||||
existingAliases = cfg.Aliases
|
||||
if len(cfg.Models) > 0 {
|
||||
model = cfg.Models[0]
|
||||
// AliasConfigurer integrations use single model; sanitize if multiple
|
||||
if len(cfg.Models) > 1 {
|
||||
_ = saveIntegration(name, []string{model})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --model flag overrides saved model
|
||||
if modelFlag != "" {
|
||||
model = modelFlag
|
||||
}
|
||||
|
||||
// Validate saved model still exists
|
||||
if model != "" && modelFlag == "" {
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
|
||||
if err := showOrPull(cmd.Context(), client, model); err != nil {
|
||||
model = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no valid model or --config flag, show picker
|
||||
if model == "" || configFlag {
|
||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model = aliases["primary"]
|
||||
existingAliases = aliases
|
||||
}
|
||||
|
||||
// Ensure cloud models are authenticated
|
||||
if isCloudModel(cmd.Context(), client, model) {
|
||||
if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Sync aliases and save
|
||||
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
|
||||
}
|
||||
if err := saveIntegration(name, []string{model}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
// Launch (unless --config without confirmation)
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
|
||||
return runIntegration(name, model, passArgs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return runIntegration(name, model, passArgs)
|
||||
}
|
||||
|
||||
// Validate --model flag for non-AliasConfigurer integrations
|
||||
if modelFlag != "" {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := showOrPull(cmd.Context(), client, modelFlag); err != nil {
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var models []string
|
||||
if modelFlag != "" {
|
||||
// When --model is specified, merge with existing models (new model becomes default)
|
||||
models = []string{modelFlag}
|
||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||
for _, m := range existing.Models {
|
||||
@@ -304,6 +511,8 @@ Examples:
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
|
||||
return runIntegration(name, saved.Models[0], passArgs)
|
||||
} else {
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
@@ -350,13 +559,13 @@ Examples:
|
||||
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
|
||||
return runIntegration(name, models[0])
|
||||
return runIntegration(name, models[0], passArgs)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
return runIntegration(name, models[0])
|
||||
return runIntegration(name, models[0], passArgs)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -364,3 +573,163 @@ Examples:
|
||||
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||
return cmd
|
||||
}
|
||||
|
||||
type modelInfo struct {
|
||||
Name string
|
||||
Remote bool
|
||||
ToolCapable bool
|
||||
}
|
||||
|
||||
// buildModelList merges existing models with recommendations, sorts them, and returns
|
||||
// the ordered items along with maps of existing and cloud model names.
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []selectItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
existingModels = make(map[string]bool)
|
||||
cloudModels = make(map[string]bool)
|
||||
recommended := make(map[string]bool)
|
||||
var hasLocalModel, hasCloudModel bool
|
||||
|
||||
for _, rec := range recommendedModels {
|
||||
recommended[rec.Name] = true
|
||||
}
|
||||
|
||||
for _, m := range existing {
|
||||
existingModels[m.Name] = true
|
||||
if m.Remote {
|
||||
cloudModels[m.Name] = true
|
||||
hasCloudModel = true
|
||||
} else {
|
||||
hasLocalModel = true
|
||||
}
|
||||
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||
existingModels[displayName] = true
|
||||
item := selectItem{Name: displayName}
|
||||
if recommended[displayName] {
|
||||
item.Description = "recommended"
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
for _, rec := range recommendedModels {
|
||||
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if strings.HasSuffix(rec.Name, ":cloud") {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
}
|
||||
|
||||
// Non-existing models get "install?" suffix and are pushed to the bottom.
|
||||
// When user has no models, preserve recommended order.
|
||||
notInstalled := make(map[string]bool)
|
||||
for i := range items {
|
||||
if !existingModels[items[i].Name] {
|
||||
notInstalled[items[i].Name] = true
|
||||
if items[i].Description != "" {
|
||||
items[i].Description += ", install?"
|
||||
} else {
|
||||
items[i].Description = "install?"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasLocalModel || hasCloudModel {
|
||||
slices.SortStableFunc(items, func(a, b selectItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if !ac && !bc && aNew != bNew {
|
||||
if aNew {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
}
|
||||
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
// isCloudModel checks if a model is a cloud model using the Show API.
|
||||
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||
if client == nil {
|
||||
return false
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return resp.RemoteModel != ""
|
||||
}
|
||||
|
||||
func pullModel(ctx context.Context, client *api.Client, model string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
var status string
|
||||
var spinner *progress.Spinner
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
if resp.Completed == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||
name = strings.TrimSpace(name)
|
||||
if isDigest {
|
||||
name = name[:min(12, len(name))]
|
||||
}
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if status != resp.Status {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: model}
|
||||
return client.Pull(ctx, &request, fn)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -90,8 +98,8 @@ func TestLaunchCmd(t *testing.T) {
|
||||
cmd := LaunchCmd(mockCheck)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
@@ -121,7 +129,7 @@ func TestLaunchCmd(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model")
|
||||
err := runIntegration("unknown-integration", "model", nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown integration, got nil")
|
||||
}
|
||||
@@ -174,15 +182,498 @@ func TestLaunchCmd_NilHeartbeat(t *testing.T) {
|
||||
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
|
||||
for name, r := range integrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// Test String() doesn't panic and returns non-empty
|
||||
displayName := r.String()
|
||||
if displayName == "" {
|
||||
t.Error("String() should not return empty")
|
||||
}
|
||||
|
||||
// Test Run() exists (we can't call it without actually running the command)
|
||||
// Just verify the method is available
|
||||
var _ func(string) error = r.Run
|
||||
var _ func(string, []string) error = r.Run
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseArgs(t *testing.T) {
|
||||
// Tests reflect cobra's ArgsLenAtDash() semantics:
|
||||
// - cobra strips "--" from args
|
||||
// - ArgsLenAtDash() returns the index where "--" was, or -1
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string // args as cobra delivers them (no "--")
|
||||
dashIdx int // what ArgsLenAtDash() returns
|
||||
wantName string
|
||||
wantArgs []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no extra args, no dash",
|
||||
args: []string{"claude"},
|
||||
dashIdx: -1,
|
||||
wantName: "claude",
|
||||
},
|
||||
{
|
||||
name: "with extra args after --",
|
||||
args: []string{"codex", "-p", "myprofile"},
|
||||
dashIdx: 1,
|
||||
wantName: "codex",
|
||||
wantArgs: []string{"-p", "myprofile"},
|
||||
},
|
||||
{
|
||||
name: "extra args only after --",
|
||||
args: []string{"codex", "--sandbox", "workspace-write"},
|
||||
dashIdx: 1,
|
||||
wantName: "codex",
|
||||
wantArgs: []string{"--sandbox", "workspace-write"},
|
||||
},
|
||||
{
|
||||
name: "-- at end with no args after",
|
||||
args: []string{"claude"},
|
||||
dashIdx: 1,
|
||||
wantName: "claude",
|
||||
},
|
||||
{
|
||||
name: "-- with no integration name",
|
||||
args: []string{"--verbose"},
|
||||
dashIdx: 0,
|
||||
wantName: "",
|
||||
wantArgs: []string{"--verbose"},
|
||||
},
|
||||
{
|
||||
name: "multiple args before -- is error",
|
||||
args: []string{"claude", "codex", "--verbose"},
|
||||
dashIdx: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "multiple args without -- is error",
|
||||
args: []string{"claude", "codex"},
|
||||
dashIdx: -1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no args, no dash",
|
||||
args: []string{},
|
||||
dashIdx: -1,
|
||||
wantName: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate the parsing logic from LaunchCmd using dashIdx
|
||||
var name string
|
||||
var parsedArgs []string
|
||||
var err error
|
||||
|
||||
dashIdx := tt.dashIdx
|
||||
args := tt.args
|
||||
|
||||
if dashIdx == -1 {
|
||||
if len(args) > 1 {
|
||||
err = fmt.Errorf("unexpected arguments: %v", args[1:])
|
||||
} else if len(args) == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
} else {
|
||||
if dashIdx > 1 {
|
||||
err = fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||
} else {
|
||||
if dashIdx == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
parsedArgs = args[dashIdx:]
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if name != tt.wantName {
|
||||
t.Errorf("name = %q, want %q", name, tt.wantName)
|
||||
}
|
||||
if !slices.Equal(parsedArgs, tt.wantArgs) {
|
||||
t.Errorf("args = %v, want %v", parsedArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCloudModel(t *testing.T) {
|
||||
// isCloudModel now only uses Show API, so nil client always returns false
|
||||
t.Run("nil client returns false", func(t *testing.T) {
|
||||
models := []string{"glm-4.7:cloud", "kimi-k2.5:cloud", "local-model"}
|
||||
for _, model := range models {
|
||||
if isCloudModel(context.Background(), nil, model) {
|
||||
t.Errorf("isCloudModel(%q) with nil client should return false", model)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func names(items []selectItem) []string {
|
||||
var out []string
|
||||
for _, item := range items {
|
||||
out = append(out, item.Name)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestBuildModelList_NoExistingModels(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(nil, nil, "")
|
||||
|
||||
want := []string{"glm-4.7-flash", "qwen3:8b", "glm-4.7:cloud", "kimi-k2.5:cloud"}
|
||||
if diff := cmp.Diff(want, names(items)); diff != "" {
|
||||
t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("item %q should have description ending with 'install?', got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "qwen2.5:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
want := []string{"llama3.2", "qwen2.5", "glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("cloud recs should be at bottom (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
want := []string{"glm-4.7:cloud", "llama3.2", "glm-4.7-flash", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("mixed models should be alphabetical (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_PreCheckedFirst(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
|
||||
got := names(items)
|
||||
|
||||
if got[0] != "llama3.2" {
|
||||
t.Errorf("pre-checked model should be first, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "glm-4.7-flash", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
|
||||
for _, item := range items {
|
||||
switch item.Name {
|
||||
case "glm-4.7-flash", "glm-4.7:cloud":
|
||||
if strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("installed recommended %q should not have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
case "kimi-k2.5:cloud", "qwen3:8b":
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("non-installed recommended %q should have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "glm-4.7-flash", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// glm-4.7-flash and glm-4.7:cloud are installed so they sort normally;
|
||||
// kimi-k2.5:cloud and qwen3:8b are not installed so they go to the bottom
|
||||
want := []string{"glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("existing cloud models should sort normally (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "kimi-k2.5:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// kimi-k2.5:cloud is installed so it sorts normally;
|
||||
// the rest of the recommendations are not installed so they go to the bottom
|
||||
want := []string{"kimi-k2.5:cloud", "llama3.2", "glm-4.7-flash", "glm-4.7:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("only non-installed models should be at bottom (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("non-installed %q should have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_LatestTagStripped(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "glm-4.7-flash:latest", Remote: false},
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, existingModels, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// :latest should be stripped from display names
|
||||
for _, name := range got {
|
||||
if strings.HasSuffix(name, ":latest") {
|
||||
t.Errorf("name %q should not have :latest suffix", name)
|
||||
}
|
||||
}
|
||||
|
||||
// glm-4.7-flash should not be duplicated (existing :latest matches the recommendation)
|
||||
count := 0
|
||||
for _, name := range got {
|
||||
if name == "glm-4.7-flash" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("glm-4.7-flash should appear exactly once, got %d in %v", count, got)
|
||||
}
|
||||
|
||||
// Stripped name should be in existingModels so it won't be pulled
|
||||
if !existingModels["glm-4.7-flash"] {
|
||||
t.Error("glm-4.7-flash should be in existingModels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
_, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
|
||||
if !existingModels["llama3.2"] {
|
||||
t.Error("llama3.2 should be in existingModels")
|
||||
}
|
||||
if !existingModels["glm-4.7:cloud"] {
|
||||
t.Error("glm-4.7:cloud should be in existingModels")
|
||||
}
|
||||
if existingModels["glm-4.7-flash"] {
|
||||
t.Error("glm-4.7-flash should not be in existingModels (it's a recommendation)")
|
||||
}
|
||||
|
||||
if !cloudModels["glm-4.7:cloud"] {
|
||||
t.Error("glm-4.7:cloud should be in cloudModels")
|
||||
}
|
||||
if !cloudModels["kimi-k2.5:cloud"] {
|
||||
t.Error("kimi-k2.5:cloud should be in cloudModels (recommended cloud)")
|
||||
}
|
||||
if cloudModels["llama3.2"] {
|
||||
t.Error("llama3.2 should not be in cloudModels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save a config for opencode so it looks like a previous launch
|
||||
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify loadIntegration returns the saved models
|
||||
saved, err := loadIntegration("opencode")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(saved.Models) == 0 {
|
||||
t.Fatal("expected saved models")
|
||||
}
|
||||
if saved.Models[0] != "llama3.2" {
|
||||
t.Errorf("expected llama3.2, got %s", saved.Models[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasConfigurerInterface(t *testing.T) {
|
||||
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
|
||||
claude := &Claude{}
|
||||
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
|
||||
t.Error("Claude should implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
|
||||
codex := &Codex{}
|
||||
if _, ok := interface{}(codex).(AliasConfigurer); ok {
|
||||
t.Error("Codex should not implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShowOrPull_ModelExists(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"model":"test-model"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := showOrPull(context.Background(), client, "test-model")
|
||||
if err != nil {
|
||||
t.Errorf("showOrPull should return nil when model exists, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
// confirmPrompt will fail in test (no terminal), so showOrPull should return an error
|
||||
err := showOrPull(context.Background(), client, "missing-model")
|
||||
if err == nil {
|
||||
t.Error("showOrPull should return error when model not found and no terminal available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_ShowCalledWithCorrectModel(t *testing.T) {
|
||||
var receivedModel string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
var req api.ShowRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err == nil {
|
||||
receivedModel = req.Model
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"model":"%s"}`, receivedModel)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
_ = showOrPull(context.Background(), client, "qwen3:8b")
|
||||
if receivedModel != "qwen3:8b" {
|
||||
t.Errorf("expected Show to be called with %q, got %q", "qwen3:8b", receivedModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_NoCloudModels(t *testing.T) {
|
||||
// ensureAuth should be a no-op when no cloud models are selected
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("no API calls expected when no cloud models selected")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ensureAuth(context.Background(), client, map[string]bool{}, []string{"local-model"})
|
||||
if err != nil {
|
||||
t.Errorf("ensureAuth should return nil for non-cloud models, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_CloudModelFilteredCorrectly(t *testing.T) {
|
||||
// ensureAuth should only care about models in cloudModels map
|
||||
var whoamiCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/me" {
|
||||
whoamiCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"name":"testuser"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cloudModels := map[string]bool{"cloud-model:cloud": true}
|
||||
selected := []string{"cloud-model:cloud", "local-model"}
|
||||
|
||||
err := ensureAuth(context.Background(), client, cloudModels, selected)
|
||||
if err != nil {
|
||||
t.Errorf("ensureAuth should succeed when user is authenticated, got: %v", err)
|
||||
}
|
||||
if !whoamiCalled {
|
||||
t.Error("expected whoami to be called for cloud model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_SkipsWhenNoCloudSelected(t *testing.T) {
|
||||
var whoamiCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/me" {
|
||||
whoamiCalled = true
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
// cloudModels has entries but none are in selected
|
||||
cloudModels := map[string]bool{"cloud-model:cloud": true}
|
||||
selected := []string{"local-model"}
|
||||
|
||||
err := ensureAuth(context.Background(), client, cloudModels, selected)
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error, got: %v", err)
|
||||
}
|
||||
if whoamiCalled {
|
||||
t.Error("whoami should not be called when no cloud models are selected")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,9 +17,7 @@ type Openclaw struct{}
|
||||
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
const ansiGreen = "\033[32m"
|
||||
|
||||
func (c *Openclaw) Run(model string) error {
|
||||
func (c *Openclaw) Run(model string, args []string) error {
|
||||
bin := "openclaw"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
bin = "clawdbot"
|
||||
@@ -38,7 +36,21 @@ func (c *Openclaw) Run(model string) error {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command(bin, "gateway")
|
||||
if !c.onboarded() {
|
||||
// Onboarding not completed: run it (model already set via Edit)
|
||||
// Use "ollama" as gateway token for simple local access
|
||||
cmd := exec.Command(bin, "onboard",
|
||||
"--auth-choice", "skip",
|
||||
"--gateway-token", "ollama",
|
||||
)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// Onboarding completed: run gateway
|
||||
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
// Capture output to detect "already running" message
|
||||
@@ -54,6 +66,35 @@ func (c *Openclaw) Run(model string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// onboarded checks if OpenClaw onboarding wizard was completed
|
||||
// by looking for the wizard.lastRunAt marker in the config
|
||||
func (c *Openclaw) onboarded() bool {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for wizard.lastRunAt marker (set when onboarding completes)
|
||||
wizard, _ := config["wizard"].(map[string]any)
|
||||
if wizard == nil {
|
||||
return false
|
||||
}
|
||||
lastRunAt, _ := wizard["lastRunAt"].(string)
|
||||
return lastRunAt != ""
|
||||
}
|
||||
|
||||
func (c *Openclaw) Paths() []string {
|
||||
home, _ := os.UserHomeDir()
|
||||
p := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
|
||||
@@ -763,3 +763,116 @@ func TestOpenclawEdit_CreatesDirectoryIfMissing(t *testing.T) {
|
||||
t.Fatal("directory was not created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawOnboarded(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("returns false when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when no config exists")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false when config exists but no wizard section", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when no wizard section")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false when wizard section exists but no lastRunAt", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{}}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when wizard.lastRunAt is missing")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false when wizard.lastRunAt is empty string", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":""}}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when wizard.lastRunAt is empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns true when wizard.lastRunAt is set", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
||||
|
||||
if !c.onboarded() {
|
||||
t.Error("expected true when wizard.lastRunAt is set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("checks legacy clawdbot path", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
||||
|
||||
if !c.onboarded() {
|
||||
t.Error("expected true when legacy config has wizard.lastRunAt")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prefers new path over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(newDir, 0o755)
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
// New path has no wizard marker
|
||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
|
||||
// Legacy has wizard marker
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false - should prefer new path which has no wizard marker")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false for corrupted JSON")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles wrong type for wizard section", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":"not a map"}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when wizard is wrong type")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
@@ -10,15 +11,56 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// cloudModelLimits maps cloud model base names to their token limits.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
if l, ok := cloudModelLimits[name]; ok {
|
||||
return l, true
|
||||
}
|
||||
base := strings.TrimSuffix(name, ":cloud")
|
||||
if base != name {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string) error {
|
||||
func (o *OpenCode) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("opencode"); err != nil {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
@@ -32,7 +74,7 @@ func (o *OpenCode) Run(model string) error {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode")
|
||||
cmd := exec.Command("opencode", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -113,6 +155,8 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
@@ -122,12 +166,29 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
existing["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
models[model] = map[string]any{
|
||||
entry := map[string]any{
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
models[model] = entry
|
||||
}
|
||||
|
||||
ollama["models"] = models
|
||||
|
||||
@@ -2,6 +2,7 @@ package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -495,6 +496,166 @@ func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry, ok := models[model].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("model %s not found in config", model)
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
if entry["limit"] != nil {
|
||||
t.Errorf("local model should not have limit set, got %v", entry["limit"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
// Set up a model with a user-configured limit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"llama3.2": {
|
||||
"name": "llama3.2",
|
||||
"_launch": true,
|
||||
"limit": {"context": 8192, "output": 4096}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`), 0o644)
|
||||
|
||||
// Re-edit should preserve the user's limit (not delete it)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("user-configured limit was removed")
|
||||
}
|
||||
if limit["context"] != float64(8192) {
|
||||
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
|
||||
}
|
||||
if limit["output"] != float64(4096) {
|
||||
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
||||
// Verify that when a cloud model entry has limits set (as Edit would do),
|
||||
// the structure matches what opencode expects and re-edit preserves them.
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
expected := cloudModelLimits["glm-4.7"]
|
||||
|
||||
// Simulate a cloud model that already has the limit set by a previous Edit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud",
|
||||
"_launch": true,
|
||||
"limit": {"context": %d, "output": %d}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, expected.Context, expected.Output)), 0o644)
|
||||
|
||||
// Re-edit should preserve the cloud model limit
|
||||
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("cloud model limit was removed on re-edit")
|
||||
}
|
||||
if limit["context"] != float64(expected.Context) {
|
||||
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
|
||||
}
|
||||
if limit["output"] != float64(expected.Output) {
|
||||
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupCloudModelLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantOK bool
|
||||
wantContext int
|
||||
wantOutput int
|
||||
}{
|
||||
{"glm-4.7", true, 202_752, 131_072},
|
||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||
{"kimi-k2.5", true, 262_144, 262_144},
|
||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
||||
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||
{"llama3.2", false, 0, 0},
|
||||
{"unknown-model:cloud", false, 0, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(tt.name)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
|
||||
}
|
||||
if ok {
|
||||
if l.Context != tt.wantContext {
|
||||
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
|
||||
}
|
||||
if l.Output != tt.wantOutput {
|
||||
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@@ -17,6 +17,7 @@ const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiClearDown = "\033[J"
|
||||
)
|
||||
|
||||
@@ -353,10 +354,15 @@ func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
||||
suffix = " " + ansiGray + "(default)" + ansiReset
|
||||
}
|
||||
|
||||
desc := ""
|
||||
if item.Description != "" {
|
||||
desc = " " + ansiGray + "- " + item.Description + ansiReset
|
||||
}
|
||||
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, desc, suffix)
|
||||
} else {
|
||||
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
|
||||
fmt.Fprintf(w, " %s %s %s%s%s\r\n", prefix, checkbox, item.Name, desc, suffix)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
@@ -96,6 +96,14 @@ func TestSelectState(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState([]selectItem{})
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != "" || err != nil {
|
||||
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
@@ -574,8 +582,19 @@ func TestRenderSelect(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "no matches") {
|
||||
t.Errorf("expected 'no matches' message, got: %s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) {
|
||||
s := newSelectState([]selectItem{})
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' message")
|
||||
t.Error("expected 'no matches' message for empty list with no filter")
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -10,19 +10,21 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
|
||||
|
||||
func startApp(ctx context.Context, client *api.Client) error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
return errNotRunning
|
||||
}
|
||||
link, err := os.Readlink(exe)
|
||||
if err != nil {
|
||||
return err
|
||||
return errNotRunning
|
||||
}
|
||||
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
|
||||
m := r.FindStringSubmatch(link)
|
||||
if len(m) != 1 {
|
||||
return errors.New("could not find ollama app")
|
||||
return errNotRunning
|
||||
}
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
|
||||
return err
|
||||
|
||||
@@ -313,8 +313,12 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &deepseek2Model{}
|
||||
case "Glm4MoeLiteForCausalLM":
|
||||
conv = &glm4MoeLiteModel{}
|
||||
case "GlmOcrForConditionalGeneration":
|
||||
conv = &glmOcrModel{}
|
||||
case "Lfm2ForCausalLM":
|
||||
conv = &lfm2Model{}
|
||||
case "Qwen3NextForCausalLM":
|
||||
conv = &qwen3NextModel{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
455
convert/convert_glmocr.go
Normal file
455
convert/convert_glmocr.go
Normal file
@@ -0,0 +1,455 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
// normalToNeoXRepacker creates a repacker that permutes Q/K weights from interleaved (LLaMA)
|
||||
// to NeoX ordering for compatibility with GGML's M-RoPE kernel.
|
||||
//
|
||||
// For weights: reshape [out, in] -> [n_heads, head_dim, in], permute rotary dims, reshape back
|
||||
// For biases: reshape [out] -> [n_heads, head_dim], permute rotary dims, reshape back
|
||||
func normalToNeoXRepacker(nHeads, headDim int, partialRotaryFactor float32) func(string, []float32, []uint64) ([]float32, error) {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
rotaryDim := int(float32(headDim) * partialRotaryFactor)
|
||||
if rotaryDim%2 != 0 {
|
||||
rotaryDim = (rotaryDim / 2) * 2 // Round down to even
|
||||
}
|
||||
|
||||
// Handle 1D (bias) or 2D (weight) tensors
|
||||
is1D := len(shape) == 1
|
||||
var inFeatures int
|
||||
if is1D {
|
||||
inFeatures = 1
|
||||
} else {
|
||||
inFeatures = int(shape[1])
|
||||
}
|
||||
outFeatures := int(shape[0])
|
||||
nEffectiveHeads := outFeatures / headDim
|
||||
|
||||
if nEffectiveHeads != nHeads {
|
||||
slog.Warn("normalToNeoX: unexpected head count", "effective", nEffectiveHeads, "expected", nHeads)
|
||||
}
|
||||
|
||||
// Reshape to [n_heads, head_dim, in_features]
|
||||
reshaped := make([]float32, len(data))
|
||||
copy(reshaped, data)
|
||||
|
||||
// Permute the rotary dimensions: even indices first, then odd
|
||||
// For each head, reorder [0,1,2,3,4,5...] to [0,2,4...,1,3,5...]
|
||||
result := make([]float32, len(data))
|
||||
halfRotary := rotaryDim / 2
|
||||
|
||||
for h := range nEffectiveHeads {
|
||||
for f := range inFeatures {
|
||||
for i := range halfRotary {
|
||||
// Even dim (0, 2, 4, ...) -> position i
|
||||
srcIdx := h*headDim*inFeatures + (2*i)*inFeatures + f
|
||||
dstIdx := h*headDim*inFeatures + i*inFeatures + f
|
||||
result[dstIdx] = reshaped[srcIdx]
|
||||
|
||||
// Odd dim (1, 3, 5, ...) -> position halfRotary + i
|
||||
srcIdx = h*headDim*inFeatures + (2*i+1)*inFeatures + f
|
||||
dstIdx = h*headDim*inFeatures + (halfRotary+i)*inFeatures + f
|
||||
result[dstIdx] = reshaped[srcIdx]
|
||||
}
|
||||
|
||||
// Non-rotary part: copy as-is
|
||||
for i := rotaryDim; i < headDim; i++ {
|
||||
srcIdx := h*headDim*inFeatures + i*inFeatures + f
|
||||
result[srcIdx] = reshaped[srcIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
type glmOcrModel struct {
|
||||
ModelParameters
|
||||
|
||||
TextConfig struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
MaxPositionEmbed uint32 `json:"max_position_embeddings"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeParameters struct {
|
||||
RopeType string `json:"rope_type"`
|
||||
MRopeSection []int32 `json:"mrope_section"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
|
||||
VisionConfig struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
Depth uint32 `json:"depth"`
|
||||
NumHeads uint32 `json:"num_heads"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
OutHiddenSize uint32 `json:"out_hidden_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
} `json:"vision_config"`
|
||||
|
||||
ImageStartTokenID uint32 `json:"image_start_token_id"`
|
||||
ImageEndTokenID uint32 `json:"image_end_token_id"`
|
||||
VideoStartTokenID uint32 `json:"video_start_token_id"`
|
||||
VideoEndTokenID uint32 `json:"video_end_token_id"`
|
||||
ImageTokenID uint32 `json:"image_token_id"`
|
||||
VideoTokenID uint32 `json:"video_token_id"`
|
||||
|
||||
// Preprocessor config (preprocessor_config.json)
|
||||
Preprocessor struct {
|
||||
Size struct {
|
||||
ShortestEdge uint32 `json:"shortest_edge"`
|
||||
LongestEdge uint32 `json:"longest_edge"`
|
||||
} `json:"size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
MergeSize uint32 `json:"merge_size"`
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
} `json:"-"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*glmOcrModel)(nil)
|
||||
|
||||
func (m *glmOcrModel) parseMore(fsys fs.FS) error {
|
||||
bts, err := fs.ReadFile(fsys, "preprocessor_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(bts, &m.Preprocessor)
|
||||
}
|
||||
|
||||
func (m *glmOcrModel) KV(t *Tokenizer) KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "glmocr"
|
||||
|
||||
// Text model parameters
|
||||
kv["glmocr.block_count"] = cmp.Or(m.TextConfig.NumHiddenLayers, 16)
|
||||
kv["glmocr.embedding_length"] = cmp.Or(m.TextConfig.HiddenSize, 1536)
|
||||
kv["glmocr.attention.head_count"] = cmp.Or(m.TextConfig.NumAttentionHeads, 16)
|
||||
kv["glmocr.attention.head_count_kv"] = cmp.Or(m.TextConfig.NumKeyValueHeads, 8)
|
||||
headDim := cmp.Or(m.TextConfig.HeadDim, m.TextConfig.HiddenSize/m.TextConfig.NumAttentionHeads)
|
||||
kv["glmocr.attention.key_length"] = headDim
|
||||
kv["glmocr.attention.value_length"] = headDim
|
||||
kv["glmocr.feed_forward_length"] = cmp.Or(m.TextConfig.IntermediateSize, 4608)
|
||||
kv["glmocr.attention.layer_norm_rms_epsilon"] = cmp.Or(m.TextConfig.RMSNormEps, 1e-5)
|
||||
kv["glmocr.context_length"] = cmp.Or(m.TextConfig.MaxPositionEmbed, 131072)
|
||||
kv["glmocr.rope.freq_base"] = cmp.Or(m.TextConfig.RopeParameters.RopeTheta, float32(10000))
|
||||
kv["glmocr.rope.partial_rotary_factor"] = cmp.Or(m.TextConfig.RopeParameters.PartialRotaryFactor, m.TextConfig.PartialRotaryFactor, float32(1.0))
|
||||
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 {
|
||||
kv["glmocr.rope.mrope_section"] = m.TextConfig.RopeParameters.MRopeSection
|
||||
}
|
||||
|
||||
// Vision model parameters
|
||||
kv["glmocr.vision.block_count"] = cmp.Or(m.VisionConfig.Depth, 24)
|
||||
kv["glmocr.vision.embedding_length"] = cmp.Or(m.VisionConfig.HiddenSize, 1024)
|
||||
kv["glmocr.vision.attention.head_count"] = cmp.Or(m.VisionConfig.NumHeads, 16)
|
||||
kv["glmocr.vision.image_size"] = cmp.Or(m.VisionConfig.ImageSize, 336)
|
||||
kv["glmocr.vision.patch_size"] = cmp.Or(m.VisionConfig.PatchSize, m.Preprocessor.PatchSize, 14)
|
||||
kv["glmocr.vision.spatial_merge_size"] = cmp.Or(m.VisionConfig.SpatialMergeSize, m.Preprocessor.MergeSize, 2)
|
||||
kv["glmocr.vision.temporal_patch_size"] = cmp.Or(m.VisionConfig.TemporalPatchSize, m.Preprocessor.TemporalPatchSize, 2)
|
||||
kv["glmocr.vision.out_hidden_size"] = cmp.Or(m.VisionConfig.OutHiddenSize, 1536)
|
||||
kv["glmocr.vision.intermediate_size"] = cmp.Or(m.VisionConfig.IntermediateSize, 4096)
|
||||
kv["glmocr.vision.attention.layer_norm_rms_epsilon"] = cmp.Or(m.VisionConfig.RMSNormEps, 1e-5)
|
||||
|
||||
// Preprocessor-derived image settings (min/max pixels and normalization)
|
||||
// Note: fs.Config.keyValue() auto-prepends architecture prefix, so use full key
|
||||
if m.Preprocessor.Size.ShortestEdge > 0 {
|
||||
kv["glmocr.vision.min_pixels"] = m.Preprocessor.Size.ShortestEdge
|
||||
}
|
||||
if m.Preprocessor.Size.LongestEdge > 0 {
|
||||
kv["glmocr.vision.max_pixels"] = m.Preprocessor.Size.LongestEdge
|
||||
}
|
||||
if len(m.Preprocessor.ImageMean) == 3 {
|
||||
kv["glmocr.vision.image_mean"] = m.Preprocessor.ImageMean
|
||||
}
|
||||
if len(m.Preprocessor.ImageStd) == 3 {
|
||||
kv["glmocr.vision.image_std"] = m.Preprocessor.ImageStd
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
kv["glmocr.image_token_id"] = m.ImageTokenID
|
||||
kv["glmocr.image_start_token_id"] = m.ImageStartTokenID
|
||||
kv["glmocr.image_end_token_id"] = m.ImageEndTokenID
|
||||
kv["glmocr.video_token_id"] = m.VideoTokenID
|
||||
kv["glmocr.video_start_token_id"] = m.VideoStartTokenID
|
||||
kv["glmocr.video_end_token_id"] = m.VideoEndTokenID
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *glmOcrModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
// Skip layers >= num_hidden_layers (Multi-Token Prediction layers not needed for basic inference)
|
||||
numLayers := int(cmp.Or(m.TextConfig.NumHiddenLayers, 16))
|
||||
skipLayer := func(name string) bool {
|
||||
// Tensor names are already replaced to "blk.N.xxx" format
|
||||
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(name)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
blkNum, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return blkNum >= numLayers
|
||||
}
|
||||
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
|
||||
// Skip next-n prediction layers (layers >= num_hidden_layers)
|
||||
if skipLayer(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Split ffn_gate_up into separate gate and up projections
|
||||
if strings.Contains(name, "ffn_gate_up") {
|
||||
for t := range splitDim(t, 0,
|
||||
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_gate")},
|
||||
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_up")},
|
||||
) {
|
||||
out = append(out, t)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasSuffix(name, "patch_embd.weight") {
|
||||
shape := t.Shape()
|
||||
if len(shape) == 5 && shape[2] == 2 {
|
||||
newShape := []uint64{shape[0], shape[1], shape[3], shape[4]}
|
||||
|
||||
t0 := t.Clone()
|
||||
t0.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
tt, err := tt.Slice(nil, nil, tensor.S(0, 1), nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
|
||||
if err := tt.Reshape(newDims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t0,
|
||||
})
|
||||
|
||||
t1 := t.Clone()
|
||||
t1.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
tt, err := tt.Slice(nil, nil, tensor.S(1, 2), nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
|
||||
if err := tt.Reshape(newDims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_1.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t1,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if len(shape) == 4 {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Warn("glmocr: patch_embed weight has unexpected shape - not splitting", "shape", shape)
|
||||
// Fall through to default handling
|
||||
}
|
||||
|
||||
// Handle pre-split patch embedding weights
|
||||
// Pattern 1: v.patch_embd.0.weight, v.patch_embd.1.weight -> patch_embd_0.weight, patch_embd_1.weight
|
||||
// Pattern 2: v.patch_embd.weight.0, v.patch_embd.weight.1 -> patch_embd_0.weight, patch_embd_1.weight
|
||||
if strings.Contains(name, "patch_embd.0.") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.0.", "patch_embd_0.", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if strings.Contains(name, "patch_embd.1.") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.1.", "patch_embd_1.", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
// Handle .weight.0 and .weight.1 suffix patterns
|
||||
if strings.HasSuffix(name, "patch_embd.weight.0") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight.0", "patch_embd_0.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, "patch_embd.weight.1") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight.1", "patch_embd_1.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Permute Q/K weights for M-RoPE compatibility (interleaved -> NeoX ordering)
|
||||
// GGML's M-RoPE kernel uses NeoX-style rotation, but GLM-OCR uses interleaved (LLaMA-style)
|
||||
// We permute at conversion time so the weights work correctly with GGML's kernel
|
||||
// This aligns Q/K rotary dimensions with GGML's NeoX-style rotation
|
||||
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 &&
|
||||
strings.Contains(name, "blk.") && (strings.Contains(name, "attn_q.") || strings.Contains(name, "attn_k.")) {
|
||||
// Get config values for permutation
|
||||
nHeads := int(cmp.Or(m.TextConfig.NumAttentionHeads, 16))
|
||||
nKVHeads := int(cmp.Or(m.TextConfig.NumKeyValueHeads, 8))
|
||||
hiddenSize := int(cmp.Or(m.TextConfig.HiddenSize, 1536))
|
||||
headDim := int(cmp.Or(m.TextConfig.HeadDim, uint32(hiddenSize/nHeads)))
|
||||
partialRotaryFactor := cmp.Or(m.TextConfig.PartialRotaryFactor, m.TextConfig.RopeParameters.PartialRotaryFactor, float32(1.0))
|
||||
|
||||
// Use appropriate head count: nHeads for Q, nKVHeads for K
|
||||
effectiveHeads := nHeads
|
||||
if strings.Contains(name, "attn_k.") {
|
||||
effectiveHeads = nKVHeads
|
||||
}
|
||||
|
||||
permutedT := t.Clone()
|
||||
permutedT.SetRepacker(normalToNeoXRepacker(effectiveHeads, headDim, partialRotaryFactor))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: permutedT,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *glmOcrModel) Replacements() []string {
|
||||
return []string{
|
||||
// Vision encoder
|
||||
"model.visual.patch_embed.proj_1", "v.patch_embd_1", // Second temporal split
|
||||
"model.visual.patch_embed.proj", "v.patch_embd",
|
||||
"model.visual.blocks", "v.blk",
|
||||
"model.visual.post_layernorm", "v.post_ln",
|
||||
"model.visual.downsample", "mm.patch_merger",
|
||||
|
||||
// Vision attention
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.proj", "attn_out",
|
||||
"attn.q_norm", "attn_q_norm",
|
||||
"attn.k_norm", "attn_k_norm",
|
||||
|
||||
// Vision norms
|
||||
"norm1", "ln1",
|
||||
"norm2", "ln2",
|
||||
|
||||
// Vision MLP
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
|
||||
// Merger (multimodal projector)
|
||||
"model.visual.merger.proj", "mm.model.fc",
|
||||
"model.visual.merger.post_projection_norm", "mm.post_norm",
|
||||
"model.visual.merger.gate_proj", "mm.gate",
|
||||
"model.visual.merger.up_proj", "mm.up",
|
||||
"model.visual.merger.down_proj", "mm.down",
|
||||
|
||||
// Language model
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.layers", "blk",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"lm_head", "output",
|
||||
|
||||
// Language model attention
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_out",
|
||||
|
||||
// Language model norms
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"post_self_attn_layernorm", "post_attn_norm",
|
||||
"post_mlp_layernorm", "post_ffn_norm",
|
||||
|
||||
// Language model MLP (remove mlp. prefix so ffn_* names work)
|
||||
"mlp.gate_up_proj", "ffn_gate_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
}
|
||||
}
|
||||
512
convert/convert_qwen3next.go
Normal file
512
convert/convert_qwen3next.go
Normal file
@@ -0,0 +1,512 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type qwen3NextModel struct {
|
||||
ModelParameters
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
|
||||
// MoE config
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
NormTopkProb bool `json:"norm_topk_prob"`
|
||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
|
||||
|
||||
// Hybrid attention config
|
||||
FullAttentionInterval uint32 `json:"full_attention_interval"`
|
||||
|
||||
// Linear attention (Gated Delta Net) config
|
||||
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
|
||||
LinearKeyHeadDim uint32 `json:"linear_key_head_dim"`
|
||||
LinearNumKeyHeads uint32 `json:"linear_num_key_heads"`
|
||||
LinearNumValueHeads uint32 `json:"linear_num_value_heads"`
|
||||
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
|
||||
|
||||
// RoPE config
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeScaling struct {
|
||||
Type string `json:"type"`
|
||||
Factor ropeFactor `json:"factor"`
|
||||
} `json:"rope_scaling"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*qwen3NextModel)(nil)
|
||||
|
||||
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
|
||||
if q.NumHiddenLayers == 0 {
|
||||
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
|
||||
}
|
||||
if q.NumAttentionHeads == 0 {
|
||||
return fmt.Errorf("qwen3next: num_attention_heads must be set")
|
||||
}
|
||||
if q.NumKeyValueHeads == 0 {
|
||||
return fmt.Errorf("qwen3next: num_key_value_heads must be set")
|
||||
}
|
||||
if q.HeadDim == 0 {
|
||||
return fmt.Errorf("qwen3next: head_dim must be set")
|
||||
}
|
||||
if q.RopeTheta == 0 {
|
||||
return fmt.Errorf("qwen3next: rope_theta must be set")
|
||||
}
|
||||
if q.PartialRotaryFactor <= 0 || q.PartialRotaryFactor > 1 {
|
||||
return fmt.Errorf("qwen3next: partial_rotary_factor must be in (0,1], got %v", q.PartialRotaryFactor)
|
||||
}
|
||||
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
|
||||
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
|
||||
}
|
||||
if q.FullAttentionInterval == 0 {
|
||||
return fmt.Errorf("qwen3next: full_attention_interval must be set")
|
||||
}
|
||||
if q.FullAttentionInterval > q.NumHiddenLayers {
|
||||
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
}
|
||||
|
||||
hasFull := false
|
||||
for i := range q.NumHiddenLayers {
|
||||
if (i+1)%q.FullAttentionInterval == 0 {
|
||||
hasFull = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasFull {
|
||||
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen3next"
|
||||
kv["tokenizer.ggml.pre"] = "qwen2"
|
||||
kv["block_count"] = q.NumHiddenLayers
|
||||
kv["context_length"] = q.MaxPositionEmbeddings
|
||||
kv["embedding_length"] = q.HiddenSize
|
||||
kv["feed_forward_length"] = q.IntermediateSize
|
||||
kv["attention.head_count"] = q.NumAttentionHeads
|
||||
headDim := q.HeadDim
|
||||
if headDim == 0 && q.NumAttentionHeads > 0 {
|
||||
headDim = q.HiddenSize / q.NumAttentionHeads
|
||||
}
|
||||
kv["attention.key_length"] = headDim
|
||||
kv["attention.value_length"] = headDim
|
||||
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
||||
kv["rope.freq_base"] = q.RopeTheta
|
||||
|
||||
// RoPE dimension count (partial rotary)
|
||||
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
|
||||
partialRotary := q.PartialRotaryFactor
|
||||
if partialRotary > 0 && partialRotary <= 1 {
|
||||
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
|
||||
}
|
||||
|
||||
// MoE config
|
||||
if q.NumExperts > 0 {
|
||||
kv["expert_count"] = q.NumExperts
|
||||
kv["expert_used_count"] = q.NumExpertsPerToken
|
||||
kv["norm_top_k_prob"] = q.NormTopkProb
|
||||
if q.MoEIntermediateSize > 0 {
|
||||
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
|
||||
}
|
||||
if q.SharedExpertIntermSize > 0 {
|
||||
kv["expert_shared_feed_forward_length"] = q.SharedExpertIntermSize
|
||||
}
|
||||
}
|
||||
|
||||
// SSM/Linear attention config
|
||||
// d_inner = linear_value_head_dim * linear_num_value_heads
|
||||
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
|
||||
kv["ssm.inner_size"] = dInner
|
||||
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
|
||||
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
|
||||
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
|
||||
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
|
||||
interval := q.FullAttentionInterval
|
||||
kv["full_attention_interval"] = interval
|
||||
|
||||
// Build per-layer KV head count array to identify layer types
|
||||
// 0 = recurrent (linear attention), non-zero = full attention
|
||||
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
|
||||
for i := range q.NumHiddenLayers {
|
||||
// Full attention every full_attention_interval layers (starting at interval-1)
|
||||
if interval > 0 && (i+1)%interval == 0 {
|
||||
kvHeadCounts[i] = q.NumKeyValueHeads
|
||||
}
|
||||
// else stays 0 (recurrent layer)
|
||||
}
|
||||
kv["attention.head_count_kv"] = kvHeadCounts
|
||||
|
||||
// RoPE scaling
|
||||
if q.RopeScaling.Type != "" {
|
||||
kv["rope.scaling.type"] = q.RopeScaling.Type
|
||||
kv["rope.scaling.factor"] = q.RopeScaling.Factor
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
// Create merges for expert tensors - stack individual experts into batched tensors
|
||||
merges := make([]merge, q.NumHiddenLayers*3)
|
||||
for i := range q.NumHiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
|
||||
// Merge expert tensors
|
||||
merged, remaining := mergeTensors(ts, merges...)
|
||||
out = append(out, merged...)
|
||||
|
||||
// Process remaining tensors
|
||||
for _, t := range remaining {
|
||||
name := t.Name()
|
||||
shape := t.Shape()
|
||||
|
||||
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
|
||||
if strings.HasSuffix(name, ".ssm_in.weight") {
|
||||
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
|
||||
out = append(out, qkv, gate)
|
||||
continue
|
||||
}
|
||||
panic(fmt.Sprintf("qwen3next: failed to split %s into attn_qkv/attn_gate (shape=%v)", name, shape))
|
||||
}
|
||||
|
||||
switch {
|
||||
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
|
||||
// This matches the Python converter behavior for qwen3next
|
||||
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
|
||||
t.SetRepacker(q.addOne)
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
// Handle linear attention A_log -> ssm_a (negate and exp)
|
||||
// Note: name has already been transformed by Replacements at this point
|
||||
case strings.HasSuffix(name, ".ssm_a"):
|
||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
// Compute -exp(A_log)
|
||||
result := make([]float32, len(data))
|
||||
for i, v := range data {
|
||||
// -exp(v)
|
||||
result[i] = -float32(math.Exp(float64(v)))
|
||||
}
|
||||
return result, nil
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
|
||||
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
||||
newShape := slices.Clone(shape)
|
||||
if len(shape) == 3 {
|
||||
if shape[0] == 1 {
|
||||
// [1, D, K] -> [D, K]
|
||||
newShape = []uint64{shape[1], shape[2]}
|
||||
} else if shape[1] == 1 {
|
||||
// [D, 1, K] -> [D, K]
|
||||
newShape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
|
||||
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
|
||||
newShape := slices.Clone(shape)
|
||||
if len(shape) == 2 {
|
||||
if shape[0] == 1 && shape[1] > 1 {
|
||||
newShape = []uint64{shape[1]}
|
||||
} else if shape[1] == 1 && shape[0] > 1 {
|
||||
newShape = []uint64{shape[0]}
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
default:
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
type qkvzSplitSpec struct {
|
||||
hidden int
|
||||
headKDim int
|
||||
headVDim int
|
||||
numKHeads int
|
||||
numVHeads int
|
||||
qkvzDim int
|
||||
qkvOut int
|
||||
gateOut int
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) qkvzSpec(shape []uint64) (qkvzSplitSpec, bool) {
|
||||
if len(shape) != 2 {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
|
||||
numKHeads := int(q.LinearNumKeyHeads)
|
||||
numVHeads := int(q.LinearNumValueHeads)
|
||||
headKDim := int(q.LinearKeyHeadDim)
|
||||
headVDim := int(q.LinearValueHeadDim)
|
||||
if numKHeads == 0 || numVHeads == 0 || headKDim == 0 || headVDim == 0 {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
if numVHeads%numKHeads != 0 {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
|
||||
hidden := int(shape[1])
|
||||
vPerHead := headVDim * (numVHeads / numKHeads)
|
||||
qkvzDim := 2*headKDim + 2*vPerHead
|
||||
expectedOut := qkvzDim * numKHeads
|
||||
if int(shape[0]) != expectedOut {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
|
||||
return qkvzSplitSpec{
|
||||
hidden: hidden,
|
||||
headKDim: headKDim,
|
||||
headVDim: headVDim,
|
||||
numKHeads: numKHeads,
|
||||
numVHeads: numVHeads,
|
||||
qkvzDim: qkvzDim,
|
||||
qkvOut: 2*headKDim*numKHeads + headVDim*numVHeads,
|
||||
gateOut: headVDim * numVHeads,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) splitQKVZTensor(t Tensor) (*ggml.Tensor, *ggml.Tensor, bool) {
|
||||
spec, ok := q.qkvzSpec(t.Shape())
|
||||
if !ok {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
qkvTensor := t.Clone()
|
||||
qkvTensor.SetRepacker(q.repackQKVZ(spec, false))
|
||||
|
||||
gateTensor := t.Clone()
|
||||
gateTensor.SetRepacker(q.repackQKVZ(spec, true))
|
||||
|
||||
qkvName := strings.Replace(t.Name(), "ssm_in", "attn_qkv", 1)
|
||||
gateName := strings.Replace(t.Name(), "ssm_in", "attn_gate", 1)
|
||||
|
||||
return &ggml.Tensor{
|
||||
Name: qkvName,
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(spec.qkvOut), uint64(spec.hidden)},
|
||||
WriterTo: qkvTensor,
|
||||
}, &ggml.Tensor{
|
||||
Name: gateName,
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(spec.gateOut), uint64(spec.hidden)},
|
||||
WriterTo: gateTensor,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repacker {
|
||||
vPerHead := spec.headVDim * (spec.numVHeads / spec.numKHeads)
|
||||
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
var err error
|
||||
|
||||
// Convert to [hidden, out_features] layout for slicing
|
||||
tt, err = tensor.Transpose(tt, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
|
||||
if err := tt.Reshape(spec.hidden, spec.numKHeads, spec.qkvzDim); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
offset := 0
|
||||
qSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset += spec.headKDim
|
||||
kSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset += spec.headKDim
|
||||
vSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset += vPerHead
|
||||
zSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qMat := tensor.Materialize(qSlice).(*tensor.Dense)
|
||||
kMat := tensor.Materialize(kSlice).(*tensor.Dense)
|
||||
vMat := tensor.Materialize(vSlice).(*tensor.Dense)
|
||||
zMat := tensor.Materialize(zSlice).(*tensor.Dense)
|
||||
|
||||
if err := qMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := kMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := vMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := zMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out tensor.Tensor
|
||||
if extractGate {
|
||||
out = zMat
|
||||
} else {
|
||||
out, err = tensor.Concat(1, qMat, kMat, vMat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
out = tensor.Materialize(out)
|
||||
out, err = tensor.Transpose(out, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = tensor.Materialize(out)
|
||||
|
||||
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(out.(*tensor.Dense))
|
||||
}
|
||||
}
|
||||
|
||||
// addOne adds 1.0 to all elements in the tensor (for norm weights)
|
||||
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
|
||||
ones := tensor.Ones(tensor.Float32, int(shape[0]))
|
||||
|
||||
n, err := n.Add(ones)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ts, err := native.SelectF32(n, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
for _, t := range ts {
|
||||
f32s = append(f32s, t...)
|
||||
}
|
||||
|
||||
return f32s, nil
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) Replacements() []string {
|
||||
return []string{
|
||||
// Embeddings and output
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
|
||||
// Layer norms
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
|
||||
// Full attention (self_attn)
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
|
||||
// Linear attention (Gated Delta Net)
|
||||
"linear_attn.in_proj_qkvz", "ssm_in",
|
||||
"linear_attn.in_proj_ba", "ssm_ba",
|
||||
"linear_attn.conv1d", "ssm_conv1d",
|
||||
"linear_attn.dt_bias", "ssm_dt",
|
||||
"linear_attn.dt_proj", "ssm_dt",
|
||||
"linear_attn.A_log", "ssm_a",
|
||||
"linear_attn.norm", "ssm_norm",
|
||||
"linear_attn.out_proj", "ssm_out",
|
||||
|
||||
// MoE (experts are stacked via mergeTensors, not replaced here)
|
||||
"mlp.gate.weight", "ffn_gate_inp.weight",
|
||||
"mlp.shared_expert.down_proj", "ffn_down_shexp",
|
||||
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_expert.up_proj", "ffn_up_shexp",
|
||||
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
|
||||
|
||||
// Dense FFN (if any layers use it)
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
}
|
||||
}
|
||||
@@ -41,6 +41,7 @@ func (t tensorBase) Kind() uint32 {
|
||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||
strings.HasSuffix(t.name, ".bias") ||
|
||||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
|
||||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
|
||||
@@ -99,6 +99,8 @@ func (st safetensor) Kind() uint32 {
|
||||
if st.dtype == "BF16" &&
|
||||
!strings.HasPrefix(st.name, "v.") &&
|
||||
!strings.HasPrefix(st.name, "s.") &&
|
||||
!strings.HasPrefix(st.name, "mm.") &&
|
||||
!strings.Contains(st.name, "ffn_gate_inp_shexp.weight") &&
|
||||
kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
@@ -71,6 +71,10 @@
|
||||
{
|
||||
"source": "/api",
|
||||
"destination": "/api/introduction"
|
||||
},
|
||||
{
|
||||
"source": "/integrations/clawdbot",
|
||||
"destination": "/integrations/openclaw"
|
||||
}
|
||||
],
|
||||
"navigation": {
|
||||
|
||||
@@ -312,7 +312,7 @@ Parallel request processing for a given model results in increasing the context
|
||||
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
|
||||
|
||||
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 \* the number of GPUs or 3 for CPU inference.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time, default 1. Required RAM will scale by `OLLAMA_NUM_PARALLEL` * `OLLAMA_CONTEXT_LENGTH`.
|
||||
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
||||
|
||||
@@ -201,7 +201,7 @@ var (
|
||||
// Enable the new Ollama engine
|
||||
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
||||
// ContextLength sets the default context length
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 0)
|
||||
// Auth enables authentication between the Ollama client and server
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable Vulkan backend
|
||||
@@ -290,7 +290,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
|
||||
|
||||
@@ -282,7 +282,7 @@ func TestVar(t *testing.T) {
|
||||
|
||||
func TestContextLength(t *testing.T) {
|
||||
cases := map[string]uint{
|
||||
"": 4096,
|
||||
"": 0,
|
||||
"2048": 2048,
|
||||
}
|
||||
|
||||
|
||||
@@ -268,8 +268,10 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"olmo3",
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3next",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"glm4moelite",
|
||||
"glmocr",
|
||||
"lfm2",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
@@ -859,11 +861,13 @@ func (f GGML) FlashAttention() bool {
|
||||
"bert",
|
||||
"gemma3",
|
||||
"glm4moelite",
|
||||
"glmocr",
|
||||
"gptoss", "gpt-oss",
|
||||
"lfm2",
|
||||
"mistral3",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3next",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
}, f.KV().String("general.architecture"))
|
||||
}
|
||||
|
||||
1
go.mod
1
go.mod
@@ -27,6 +27,7 @@ require (
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
|
||||
3
go.sum
3
go.sum
@@ -174,6 +174,8 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
|
||||
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
||||
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
|
||||
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -304,6 +306,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
|
||||
@@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
}
|
||||
ChatTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
// TestNumPredict verifies that when num_predict is set, the model generates
|
||||
// exactly that many tokens. It uses logprobs to count the actual tokens output.
|
||||
func TestNumPredict(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
|
||||
t.Fatalf("failed to pull model: %v", err)
|
||||
}
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "qwen3:0.6b",
|
||||
Prompt: "Write a long story.",
|
||||
Stream: &stream,
|
||||
Logprobs: true,
|
||||
Options: map[string]any{
|
||||
"num_predict": 10,
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
|
||||
logprobCount := 0
|
||||
var finalResponse api.GenerateResponse
|
||||
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
|
||||
logprobCount += len(resp.Logprobs)
|
||||
if resp.Done {
|
||||
finalResponse = resp
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("generate failed: %v", err)
|
||||
}
|
||||
|
||||
if logprobCount != 10 {
|
||||
t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
|
||||
logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,3 +75,10 @@ type Cache interface {
|
||||
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||
Remove(seq int, beginIndex, endIndex int32) error
|
||||
}
|
||||
|
||||
// CheckpointCache optionally supports restoring recurrent state to a prior
|
||||
// position to avoid full prompt reprocessing when a prefix mismatch occurs.
|
||||
// The returned position is the number of tokens that can be kept (prefix length).
|
||||
type CheckpointCache interface {
|
||||
PrepareRestore(seq int, targetPos int32) (int32, bool)
|
||||
}
|
||||
|
||||
276
llama/patches/0033-ggml-metal-solve_tri.patch
Normal file
276
llama/patches/0033-ggml-metal-solve_tri.patch
Normal file
@@ -0,0 +1,276 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jeffrey Morgan <jmorganca@gmail.com>
|
||||
Date: Tue, 3 Feb 2026 12:00:00 -0800
|
||||
Subject: [PATCH] ggml: metal solve_tri
|
||||
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal-device.cpp | 20 +++++++
|
||||
ggml/src/ggml-metal/ggml-metal-device.h | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal-device.m | 11 ++++
|
||||
ggml/src/ggml-metal/ggml-metal-impl.h | 21 ++++++++
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 63 +++++++++++++++++++++++
|
||||
ggml/src/ggml-metal/ggml-metal-ops.h | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 60 +++++++++++++++++++++
|
||||
7 files changed, 177 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
index 680904d13..83385c9ef 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
|
||||
return res;
|
||||
}
|
||||
|
||||
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
+ assert(op->op == GGML_OP_SOLVE_TRI);
|
||||
+
|
||||
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||
+
|
||||
+ char base[256];
|
||||
+ char name[256];
|
||||
+
|
||||
+ snprintf(base, 256, "kernel_solve_tri_f32");
|
||||
+ snprintf(name, 256, "%s", base);
|
||||
+
|
||||
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
+ if (!res.pipeline) {
|
||||
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
+ }
|
||||
+
|
||||
+ return res;
|
||||
+}
|
||||
+
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_GROUP_NORM);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
index 0a8b9211a..8a9d17460 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
index 7b5ee968c..4e5acfbe5 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
+ case GGML_OP_SOLVE_TRI:
|
||||
+ return ggml_is_contiguous(op->src[0]) &&
|
||||
+ ggml_is_contiguous(op->src[1]) &&
|
||||
+ op->src[0]->type == GGML_TYPE_F32 &&
|
||||
+ op->src[1]->type == GGML_TYPE_F32 &&
|
||||
+ op->type == GGML_TYPE_F32;
|
||||
+ case GGML_OP_COUNT_EQUAL:
|
||||
+ return has_simdgroup_reduction &&
|
||||
+ op->src[0]->type == GGML_TYPE_I32 &&
|
||||
+ op->src[1]->type == GGML_TYPE_I32 &&
|
||||
+ op->type == GGML_TYPE_I64;
|
||||
case GGML_OP_ARGMAX:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_NORM:
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
index 8944b07e9..cfdea9c07 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
@@ -500,6 +500,27 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
+typedef struct {
|
||||
+ int32_t ne00;
|
||||
+ int32_t ne01;
|
||||
+ int32_t ne02;
|
||||
+ int32_t ne03;
|
||||
+ uint64_t nb00;
|
||||
+ uint64_t nb01;
|
||||
+ uint64_t nb02;
|
||||
+ uint64_t nb03;
|
||||
+ int32_t ne10;
|
||||
+ int32_t ne11;
|
||||
+ uint64_t nb10;
|
||||
+ uint64_t nb11;
|
||||
+ uint64_t nb12;
|
||||
+ uint64_t nb13;
|
||||
+ uint64_t nb0;
|
||||
+ uint64_t nb1;
|
||||
+ uint64_t nb2;
|
||||
+ uint64_t nb3;
|
||||
+} ggml_metal_kargs_solve_tri;
|
||||
+
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index 80864f303..4ac135603 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
||||
} break;
|
||||
+ case GGML_OP_SOLVE_TRI:
|
||||
+ {
|
||||
+ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
|
||||
+ } break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_group_norm(ctx, idx);
|
||||
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
||||
+ ggml_tensor * op = ctx->node(idx);
|
||||
+
|
||||
+ ggml_metal_library_t lib = ctx->lib;
|
||||
+ ggml_metal_encoder_t enc = ctx->enc;
|
||||
+
|
||||
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
+
|
||||
+ ggml_metal_kargs_solve_tri args = {
|
||||
+ /*.ne00 =*/ ne00,
|
||||
+ /*.ne01 =*/ ne01,
|
||||
+ /*.ne02 =*/ ne02,
|
||||
+ /*.ne03 =*/ ne03,
|
||||
+ /*.nb00 =*/ nb00,
|
||||
+ /*.nb01 =*/ nb01,
|
||||
+ /*.nb02 =*/ nb02,
|
||||
+ /*.nb03 =*/ nb03,
|
||||
+ /*.ne10 =*/ ne10,
|
||||
+ /*.ne11 =*/ ne11,
|
||||
+ /*.nb10 =*/ nb10,
|
||||
+ /*.nb11 =*/ nb11,
|
||||
+ /*.nb12 =*/ nb12,
|
||||
+ /*.nb13 =*/ nb13,
|
||||
+ /*.nb0 =*/ nb0,
|
||||
+ /*.nb1 =*/ nb1,
|
||||
+ /*.nb2 =*/ nb2,
|
||||
+ /*.nb3 =*/ nb3,
|
||||
+ };
|
||||
+
|
||||
+ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
|
||||
+
|
||||
+ const int64_t ncols = ne10;
|
||||
+ const int64_t n_batches = (int64_t)ne02 * ne03;
|
||||
+ const int64_t nr = n_batches * ncols;
|
||||
+
|
||||
+ int nth = 64;
|
||||
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
+ if (nth < 1) {
|
||||
+ nth = 1;
|
||||
+ }
|
||||
+
|
||||
+ const int64_t n_tg = (nr + nth - 1) / nth;
|
||||
+
|
||||
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
+
|
||||
+ ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
|
||||
+
|
||||
+ return 1;
|
||||
+}
|
||||
+
|
||||
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h
|
||||
index 902b54452..a475183d3 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.h
|
||||
@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
|
||||
+int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index d33c16079..c37447a10 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
+kernel void kernel_solve_tri_f32(
|
||||
+ constant ggml_metal_kargs_solve_tri & args,
|
||||
+ device const char * src0,
|
||||
+ device const char * src1,
|
||||
+ device char * dst,
|
||||
+ uint tgpig[[threadgroup_position_in_grid]],
|
||||
+ ushort tpitg[[thread_position_in_threadgroup]],
|
||||
+ ushort ntg[[threads_per_threadgroup]]) {
|
||||
+ const uint64_t ncols = (uint64_t) args.ne10;
|
||||
+ const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
|
||||
+ const uint64_t nr = n_batches * ncols;
|
||||
+
|
||||
+ const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
|
||||
+ if (gid >= nr) {
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
|
||||
+ const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
|
||||
+ const uint64_t i02 = rem / ncols;
|
||||
+ const uint64_t i01 = rem - i02 * ncols;
|
||||
+
|
||||
+ const uint64_t sa0 = args.nb00 / sizeof(float);
|
||||
+ const uint64_t sa1 = args.nb01 / sizeof(float);
|
||||
+ const uint64_t sa2 = args.nb02 / sizeof(float);
|
||||
+ const uint64_t sa3 = args.nb03 / sizeof(float);
|
||||
+
|
||||
+ const uint64_t sb0 = args.nb10 / sizeof(float);
|
||||
+ const uint64_t sb1 = args.nb11 / sizeof(float);
|
||||
+ const uint64_t sb2 = args.nb12 / sizeof(float);
|
||||
+ const uint64_t sb3 = args.nb13 / sizeof(float);
|
||||
+
|
||||
+ const uint64_t sx0 = args.nb0 / sizeof(float);
|
||||
+ const uint64_t sx1 = args.nb1 / sizeof(float);
|
||||
+ const uint64_t sx2 = args.nb2 / sizeof(float);
|
||||
+ const uint64_t sx3 = args.nb3 / sizeof(float);
|
||||
+
|
||||
+ device const float * A = (device const float *) src0;
|
||||
+ device const float * B = (device const float *) src1;
|
||||
+ device float * X = (device float *) dst;
|
||||
+
|
||||
+ const uint64_t A_base = i02 * sa2 + i03 * sa3;
|
||||
+ const uint64_t B_base = i02 * sb2 + i03 * sb3;
|
||||
+ const uint64_t X_base = i02 * sx2 + i03 * sx3;
|
||||
+
|
||||
+ const uint64_t n = (uint64_t) args.ne11;
|
||||
+
|
||||
+ for (uint64_t i00 = 0; i00 < n; ++i00) {
|
||||
+ float sum = 0.0f;
|
||||
+ for (uint64_t t = 0; t < i00; ++t) {
|
||||
+ sum += A[A_base + i00 * sa1 + t * sa0] *
|
||||
+ X[X_base + t * sx1 + i01 * sx0];
|
||||
+ }
|
||||
+
|
||||
+ const float diag = A[A_base + i00 * sa1 + i00 * sa0];
|
||||
+ X[X_base + i00 * sx1 + i01 * sx0] =
|
||||
+ (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type filteredEnv []string
|
||||
@@ -80,6 +81,7 @@ type LlamaServer interface {
|
||||
GetPort() int
|
||||
GetDeviceInfos(ctx context.Context) []ml.DeviceInfo
|
||||
HasExited() bool
|
||||
ContextLength() int
|
||||
}
|
||||
|
||||
// llmServer is an instance of a runner hosting a single model
|
||||
@@ -115,7 +117,7 @@ type llamaServer struct {
|
||||
type ollamaServer struct {
|
||||
llmServer
|
||||
|
||||
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
|
||||
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
|
||||
}
|
||||
|
||||
// LoadModel will load a model from disk. The model must be in the GGML format.
|
||||
@@ -141,11 +143,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
var llamaModel *llama.Model
|
||||
var textProcessor model.TextProcessor
|
||||
var tok tokenizer.Tokenizer
|
||||
var err error
|
||||
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
||||
if len(projectors) == 0 {
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
tok, err = model.NewTextProcessor(modelPath)
|
||||
} else {
|
||||
err = errors.New("split vision models aren't supported")
|
||||
}
|
||||
@@ -154,7 +156,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||
}
|
||||
}
|
||||
if textProcessor == nil {
|
||||
if tok == nil {
|
||||
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -210,7 +212,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
|
||||
if textProcessor == nil {
|
||||
if tok == nil {
|
||||
flashAttention := ml.FlashAttentionAuto
|
||||
if faUserSet {
|
||||
if fa {
|
||||
@@ -260,7 +262,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
gpuLibs := ml.LibraryPaths(gpus)
|
||||
status := NewStatusWriter(os.Stderr)
|
||||
cmd, port, err := StartRunner(
|
||||
textProcessor != nil,
|
||||
tok != nil,
|
||||
modelPath,
|
||||
gpuLibs,
|
||||
status,
|
||||
@@ -309,8 +311,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
}
|
||||
}()
|
||||
|
||||
if textProcessor != nil {
|
||||
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
|
||||
if tok != nil {
|
||||
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
|
||||
} else {
|
||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||
}
|
||||
@@ -1200,7 +1202,8 @@ func (s *llmServer) initModel(ctx context.Context, req LoadRequest, operation Lo
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do load request: %w", err)
|
||||
slog.Error("do load request", "error", err)
|
||||
return nil, errors.New("model failed to load, this may be due to resource limitations or an internal error, check ollama server logs for details")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
@@ -1772,7 +1775,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
|
||||
}
|
||||
|
||||
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
tokens, err := s.textProcessor.Encode(content, false)
|
||||
tokens, err := s.tokenizer.Encode(content, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1807,7 +1810,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
|
||||
toks[i] = int32(t)
|
||||
}
|
||||
|
||||
content, err := s.textProcessor.Decode(toks)
|
||||
content, err := s.tokenizer.Decode(toks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1901,6 +1904,10 @@ func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *llmServer) ContextLength() int {
|
||||
return s.options.NumCtx
|
||||
}
|
||||
|
||||
func (s *ollamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
devices, err := ml.GetDevicesFromRunner(ctx, s)
|
||||
if err != nil {
|
||||
|
||||
@@ -131,12 +131,15 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
|
||||
messageID := anthropic.GenerateMessageID()
|
||||
|
||||
// Estimate input tokens for streaming (actual count not available until generation completes)
|
||||
estimatedTokens := anthropic.EstimateInputTokens(req)
|
||||
|
||||
w := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model),
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
|
||||
@@ -170,10 +170,12 @@ type Tensor interface {
|
||||
Cos(ctx Context) Tensor
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context, up ...Tensor) Tensor
|
||||
GELU_ERF(ctx Context) Tensor
|
||||
QuickGELU(ctx Context, up ...Tensor) Tensor
|
||||
SILU(ctx Context, up ...Tensor) Tensor
|
||||
RELU(ctx Context, up ...Tensor) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
SigmoidOut(ctx Context) Tensor
|
||||
|
||||
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
@@ -206,6 +208,32 @@ type Tensor interface {
|
||||
Stddev(ctx Context) Tensor
|
||||
Sqr(ctx Context) Tensor
|
||||
Sqrt(ctx Context) Tensor
|
||||
Exp(ctx Context) Tensor
|
||||
Neg(ctx Context) Tensor
|
||||
|
||||
// Clamp clamps values to [min, max] range
|
||||
Clamp(ctx Context, min, max float32) Tensor
|
||||
|
||||
// Softplus computes ln(1 + exp(x))
|
||||
Softplus(ctx Context) Tensor
|
||||
|
||||
// CumSum computes cumulative sum along dimension 0
|
||||
CumSum(ctx Context) Tensor
|
||||
|
||||
// Diag creates a diagonal matrix from a 1D tensor
|
||||
Diag(ctx Context) Tensor
|
||||
|
||||
// Tri converts a matrix to triangular form (0=upper+diag, 1=upper, 2=lower+diag, 3=lower)
|
||||
Tri(ctx Context, triType int) Tensor
|
||||
|
||||
// Fill fills a tensor with a constant value (in-place)
|
||||
Fill(ctx Context, value float32) Tensor
|
||||
|
||||
// Repeat4D repeats tensor to match target shape
|
||||
Repeat4D(ctx Context, dim0, dim1, dim2, dim3 int) Tensor
|
||||
|
||||
// SolveTri solves a triangular system Ax = B
|
||||
SolveTri(ctx Context, b Tensor, lower, left, unitDiag bool) Tensor
|
||||
|
||||
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||
}
|
||||
|
||||
@@ -378,7 +378,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
}
|
||||
|
||||
maxGraphNodes := max(1024, len(meta.Tensors().Items())*8)
|
||||
maxGraphNodes := max(1024, len(meta.Tensors().Items())*32)
|
||||
|
||||
sched := C.ggml_backend_sched_new_ext(
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
|
||||
@@ -1468,6 +1468,13 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SigmoidOut(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sigmoid(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
@@ -1581,6 +1588,13 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) GELU_ERF(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_erf_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
var tt *C.struct_ggml_tensor
|
||||
if len(t2) > 0 {
|
||||
@@ -1772,6 +1786,76 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Exp(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_exp(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Softplus(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_softplus(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) CumSum(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cumsum(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Diag(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_diag(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Tri(ctx ml.Context, triType int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tri(ctx.(*Context).ctx, t.t, C.enum_ggml_tri_type(triType)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Fill(ctx ml.Context, value float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_fill_inplace(ctx.(*Context).ctx, t.t, C.float(value)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Repeat4D(ctx ml.Context, dim0, dim1, dim2, dim3 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_repeat_4d(ctx.(*Context).ctx, t.t, C.int64_t(dim0), C.int64_t(dim1), C.int64_t(dim2), C.int64_t(dim3)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SolveTri(ctx ml.Context, b ml.Tensor, lower, left, unitDiag bool) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_solve_tri(ctx.(*Context).ctx, t.t, b.(*Tensor).t, C._Bool(lower), C._Bool(left), C._Bool(unitDiag)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
|
||||
var mode C.uint32_t
|
||||
switch samplingMode {
|
||||
|
||||
@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_SOLVE_TRI);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_solve_tri_f32");
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_GROUP_NORM);
|
||||
|
||||
|
||||
@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
return ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op->src[1]) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
return has_simdgroup_reduction &&
|
||||
op->src[0]->type == GGML_TYPE_I32 &&
|
||||
op->src[1]->type == GGML_TYPE_I32 &&
|
||||
op->type == GGML_TYPE_I64;
|
||||
case GGML_OP_ARGMAX:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_NORM:
|
||||
|
||||
@@ -2385,6 +2385,27 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_solve_tri;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
@@ -5813,6 +5834,66 @@ kernel void kernel_l2_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_solve_tri_f32(
|
||||
constant ggml_metal_kargs_solve_tri & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
const uint64_t ncols = (uint64_t) args.ne10;
|
||||
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
|
||||
const uint64_t nr = n_batches * ncols;
|
||||
|
||||
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
|
||||
if (gid >= nr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
|
||||
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
|
||||
const uint64_t i02 = rem / ncols;
|
||||
const uint64_t i01 = rem - i02 * ncols;
|
||||
|
||||
const uint64_t sa0 = args.nb00 / sizeof(float);
|
||||
const uint64_t sa1 = args.nb01 / sizeof(float);
|
||||
const uint64_t sa2 = args.nb02 / sizeof(float);
|
||||
const uint64_t sa3 = args.nb03 / sizeof(float);
|
||||
|
||||
const uint64_t sb0 = args.nb10 / sizeof(float);
|
||||
const uint64_t sb1 = args.nb11 / sizeof(float);
|
||||
const uint64_t sb2 = args.nb12 / sizeof(float);
|
||||
const uint64_t sb3 = args.nb13 / sizeof(float);
|
||||
|
||||
const uint64_t sx0 = args.nb0 / sizeof(float);
|
||||
const uint64_t sx1 = args.nb1 / sizeof(float);
|
||||
const uint64_t sx2 = args.nb2 / sizeof(float);
|
||||
const uint64_t sx3 = args.nb3 / sizeof(float);
|
||||
|
||||
device const float * A = (device const float *) src0;
|
||||
device const float * B = (device const float *) src1;
|
||||
device float * X = (device float *) dst;
|
||||
|
||||
const uint64_t A_base = i02 * sa2 + i03 * sa3;
|
||||
const uint64_t B_base = i02 * sb2 + i03 * sb3;
|
||||
const uint64_t X_base = i02 * sx2 + i03 * sx3;
|
||||
|
||||
const uint64_t n = (uint64_t) args.ne11;
|
||||
|
||||
for (uint64_t i00 = 0; i00 < n; ++i00) {
|
||||
float sum = 0.0f;
|
||||
for (uint64_t t = 0; t < i00; ++t) {
|
||||
sum += A[A_base + i00 * sa1 + t * sa0] *
|
||||
X[X_base + t * sx1 + i01 * sx0];
|
||||
}
|
||||
|
||||
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
|
||||
X[X_base + i00 * sx1 + i01 * sx0] =
|
||||
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
|
||||
@@ -500,6 +500,27 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_solve_tri;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
|
||||
@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_group_norm(ctx, idx);
|
||||
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_solve_tri args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
|
||||
|
||||
const int64_t ncols = ne10;
|
||||
const int64_t n_batches = (int64_t)ne02 * ne03;
|
||||
const int64_t nr = n_batches * ncols;
|
||||
|
||||
int nth = 64;
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
if (nth < 1) {
|
||||
nth = 1;
|
||||
}
|
||||
|
||||
const int64_t n_tg = (nr + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
||||
@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_solve_tri_f32(
|
||||
constant ggml_metal_kargs_solve_tri & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
const uint64_t ncols = (uint64_t) args.ne10;
|
||||
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
|
||||
const uint64_t nr = n_batches * ncols;
|
||||
|
||||
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
|
||||
if (gid >= nr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
|
||||
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
|
||||
const uint64_t i02 = rem / ncols;
|
||||
const uint64_t i01 = rem - i02 * ncols;
|
||||
|
||||
const uint64_t sa0 = args.nb00 / sizeof(float);
|
||||
const uint64_t sa1 = args.nb01 / sizeof(float);
|
||||
const uint64_t sa2 = args.nb02 / sizeof(float);
|
||||
const uint64_t sa3 = args.nb03 / sizeof(float);
|
||||
|
||||
const uint64_t sb0 = args.nb10 / sizeof(float);
|
||||
const uint64_t sb1 = args.nb11 / sizeof(float);
|
||||
const uint64_t sb2 = args.nb12 / sizeof(float);
|
||||
const uint64_t sb3 = args.nb13 / sizeof(float);
|
||||
|
||||
const uint64_t sx0 = args.nb0 / sizeof(float);
|
||||
const uint64_t sx1 = args.nb1 / sizeof(float);
|
||||
const uint64_t sx2 = args.nb2 / sizeof(float);
|
||||
const uint64_t sx3 = args.nb3 / sizeof(float);
|
||||
|
||||
device const float * A = (device const float *) src0;
|
||||
device const float * B = (device const float *) src1;
|
||||
device float * X = (device float *) dst;
|
||||
|
||||
const uint64_t A_base = i02 * sa2 + i03 * sa3;
|
||||
const uint64_t B_base = i02 * sb2 + i03 * sb3;
|
||||
const uint64_t X_base = i02 * sx2 + i03 * sx3;
|
||||
|
||||
const uint64_t n = (uint64_t) args.ne11;
|
||||
|
||||
for (uint64_t i00 = 0; i00 < n; ++i00) {
|
||||
float sum = 0.0f;
|
||||
for (uint64_t t = 0; t < i00; ++t) {
|
||||
sum += A[A_base + i00 * sa1 + t * sa0] *
|
||||
X[X_base + t * sx1 + i01 * sx0];
|
||||
}
|
||||
|
||||
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
|
||||
X[X_base + i00 * sx1 + i01 * sx0] =
|
||||
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type BytePairEncoding struct {
|
||||
vocab *Vocabulary
|
||||
regexps []*regexp2.Regexp
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||
if len(pretokenizers) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizers {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||
return bpe.vocab
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||
return bpe.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||
parts := []string{s}
|
||||
for _, re := range bpe.regexps {
|
||||
parts = slices.Collect(func(yield func(string) bool) {
|
||||
for _, part := range parts {
|
||||
r := []rune(part)
|
||||
var offset int
|
||||
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
|
||||
if offset-m.Index != 0 {
|
||||
if !yield(string(r[:m.Index])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !yield(m.String()) {
|
||||
return
|
||||
}
|
||||
|
||||
offset = m.Index + m.Length
|
||||
}
|
||||
|
||||
if offset < len(r) {
|
||||
if !yield(string(r[offset:])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return slices.Values(parts)
|
||||
}
|
||||
|
||||
// fragment is a string fragment and their corresponding token IDs
|
||||
type fragment struct {
|
||||
value string
|
||||
ids []int32
|
||||
}
|
||||
|
||||
// pair is a pair of runes and its rank
|
||||
type pair struct {
|
||||
a, b int
|
||||
rank int
|
||||
value string
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
p, n int
|
||||
runes []rune
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
id := bpe.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var middle []fragment
|
||||
switch i := strings.Index(frag.value, special); {
|
||||
case i < 0:
|
||||
middle = append(middle, frag)
|
||||
case i > 0:
|
||||
middle = append(middle, fragment{value: frag.value[:i]})
|
||||
fallthrough
|
||||
default:
|
||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||
if rest := frag.value[i+len(special):]; rest != "" {
|
||||
middle = append(middle, fragment{value: rest})
|
||||
}
|
||||
}
|
||||
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
continue
|
||||
}
|
||||
|
||||
for split := range bpe.split(frag.value) {
|
||||
// TODO: process splits concurrently
|
||||
var sb strings.Builder
|
||||
for _, b := range []byte(split) {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
|
||||
// short circuit if the fragment is in the vocabulary
|
||||
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
runes := []rune(sb.String())
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *pair {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
rank := bpe.vocab.Merge(left, right)
|
||||
if rank < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &pair{
|
||||
a: a,
|
||||
b: b,
|
||||
rank: rank,
|
||||
value: left + right,
|
||||
}
|
||||
}
|
||||
|
||||
pairs := heap.NewWith(func(i, j *pair) int {
|
||||
return cmp.Compare(i.rank, j.rank)
|
||||
})
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for !pairs.Empty() {
|
||||
pair, _ := pairs.Pop()
|
||||
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
||||
string(left.runes)+string(right.runes) != pair.value {
|
||||
continue
|
||||
}
|
||||
|
||||
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial {
|
||||
ids = bpe.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
for _, r := range bpe.vocab.Decode(id) {
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// this produces 0x00 aka NULL
|
||||
continue
|
||||
case r == 0x0143:
|
||||
r = 0x00ad
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
r = r - 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
}
|
||||
|
||||
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||
// encoding of the rune which is _not_ what we want
|
||||
if err := sb.WriteByte(byte(r)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", ids)
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -133,7 +134,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, ok := m.(TextProcessor)
|
||||
tp, ok := m.(tokenizer.Tokenizer)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
@@ -56,6 +56,18 @@ type fakeTensor struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// Stub methods to satisfy ml.Tensor interface
|
||||
func (f *fakeTensor) Exp(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Neg(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Clamp(ctx ml.Context, _, _ float32) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Softplus(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) CumSum(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Diag(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
|
||||
|
||||
func (m *fakeBackend) Get(name string) ml.Tensor {
|
||||
if slices.Contains(m.names, name) {
|
||||
return &fakeTensor{Name: name}
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocab := &model.Vocabulary{
|
||||
vocab := &tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
|
||||
},
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
var t tokenizer.Tokenizer
|
||||
switch c.String("tokenizer.ggml.model", "bert") {
|
||||
case "bert":
|
||||
processor = model.NewWordPiece(vocab, true)
|
||||
t = tokenizer.NewWordPiece(vocab, true)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Tokenizer: t,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
Sam *samModel `gguf:"s"`
|
||||
Vision *visionModel `gguf:"v"`
|
||||
@@ -134,8 +135,8 @@ func init() {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -43,8 +44,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
poolingType pooling.Type
|
||||
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
|
||||
),
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
var t tokenizer.Tokenizer
|
||||
switch c.String("tokenizer.ggml.model") {
|
||||
case "gpt2":
|
||||
processor = model.NewBytePairEncoding(&vocabulary)
|
||||
t = tokenizer.NewBytePairEncoding(&vocabulary)
|
||||
default:
|
||||
// Previous uploads of Gemma 3 on Ollama did not have token 106
|
||||
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
t = tokenizer.NewSentencePiece(&vocabulary)
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Tokenizer: t,
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
|
||||
@@ -6,11 +6,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
}
|
||||
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
|
||||
@@ -198,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -236,8 +237,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
174
model/models/glmocr/imageprocessor.go
Normal file
174
model/models/glmocr/imageprocessor.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package glmocr
|
||||
|
||||
import (
|
||||
"image"
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize int
|
||||
patchSize int
|
||||
temporalPatchSize int
|
||||
spatialMergeSize int
|
||||
minPixels int
|
||||
maxPixels int
|
||||
factor int
|
||||
imageMean [3]float32
|
||||
imageStd [3]float32
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
|
||||
|
||||
// Read normalization values from config if available, otherwise use CLIP defaults
|
||||
imageMean := c.Floats("vision.image_mean", imageproc.ClipDefaultMean[:])
|
||||
imageStd := c.Floats("vision.image_std", imageproc.ClipDefaultSTD[:])
|
||||
|
||||
// Default max_pixels: 2048 * patchSize^2 * mergeSize^2 * temporal = ~3.2M pixels
|
||||
// This limits to ~16k patches (4k output tokens) to keep memory stable without flash attention
|
||||
defaultMaxPixels := 2048 * patchSize * patchSize * spatialMergeSize * spatialMergeSize * temporalPatchSize
|
||||
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size", 336)),
|
||||
patchSize: patchSize,
|
||||
temporalPatchSize: temporalPatchSize,
|
||||
spatialMergeSize: spatialMergeSize,
|
||||
minPixels: int(c.Uint("vision.min_pixels", uint32(8*patchSize*patchSize*spatialMergeSize*spatialMergeSize*temporalPatchSize))),
|
||||
maxPixels: int(c.Uint("vision.max_pixels", uint32(defaultMaxPixels))),
|
||||
factor: patchSize * spatialMergeSize,
|
||||
imageMean: [3]float32{imageMean[0], imageMean[1], imageMean[2]},
|
||||
imageStd: [3]float32{imageStd[0], imageStd[1], imageStd[2]},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
|
||||
factor := p.factor
|
||||
temporalFactor := p.temporalPatchSize
|
||||
numFrames := temporalFactor // single image
|
||||
|
||||
if height < factor || width < factor {
|
||||
// Scale up small images
|
||||
scale := float64(factor) / float64(min(height, width))
|
||||
height = int(math.Ceil(float64(height) * scale))
|
||||
width = int(math.Ceil(float64(width) * scale))
|
||||
}
|
||||
|
||||
if temporalFactor <= 0 {
|
||||
slog.Warn("temporal_patch_size must be > 0, defaulting to 1")
|
||||
temporalFactor = 1
|
||||
}
|
||||
if numFrames < temporalFactor {
|
||||
slog.Warn("num_frames must be >= temporal_patch_size, adjusting num_frames", "num_frames", numFrames, "temporal_patch_size", temporalFactor)
|
||||
numFrames = temporalFactor
|
||||
}
|
||||
if aspectRatio := float64(max(height, width)) / float64(min(height, width)); aspectRatio > 200 {
|
||||
slog.Warn("aspect ratio exceeds 200, image quality may be affected", "aspect_ratio", aspectRatio)
|
||||
}
|
||||
|
||||
round := func(x float64) int { return int(math.RoundToEven(x)) }
|
||||
|
||||
hBar := round(float64(height)/float64(factor)) * factor
|
||||
wBar := round(float64(width)/float64(factor)) * factor
|
||||
tBar := round(float64(numFrames)/float64(temporalFactor)) * temporalFactor
|
||||
|
||||
if tBar*hBar*wBar > p.maxPixels {
|
||||
beta := math.Sqrt(float64(numFrames*height*width) / float64(p.maxPixels))
|
||||
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
|
||||
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
|
||||
} else if tBar*hBar*wBar < p.minPixels {
|
||||
beta := math.Sqrt(float64(p.minPixels) / float64(numFrames*height*width))
|
||||
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
}
|
||||
|
||||
return hBar, wBar
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
|
||||
img = imageproc.Composite(img)
|
||||
|
||||
origWidth := img.Bounds().Dx()
|
||||
origHeight := img.Bounds().Dy()
|
||||
|
||||
// Calculate smart resize dimensions
|
||||
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
|
||||
|
||||
// Resize image
|
||||
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeCatmullrom)
|
||||
|
||||
// Normalize pixels - output format is [C, H, W] with rescale and channelFirst
|
||||
// We keep [C, H, W] for patch extraction
|
||||
normalizedPixels := imageproc.Normalize(resizedImg, p.imageMean, p.imageStd, true, true)
|
||||
|
||||
// Calculate grid dimensions (after Conv2D patching)
|
||||
grid := &Grid{
|
||||
Height: resizedHeight / p.patchSize,
|
||||
Width: resizedWidth / p.patchSize,
|
||||
Temporal: 1, // Single image
|
||||
ImageHeight: resizedHeight,
|
||||
ImageWidth: resizedWidth,
|
||||
}
|
||||
|
||||
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return patches, grid, nil
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
|
||||
channels := 3
|
||||
patchSize := p.patchSize
|
||||
mergeSize := p.spatialMergeSize
|
||||
temporalPatchSize := p.temporalPatchSize
|
||||
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
||||
result := make([]float32, numPatches*patchDim)
|
||||
patchIndex := 0
|
||||
|
||||
// Single temporal frame handling (copies to all frames)
|
||||
for range grid.Temporal {
|
||||
for h := 0; h < grid.Height; h += mergeSize {
|
||||
for w := 0; w < grid.Width; w += mergeSize {
|
||||
for mh := range mergeSize {
|
||||
for mw := range mergeSize {
|
||||
baseOffset := patchIndex * patchDim
|
||||
for c := range channels {
|
||||
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
|
||||
for py := range patchSize {
|
||||
for px := range patchSize {
|
||||
y := (h+mh)*patchSize + py
|
||||
x := (w+mw)*patchSize + px
|
||||
srcIdx := c*height*width + y*width + x
|
||||
dstIdx := channelOffset + (py * patchSize) + px
|
||||
result[dstIdx] = pixels[srcIdx]
|
||||
}
|
||||
}
|
||||
|
||||
if temporalPatchSize > 1 {
|
||||
frameSize := patchSize * patchSize
|
||||
for tp := 1; tp < temporalPatchSize; tp++ {
|
||||
currentFrameOffset := channelOffset + (tp * frameSize)
|
||||
copy(result[currentFrameOffset:currentFrameOffset+frameSize],
|
||||
result[channelOffset:channelOffset+frameSize])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
patchIndex++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
236
model/models/glmocr/model.go
Normal file
236
model/models/glmocr/model.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package glmocr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"image"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
VisionDownsample *VisionDownsample `gguf:"mm.patch_merger"`
|
||||
PatchMerger *PatchMerger `gguf:"mm"`
|
||||
|
||||
ImageProcessor
|
||||
|
||||
imageTokenID int32
|
||||
imageStartTokenID int32
|
||||
imageEndTokenID int32
|
||||
}
|
||||
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
eosTokenID := int32(c.Uint("tokenizer.ggml.eos_token_id"))
|
||||
eosTokenIDs := c.Ints("tokenizer.ggml.eos_token_ids")
|
||||
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
|
||||
|
||||
m := &Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: allEOS,
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
imageTokenID: int32(c.Uint("image_token_id", 59280)),
|
||||
imageStartTokenID: int32(c.Uint("image_start_token_id", 59256)),
|
||||
imageEndTokenID: int32(c.Uint("image_end_token_id", 59257)),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if len(m.VisionModel.Blocks) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, grid, err := m.ImageProcessor.ProcessImage(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create pixel values tensor from flattened patches
|
||||
// Shape: [patchDim, numPatches]
|
||||
patchDim := m.VisionModel.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
|
||||
|
||||
// Forward through vision encoder
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid)
|
||||
|
||||
// Forward through downsample (patch merger)
|
||||
if m.VisionDownsample == nil || m.VisionDownsample.Weight == nil {
|
||||
return nil, errors.New("glmocr: missing vision downsample weights")
|
||||
}
|
||||
visionOutputs = m.VisionDownsample.Forward(ctx, visionOutputs, grid, m.VisionModel.VisionModelOptions)
|
||||
|
||||
// Forward through patch merger (FC + LayerNorm + GELU + SwiGLU FFN)
|
||||
if m.PatchMerger == nil {
|
||||
return nil, errors.New("glmocr: missing patch merger weights")
|
||||
}
|
||||
visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.VisionModelOptions)
|
||||
|
||||
return []input.Multimodal{{Tensor: visionOutputs, Data: grid}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
// Reset position cache
|
||||
m.TextModel.positionCache = m.TextModel.positionCache[:0]
|
||||
m.TextModel.ropeDelta = 0
|
||||
|
||||
pos := int32(0)
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
result = append(result, inp)
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
// Get grid info for position calculation
|
||||
grid := inp.Multimodal[0].Data.(*Grid)
|
||||
mergedH := grid.Height / m.VisionModel.spatialMergeSize
|
||||
mergedW := grid.Width / m.VisionModel.spatialMergeSize
|
||||
|
||||
// Add image start token
|
||||
result = append(result, &input.Input{Token: m.imageStartTokenID})
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
|
||||
pos++
|
||||
|
||||
// Add image tokens with multimodal data
|
||||
// All image tokens share the same base position for temporal dimension
|
||||
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
|
||||
basePos := pos
|
||||
sameBatch := tokensPerGrid - 1
|
||||
if sameBatch < 0 {
|
||||
sameBatch = 0
|
||||
}
|
||||
result = append(result, &input.Input{
|
||||
Token: m.imageTokenID,
|
||||
Multimodal: inp.Multimodal,
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
SameBatch: sameBatch,
|
||||
})
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
|
||||
|
||||
// Add placeholder tokens for remaining positions
|
||||
// All image tokens use the same base position (temporal stays constant)
|
||||
for range tokensPerGrid - 1 {
|
||||
result = append(result, &input.Input{Token: m.imageTokenID})
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
|
||||
}
|
||||
|
||||
// Advance position by max(mergedH, mergedW) after image tokens
|
||||
pos = basePos + int32(max(mergedH, mergedW))
|
||||
|
||||
// Add image end token
|
||||
result = append(result, &input.Input{Token: m.imageEndTokenID})
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
|
||||
pos++
|
||||
}
|
||||
|
||||
// Compute rope delta for continuation after the prefill segment:
|
||||
// delta = (max_position_id + 1) - sequence_length
|
||||
if len(m.TextModel.positionCache) > 0 {
|
||||
last := m.TextModel.positionCache[len(m.TextModel.positionCache)-1]
|
||||
m.TextModel.ropeDelta = last + 1 - int32(len(m.TextModel.positionCache))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
// Initial token embedding
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
|
||||
ctx.Forward(hiddenStates)
|
||||
|
||||
// Build position slices for M-RoPE
|
||||
positionSlice := func() [][]int32 {
|
||||
s := [][]int32{
|
||||
make([]int32, len(batch.Positions)), // temporal
|
||||
make([]int32, len(batch.Positions)), // height
|
||||
make([]int32, len(batch.Positions)), // width
|
||||
make([]int32, len(batch.Positions)), // unused (zeros)
|
||||
}
|
||||
for i, position := range batch.Positions {
|
||||
// Translate through position cache or continue sequence
|
||||
if position < int32(len(m.TextModel.positionCache)) {
|
||||
position = m.TextModel.positionCache[position]
|
||||
} else if len(m.TextModel.positionCache) > 0 {
|
||||
// Continue sequence after cached positions using ropeDelta
|
||||
position = position + m.TextModel.ropeDelta
|
||||
}
|
||||
|
||||
s[0][i] = position
|
||||
s[1][i] = position
|
||||
s[2][i] = position
|
||||
}
|
||||
return s
|
||||
}()
|
||||
|
||||
// Inject vision embeddings and adjust positions for image tokens
|
||||
for _, mi := range batch.Multimodal {
|
||||
img := mi.Multimodal[0].Tensor
|
||||
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
|
||||
|
||||
if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
|
||||
w := grid.Width / m.VisionModel.spatialMergeSize
|
||||
for i := range img.Dim(1) {
|
||||
positionSlice[1][mi.Index+i] += int32(i / w)
|
||||
positionSlice[2][mi.Index+i] += int32(i % w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
|
||||
|
||||
// Process through transformer layers
|
||||
for i, layer := range m.TextModel.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.TextModel.Layers)-1 {
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, m.Cache, m.TextModel.TextModelOptions)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.TextModel.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("glmocr", New)
|
||||
}
|
||||
190
model/models/glmocr/model_text.go
Normal file
190
model/models/glmocr/model_text.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package glmocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
type TextModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
numKVHeads int
|
||||
headDim int
|
||||
rotaryDim int
|
||||
intermediateSize int
|
||||
eps float32
|
||||
ropeBase float32
|
||||
mropeSections []int
|
||||
}
|
||||
|
||||
func (o *TextModelOptions) applyMRoPE(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
// With 4 sections for [temporal, height, width, unused]
|
||||
return nn.RoPE(ctx, states, positions, o.rotaryDim, o.ropeBase, 1.0, rope.WithMRoPE(o.mropeSections))
|
||||
}
|
||||
|
||||
type TextSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
// Separate Q, K, V projections
|
||||
q := sa.Query.Forward(ctx, hiddenStates)
|
||||
k := sa.Key.Forward(ctx, hiddenStates)
|
||||
v := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
// Reshape for GQA
|
||||
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
|
||||
k = k.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
|
||||
v = v.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
// Apply M-RoPE (multi-resolution rotary position embeddings)
|
||||
q = opts.applyMRoPE(ctx, q, positions)
|
||||
k = opts.applyMRoPE(ctx, k, positions)
|
||||
|
||||
// Scaled dot-product attention with KV cache
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
// Reshape attention output: [headDim, numHeads, batchSize] -> [numHeads*headDim, batchSize]
|
||||
// Note: numHeads * headDim = 16 * 128 = 2048, which is the attention hidden size
|
||||
kqv = kqv.Reshape(ctx, opts.numHeads*opts.headDim, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextModelOptions) ml.Tensor {
|
||||
// SwiGLU: down(silu(gate(x)) * up(x))
|
||||
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, gate)
|
||||
}
|
||||
|
||||
type TextDecoderLayer struct {
|
||||
// Input layernorm (before attention)
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *TextSelfAttention
|
||||
// Post self-attention layernorm (after attention, before residual add)
|
||||
PostAttnNorm *nn.RMSNorm `gguf:"post_attn_norm"`
|
||||
|
||||
// FFN input layernorm (after first residual, before MLP)
|
||||
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *TextMLP
|
||||
// Post MLP layernorm (after MLP, before residual add)
|
||||
PostFFNNorm *nn.RMSNorm `gguf:"post_ffn_norm"`
|
||||
}
|
||||
|
||||
func (l *TextDecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
// Attention block
|
||||
residual := hiddenStates
|
||||
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = l.SelfAttention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
hiddenStates = l.PostAttnNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// Prune to output positions in final layer
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
// MLP block
|
||||
residual = hiddenStates
|
||||
hiddenStates = l.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = l.PostFFNNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []TextDecoderLayer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*TextModelOptions
|
||||
|
||||
// positionCache stores the M-RoPE position for each token in the sequence.
|
||||
// This is needed because image tokens share the same base position but have
|
||||
// different height/width offsets, and the end token position depends on the
|
||||
// image grid dimensions.
|
||||
positionCache []int32
|
||||
ropeDelta int32
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// Clear position cache when KV cache shifts
|
||||
m.positionCache = nil
|
||||
m.ropeDelta = 0
|
||||
return m.applyMRoPE(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
hiddenSize := int(c.Uint("embedding_length", 1536))
|
||||
numHeads := int(c.Uint("attention.head_count", 16))
|
||||
numKVHeads := int(c.Uint("attention.head_count_kv", 8))
|
||||
intermediateSize := int(c.Uint("feed_forward_length", 4608))
|
||||
eps := c.Float("attention.layer_norm_rms_epsilon", 1e-5)
|
||||
ropeBase := c.Float("rope.freq_base", 10000)
|
||||
|
||||
headDim := int(c.Uint("attention.key_length", uint32(hiddenSize/numHeads)))
|
||||
ropeDim := int(c.Uint("rope.dimension_count", uint32(headDim)))
|
||||
if ropeDim <= 0 {
|
||||
ropeDim = headDim
|
||||
}
|
||||
|
||||
mropeSections := c.Ints("rope.mrope_section")
|
||||
var sectionInts []int
|
||||
|
||||
if len(mropeSections) > 0 {
|
||||
sectionInts = make([]int, len(mropeSections))
|
||||
for i, section := range mropeSections {
|
||||
sectionInts[i] = int(section)
|
||||
}
|
||||
} else {
|
||||
// Default to GLM-OCR's HF ratio (2:3:3) scaled to rotaryDim/2.
|
||||
// For rotaryDim=64 this yields [8, 12, 12].
|
||||
total := ropeDim / 2
|
||||
if total <= 0 {
|
||||
total = 32
|
||||
}
|
||||
s0 := total * 2 / 8
|
||||
s1 := total * 3 / 8
|
||||
s2 := total - s0 - s1
|
||||
sectionInts = []int{s0, s1, s2}
|
||||
}
|
||||
|
||||
// GGML rope_multi: sector = (dim_pair) % sum(sections), mapping each pair to its position dim
|
||||
rotaryDim := ropeDim
|
||||
|
||||
return &TextModel{
|
||||
Layers: make([]TextDecoderLayer, c.Uint("block_count", 16)),
|
||||
TextModelOptions: &TextModelOptions{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
numKVHeads: numKVHeads,
|
||||
headDim: headDim,
|
||||
rotaryDim: rotaryDim,
|
||||
intermediateSize: intermediateSize,
|
||||
eps: eps,
|
||||
ropeBase: ropeBase,
|
||||
mropeSections: sectionInts,
|
||||
},
|
||||
}
|
||||
}
|
||||
355
model/models/glmocr/model_vision.go
Normal file
355
model/models/glmocr/model_vision.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package glmocr
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
type Grid struct {
|
||||
Height int // Number of patches in height direction
|
||||
Width int // Number of patches in width direction
|
||||
Temporal int
|
||||
ImageHeight int // Full image height in pixels
|
||||
ImageWidth int // Full image width in pixels
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
headDim int
|
||||
numChannels int
|
||||
patchSize int
|
||||
temporalPatchSize int
|
||||
imageSize int
|
||||
spatialMergeSize int
|
||||
outHiddenSize int
|
||||
intermediateSize int
|
||||
eps float32
|
||||
}
|
||||
|
||||
type VisionPatchEmbed struct {
|
||||
Proj *nn.Conv2D `gguf:"patch_embd_0"`
|
||||
Proj1 *nn.Conv2D `gguf:"patch_embd_1"`
|
||||
Bias ml.Tensor `gguf:"patch_embd.bias"`
|
||||
}
|
||||
|
||||
func (pe *VisionPatchEmbed) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
|
||||
_ = grid // patches are already in merge-block order
|
||||
|
||||
// pixelValues shape: [patchDim, numPatches]
|
||||
numPatches := pixelValues.Shape()[1]
|
||||
|
||||
// Reshape to [patchSize*patchSize, temporalPatchSize, numChannels, numPatches]
|
||||
pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
|
||||
// Permute to [temporalPatchSize, patchSize*patchSize, numChannels, numPatches]
|
||||
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Slice temporal frames for Conv2D (simulate Conv3D)
|
||||
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||
in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
|
||||
|
||||
s0, s1 := opts.patchSize, opts.patchSize
|
||||
p0, p1 := 0, 0
|
||||
d0, d1 := 1, 1
|
||||
hiddenStates := pe.Proj.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
|
||||
|
||||
if pe.Proj1 != nil && opts.temporalPatchSize > 1 {
|
||||
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||
in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
|
||||
out1 := pe.Proj1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
|
||||
hiddenStates = hiddenStates.Add(ctx, out1)
|
||||
}
|
||||
|
||||
// Flatten to [hidden_size, num_patches]
|
||||
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, numPatches)
|
||||
|
||||
// Add patch bias - reshape from [hidden_size] to [hidden_size, 1] for broadcasting
|
||||
if pe.Bias != nil {
|
||||
hiddenStates = hiddenStates.Add(ctx, pe.Bias.Reshape(ctx, opts.hiddenSize, 1))
|
||||
}
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
QKV *nn.Linear `gguf:"attn_qkv"`
|
||||
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
// Combined QKV projection: [3*hidden_size, batch_size]
|
||||
qkv := sa.QKV.Forward(ctx, hiddenStates)
|
||||
|
||||
// Split using ChunkSections along dim 0 (handles byte offsets correctly)
|
||||
// ChunkSections returns views - must make contiguous before further operations
|
||||
chunks := qkv.ChunkSections(ctx, 0, opts.hiddenSize, opts.hiddenSize, opts.hiddenSize)
|
||||
q := chunks[0].Contiguous(ctx)
|
||||
k := chunks[1].Contiguous(ctx)
|
||||
v := chunks[2].Contiguous(ctx)
|
||||
|
||||
// Reshape for multi-head attention: [hiddenSize, N] -> [headDim, numHeads, N]
|
||||
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
|
||||
k = k.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
|
||||
v = v.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
|
||||
|
||||
// Apply Q-norm and K-norm after head reshape
|
||||
// Weights are [headDim]=64, tensor is [headDim, numHeads, N]
|
||||
q = sa.QNorm.Forward(ctx, q, opts.eps)
|
||||
k = sa.KNorm.Forward(ctx, k, opts.eps)
|
||||
|
||||
// Apply rotary position embeddings with vision-style 2D positions.
|
||||
// ggml's vision RoPE uses two position dimensions (H/W) with half-rotation pairs.
|
||||
// We provide H/W sections and leave the remaining sections empty.
|
||||
ropeFreqBase := float32(10000.0)
|
||||
section := opts.headDim / 4
|
||||
if section <= 0 {
|
||||
section = 1
|
||||
}
|
||||
sections := []int{section, section, 0, 0}
|
||||
q = nn.RoPE(ctx, q, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
|
||||
k = nn.RoPE(ctx, k, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
|
||||
|
||||
// Scale factor for scaled dot-product attention
|
||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||
|
||||
// Try flash attention first (ScaledDotProductAttention), fall back to manual
|
||||
if sdpa, ok := q.(ml.ScaledDotProductAttention); ok {
|
||||
attention := sdpa.ScaledDotProductAttention(ctx, k, v, nil, nil, nil, scale, false)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
slog.Warn("glmocr: vision attention falling back to manual attention",
|
||||
"batchSize", batchSize, "numHeads", opts.numHeads,
|
||||
"hint", "set OLLAMA_FLASH_ATTENTION=1 to enable flash attention")
|
||||
|
||||
// Manual attention fallback
|
||||
// q, k, v are [headDim, numHeads, batchSize] - GGML treats as 4D with implicit dim 3 = 1
|
||||
q = q.Permute(ctx, 0, 2, 1, 3)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
// Attention scores
|
||||
kq := k.MulmatFullPrec(ctx, q)
|
||||
kq = kq.Scale(ctx, scale)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
// Attention output: v @ kq (note: v first)
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
// SwiGLU: down(silu(gate(x)) * up(x))
|
||||
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, gate)
|
||||
}
|
||||
|
||||
type VisionBlock struct {
|
||||
Norm1 *nn.RMSNorm `gguf:"ln1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
Norm2 *nn.RMSNorm `gguf:"ln2"`
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (b *VisionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
// Pre-norm architecture
|
||||
residual := hiddenStates
|
||||
hiddenStates = b.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = b.SelfAttention.Forward(ctx, hiddenStates, positions, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = b.Norm2.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = b.MLP.Forward(ctx, hiddenStates)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type VisionDownsample struct {
|
||||
*nn.Conv2D
|
||||
}
|
||||
|
||||
func (d *VisionDownsample) Forward(ctx ml.Context, hiddenStates ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
|
||||
// Apply spatial downsampling via Conv2D
|
||||
// Input: [hidden_size, num_patches] where patches are in merge-block order
|
||||
|
||||
if d.Conv2D == nil || d.Weight == nil {
|
||||
slog.Error("VisionDownsample weights not loaded - model may be corrupted or incompatible")
|
||||
return hiddenStates // Return input unchanged as fallback
|
||||
}
|
||||
|
||||
merge := opts.spatialMergeSize
|
||||
numOutputTokens := (grid.Height / merge) * (grid.Width / merge)
|
||||
|
||||
// Step 1: Reshape to [hidden_size, merge, merge, num_output_tokens]
|
||||
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, merge, merge, numOutputTokens)
|
||||
|
||||
// Step 2: Permute to [merge, merge, hidden_size, num_output_tokens]
|
||||
// ggml semantics: result.ne[perm[i]] = input.ne[i]
|
||||
// So permute(2,0,1,3) on [1024,2,2,N] gives: ne[2]=1024, ne[0]=2, ne[1]=2, ne[3]=N -> [2,2,1024,N]
|
||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
|
||||
// Step 3: Apply Conv2D without bias (bias added after reshape)
|
||||
// Note: ggml_conv_2d takes (kernel, input) - kernel must be receiver in ollama
|
||||
s0, s1 := merge, merge
|
||||
p0, p1 := 0, 0
|
||||
d0, d1 := 1, 1
|
||||
hiddenStates = d.Weight.Conv2D(ctx, hiddenStates, s0, s1, p0, p1, d0, d1)
|
||||
|
||||
// Step 4: Reshape to [out_hidden_size, num_output_tokens]
|
||||
hiddenStates = hiddenStates.Reshape(ctx, opts.outHiddenSize, numOutputTokens)
|
||||
|
||||
// Step 5: Add bias after reshape
|
||||
// Reshape bias from [out_hidden_size] to [out_hidden_size, 1] for proper broadcasting
|
||||
if d.Bias != nil {
|
||||
hiddenStates = hiddenStates.Add(ctx, d.Bias.Reshape(ctx, opts.outHiddenSize, 1))
|
||||
}
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type PatchMerger struct {
|
||||
// GGUF tags align with mm.* keys used by the model
|
||||
Proj *nn.Linear `gguf:"model.fc"` // mm.model.fc.weight
|
||||
PostLN *nn.LayerNorm `gguf:"post_norm"` // mm.post_norm.weight/bias
|
||||
GateProj *nn.Linear `gguf:"gate"` // mm.gate.weight
|
||||
UpProj *nn.Linear `gguf:"up"` // mm.up.weight
|
||||
DownProj *nn.Linear `gguf:"down"` // mm.down.weight
|
||||
}
|
||||
|
||||
func (m *PatchMerger) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
// Linear projection
|
||||
hiddenStates = m.Proj.Forward(ctx, hiddenStates)
|
||||
|
||||
// Post-projection layer norm + GELU ERF
|
||||
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = hiddenStates.GELU_ERF(ctx)
|
||||
// Force a copy to avoid in-place mutation issues with GELU_ERF
|
||||
hiddenStates = hiddenStates.Contiguous(ctx)
|
||||
|
||||
// SwiGLU MLP: down(silu(gate(x)) * up(x))
|
||||
gateOut := m.GateProj.Forward(ctx, hiddenStates)
|
||||
upOut := m.UpProj.Forward(ctx, hiddenStates)
|
||||
gate := gateOut.SILU(ctx, upOut)
|
||||
return m.DownProj.Forward(ctx, gate)
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbed *VisionPatchEmbed
|
||||
Blocks []VisionBlock `gguf:"blk"`
|
||||
PostLN *nn.RMSNorm `gguf:"post_ln"`
|
||||
// Note: Downsample is applied at the model level so mm.patch_merger stays separate
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
|
||||
// Extract patch embeddings from flattened patches
|
||||
hiddenStates := m.PatchEmbed.Forward(ctx, pixelValues, grid, m.VisionModelOptions)
|
||||
|
||||
// Create position IDs for RoPE (spatial grid)
|
||||
// Patches are already in merge-block order from preprocessing
|
||||
positions := m.createPositions(ctx, grid)
|
||||
|
||||
// Process through vision blocks
|
||||
for _, block := range m.Blocks {
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
// Post-layernorm
|
||||
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, m.eps)
|
||||
|
||||
// Note: Downsample is now applied separately in Model.EncodeMultimodal
|
||||
// so mm.patch_merger remains a distinct module
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
func (m *VisionModel) createPositions(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||
// Create spatial position IDs for vision RoPE
|
||||
// Position layout: [height, width, height, width] - 4 sections for mrope
|
||||
// Patches are in MERGE-BLOCK order after VisionPatchEmbed interleaving
|
||||
// This follows the GLM-OCR rot_pos_emb layout
|
||||
numPatches := grid.Height * grid.Width
|
||||
mergeRatio := m.spatialMergeSize
|
||||
|
||||
// Build position arrays in merge-block order
|
||||
// Each merge_ratio x merge_ratio block of patches is grouped together
|
||||
hpos := make([]int32, numPatches)
|
||||
wpos := make([]int32, numPatches)
|
||||
ptr := 0
|
||||
for y := 0; y < grid.Height; y += mergeRatio {
|
||||
for x := 0; x < grid.Width; x += mergeRatio {
|
||||
for dy := range mergeRatio {
|
||||
for dx := range mergeRatio {
|
||||
hpos[ptr] = int32(y + dy)
|
||||
wpos[ptr] = int32(x + dx)
|
||||
ptr++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build position arrays for 4 sections (mrope). ggml vision RoPE uses only H/W;
|
||||
// keep remaining sections zeroed to match its conventions.
|
||||
zeros := make([]int32, numPatches)
|
||||
s := [][]int32{
|
||||
hpos, // Section 0: height
|
||||
wpos, // Section 1: width
|
||||
zeros, // Section 2: unused
|
||||
zeros, // Section 3: unused
|
||||
}
|
||||
|
||||
return ctx.Input().FromInts(slices.Concat(s...), numPatches*4)
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
hiddenSize := int(c.Uint("vision.embedding_length", 1024))
|
||||
numHeads := int(c.Uint("vision.attention.head_count", 16))
|
||||
numChannels := int(c.Uint("vision.num_channels", 3))
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
|
||||
imageSize := int(c.Uint("vision.image_size", 336))
|
||||
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||
outHiddenSize := int(c.Uint("vision.out_hidden_size", 1536))
|
||||
intermediateSize := int(c.Uint("vision.intermediate_size", 4096))
|
||||
eps := c.Float("vision.attention.layer_norm_rms_epsilon", 1e-5)
|
||||
|
||||
return &VisionModel{
|
||||
Blocks: make([]VisionBlock, c.Uint("vision.block_count", 24)),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
headDim: hiddenSize / numHeads,
|
||||
numChannels: numChannels,
|
||||
patchSize: patchSize,
|
||||
temporalPatchSize: temporalPatchSize,
|
||||
imageSize: imageSize,
|
||||
spatialMergeSize: spatialMergeSize,
|
||||
outHiddenSize: outHiddenSize,
|
||||
intermediateSize: intermediateSize,
|
||||
eps: eps,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Transformer struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TransformerBlocks []TransformerBlock `gguf:"blk"`
|
||||
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Transformer{
|
||||
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -59,7 +60,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -78,7 +79,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
vocabulary := model.Vocabulary{
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -104,8 +105,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
vocabulary := model.Vocabulary{
|
||||
var processor tokenizer.Tokenizer
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
}
|
||||
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
case "llama":
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
processor = tokenizer.NewSentencePiece(&vocabulary)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
ImageProcessor
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -28,12 +29,12 @@ type Model struct {
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
// Implement TextProcessor interface
|
||||
var _ model.TextProcessor = (*Model)(nil)
|
||||
var _ tokenizer.Tokenizer = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -32,8 +33,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/glm4moelite"
|
||||
_ "github.com/ollama/ollama/model/models/glmocr"
|
||||
_ "github.com/ollama/ollama/model/models/gptoss"
|
||||
_ "github.com/ollama/ollama/model/models/lfm2"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
@@ -19,5 +20,6 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3next"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3vl"
|
||||
)
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -178,29 +179,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
headDim := hiddenSize / numHeads
|
||||
|
||||
processor := model.NewWordPiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
)
|
||||
|
||||
blockCount := int(c.Uint("block_count"))
|
||||
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
|
||||
layers := make([]EncoderLayer, blockCount)
|
||||
@@ -219,8 +197,29 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: layers,
|
||||
Tokenizer: tokenizer.NewWordPiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: Options{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,7 +34,7 @@ type Options struct {
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -44,28 +45,24 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
processor := model.NewBytePairEncoding(
|
||||
&vocabulary,
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []DecoderLayer `gguf:"blk"`
|
||||
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
m := Model{
|
||||
Layers: make([]DecoderLayer, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*Model
|
||||
poolingType pooling.Type
|
||||
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
m := embedModel{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
103
model/models/qwen3next/attention.go
Normal file
103
model/models/qwen3next/attention.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
|
||||
// with the attention layer requirements.
|
||||
var ErrUnsupportedBatchLayout = errors.New("qwen3next: unsupported batch layout")
|
||||
|
||||
// FullAttention implements gated attention with QK normalization and sigmoid-gated output.
|
||||
// Key differences from standard attention:
|
||||
// - Q projection outputs 2x size (Q + gate interleaved)
|
||||
// - Both Q and K have RMSNorm
|
||||
// - Output is gated: attn * sigmoid(gate)
|
||||
type FullAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"` // outputs [n_embd_head * 2, n_head]
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
// Use Dim() instead of Shape() for consistent behavior during graph construction
|
||||
hiddenDim := hiddenStates.Dim(0)
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
nSeqs := hiddenStates.Dim(2) // 0 if 2D tensor
|
||||
|
||||
if cache != nil && cache.IsSupportedForBatch() {
|
||||
seqTokens := cache.seqTokens()
|
||||
seqs := cache.numSeqs()
|
||||
if seqTokens > 0 && seqs > 0 {
|
||||
if nSeqs > 0 {
|
||||
// 3D tensor: [hiddenDim, seqTokens, nSeqs]
|
||||
if batchSize != seqTokens || nSeqs != seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
|
||||
batchSize = seqTokens * seqs
|
||||
} else if batchSize != seqTokens*seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
}
|
||||
}
|
||||
headDim := opts.headDim()
|
||||
numHeads := opts.numHeads
|
||||
|
||||
// Q projection outputs query + gate interleaved
|
||||
qFull := sa.Query.Forward(ctx, hiddenStates)
|
||||
|
||||
// Reshape to [headDim * 2, numHeads, batchSize]
|
||||
qFull = qFull.Reshape(ctx, headDim*2, numHeads, batchSize)
|
||||
|
||||
// Split Q and gate along dimension 0
|
||||
// Q: first headDim elements, gate: second headDim elements
|
||||
query := qFull.Slice(ctx, 0, 0, headDim, 1)
|
||||
gate := qFull.Slice(ctx, 0, headDim, headDim*2, 1)
|
||||
|
||||
// Make query contiguous for further operations
|
||||
query = query.Contiguous(ctx, headDim, numHeads, batchSize)
|
||||
|
||||
// K and V projections
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
// Derive numKVHeads from tensor dimensions (per-layer value)
|
||||
numKVHeads := key.Dim(0) / headDim
|
||||
|
||||
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
|
||||
// Apply QK normalization
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
// Apply RoPE
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
// Standard attention computation
|
||||
scale := opts.attentionScale
|
||||
if scale == 0 {
|
||||
scale = 1.0 / math.Sqrt(float64(headDim))
|
||||
}
|
||||
attention := nn.Attention(ctx, query, key, value, scale, cache)
|
||||
|
||||
// Flatten heads
|
||||
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
|
||||
|
||||
// Apply sigmoid gate
|
||||
// gate shape: [headDim, numHeads, batchSize] -> [headDim*numHeads, batchSize]
|
||||
gate = gate.Contiguous(ctx, headDim*numHeads, batchSize)
|
||||
gateSigmoid := gate.Sigmoid(ctx)
|
||||
attention = attention.Mul(ctx, gateSigmoid)
|
||||
|
||||
return sa.Output.Forward(ctx, attention), nil
|
||||
}
|
||||
596
model/models/qwen3next/cache.go
Normal file
596
model/models/qwen3next/cache.go
Normal file
@@ -0,0 +1,596 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var _ kvcache.Cache = (*HybridCache)(nil)
|
||||
|
||||
// HybridCache stores:
|
||||
// - a standard causal KV cache for full attention layers
|
||||
// - per-sequence conv state for linear attention layers
|
||||
// - per-sequence delta state for linear attention layers
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels]
|
||||
// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads]
|
||||
type HybridCache struct {
|
||||
kv *kvcache.Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
// Conv state dimensions
|
||||
convDim int // convKernelSize - 1
|
||||
convChannels int // d_inner + 2 * num_k_heads * head_k_dim
|
||||
|
||||
// Delta state dimensions
|
||||
deltaStateSize int // headVDim * headVDim * numVHeads
|
||||
|
||||
// slot mapping for recurrent state (copy-on-write)
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
|
||||
|
||||
// per-layer delta state buffers (allocated lazily)
|
||||
deltaCtxs map[int]ml.Context
|
||||
deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots]
|
||||
|
||||
// recurrent checkpoints (per slot)
|
||||
checkpointCount int
|
||||
checkpointMinPos int32
|
||||
checkpointInterval int32
|
||||
checkpointCtxSize int
|
||||
checkpoints map[int]*slotCheckpointStore
|
||||
pendingRestore map[int]checkpointRestore
|
||||
curCheckpointPos []int32
|
||||
curCheckpointSlots map[int]int
|
||||
reserveCheckpoints bool
|
||||
checkpointConvCtxs map[int]ml.Context
|
||||
checkpointDeltaCtxs map[int]ml.Context
|
||||
checkpointReserved map[int]struct{}
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
writableError error
|
||||
}
|
||||
|
||||
func NewHybridCache(
|
||||
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
|
||||
convDim, convChannels, deltaStateSize int,
|
||||
) *HybridCache {
|
||||
return &HybridCache{
|
||||
kv: kvcache.NewCausalCache(shift),
|
||||
convDim: convDim,
|
||||
convChannels: convChannels,
|
||||
deltaStateSize: deltaStateSize,
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
deltaCtxs: make(map[int]ml.Context),
|
||||
deltaStates: make(map[int]ml.Tensor),
|
||||
checkpointCount: checkpointCountDefault,
|
||||
checkpointMinPos: checkpointMinPosDefault,
|
||||
checkpointInterval: checkpointIntervalDefault,
|
||||
checkpoints: make(map[int]*slotCheckpointStore),
|
||||
pendingRestore: make(map[int]checkpointRestore),
|
||||
curCheckpointSlots: make(map[int]int),
|
||||
checkpointConvCtxs: make(map[int]ml.Context),
|
||||
checkpointDeltaCtxs: make(map[int]ml.Context),
|
||||
checkpointReserved: make(map[int]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
c.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
c.pendingRestore = make(map[int]checkpointRestore)
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
c.curCheckpointSlots = make(map[int]int)
|
||||
c.checkpointReserved = make(map[int]struct{})
|
||||
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
|
||||
if c.checkpointCtxSize < 8 {
|
||||
c.checkpointCtxSize = 8
|
||||
}
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.deltaCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointConvCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointDeltaCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for recurrent layers
|
||||
seqCounts := make(map[int]int)
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if _, ok := seqCounts[s]; !ok {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
if len(c.curSeqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
// When reserving memory for estimation, use fake slot assignments
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
slots := make([]int32, nSeqs)
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
slots[i] = int32(i)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
c.reserveCheckpoints = true
|
||||
c.planCheckpoints(batch)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
// Zero state for newly allocated slots
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
// Create a tensor for the current slots
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
// Reset writable state for new forward pass
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
c.reserveCheckpoints = false
|
||||
c.planCheckpoints(batch)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, kvcache.ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) freeSlot(slot int) {
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroSlots zeros the recurrent state for the given slots across all layers.
|
||||
func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
inputCtx := ctx.Input()
|
||||
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, s := range slots {
|
||||
slotIndices[i] = int32(s)
|
||||
}
|
||||
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
||||
|
||||
// Zero conv states
|
||||
if len(c.convStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
// Zero delta states
|
||||
if len(c.deltaStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots))
|
||||
for _, buf := range c.deltaStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures sequences have private slots (copy-on-write).
|
||||
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
}
|
||||
|
||||
// Rebuild current slots tensor
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
|
||||
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
|
||||
|
||||
for _, buf := range c.convStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
||||
}
|
||||
|
||||
for _, buf := range c.deltaStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
// Copy-on-write for recurrent state
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
if c.validSlot(dstSlot) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if c.validSlot(srcSlot) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
||||
if !c.kv.CanResume(seq, pos) {
|
||||
return false
|
||||
}
|
||||
if pos == 0 {
|
||||
return true
|
||||
}
|
||||
return c.hasCheckpoint(seq, pos)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore, ok := c.pendingRestore[seq]
|
||||
if !ok || restore.pos+1 != beginIndex {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
if !c.restoreComplete(restore) {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
// If the recurrent slot is shared, detach it before applying a restore.
|
||||
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.deltaStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
|
||||
restore.slot = newSlot
|
||||
c.pendingRestore[seq] = restore
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore := c.pendingRestore[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
return c.applyCheckpointRestore(restore)
|
||||
}
|
||||
|
||||
// Removal invalidates recurrent state
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.validSlot(slot) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.clearCheckpoints(slot)
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) validSlot(slot int) bool {
|
||||
return slot >= 0 && slot < len(c.refCount)
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
}
|
||||
|
||||
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
|
||||
func (c *HybridCache) contiguousSlots() (int, bool) {
|
||||
if len(c.curSlots) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
start := c.curSlots[0]
|
||||
for i, s := range c.curSlots {
|
||||
if s != start+i {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
return start, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.curSeqTokens
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
// Recurrent state must stay in F32 (ssm_conv kernels are F32-only).
|
||||
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.deltaStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.deltaCtxs[layer]; !ok {
|
||||
c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
// Recurrent delta state must stay in F32.
|
||||
buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences)
|
||||
c.deltaStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureWritableOnce(ctx ml.Context) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
}
|
||||
|
||||
// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs].
|
||||
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
c.ensureWritableOnce(ctx)
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes a new conv state for current batch sequences.
|
||||
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs())
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
// Fast path: contiguous slots allow a single view + copy
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs())
|
||||
ctx.Forward(srcF32.Copy(ctx, view))
|
||||
} else {
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
c.captureConvCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs].
|
||||
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
|
||||
c.ensureWritableOnce(ctx)
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateDeltaState writes a new delta state for current batch sequences.
|
||||
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs())
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
// Fast path: contiguous slots allow a single view + copy
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs())
|
||||
ctx.Forward(srcF32.Copy(ctx, view))
|
||||
} else {
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
c.captureDeltaCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *HybridCache) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
}
|
||||
498
model/models/qwen3next/checkpoints.go
Normal file
498
model/models/qwen3next/checkpoints.go
Normal file
@@ -0,0 +1,498 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
checkpointCountDefault = 32
|
||||
checkpointMinPosDefault = int32(16)
|
||||
checkpointIntervalDefault = int32(1280)
|
||||
)
|
||||
|
||||
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
|
||||
// memory usage while preserving prefix reuse for recurrent state.
|
||||
|
||||
type checkpointEntry struct {
|
||||
pos int32
|
||||
conv map[int]ml.Tensor
|
||||
delta map[int]ml.Tensor
|
||||
}
|
||||
|
||||
type slotCheckpointStore struct {
|
||||
entries []checkpointEntry
|
||||
size int
|
||||
next int
|
||||
lastPos int32
|
||||
}
|
||||
|
||||
type checkpointRestore struct {
|
||||
slot int
|
||||
idx int
|
||||
pos int32
|
||||
}
|
||||
|
||||
func newSlotCheckpointStore(n int) *slotCheckpointStore {
|
||||
entries := make([]checkpointEntry, n)
|
||||
for i := range entries {
|
||||
entries[i].pos = -1
|
||||
}
|
||||
return &slotCheckpointStore{
|
||||
entries: entries,
|
||||
lastPos: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) reset() {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
for i := range s.entries {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) record(pos int32) int {
|
||||
if len(s.entries) == 0 {
|
||||
return -1
|
||||
}
|
||||
idx := s.next
|
||||
s.next = (s.next + 1) % len(s.entries)
|
||||
if s.size < len(s.entries) {
|
||||
s.size++
|
||||
}
|
||||
s.entries[idx].pos = pos
|
||||
s.lastPos = pos
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
|
||||
bestIdx := -1
|
||||
bestPos := int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 || pos >= targetPos {
|
||||
continue
|
||||
}
|
||||
if pos > bestPos {
|
||||
bestPos = pos
|
||||
bestIdx = i
|
||||
}
|
||||
}
|
||||
if bestIdx < 0 {
|
||||
return -1, -1, false
|
||||
}
|
||||
return bestIdx, bestPos, true
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) pruneAfter(pos int32) {
|
||||
if len(s.entries) == 0 {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
minPos := int32(math.MaxInt32)
|
||||
minIdx := 0
|
||||
for i := range s.entries {
|
||||
if s.entries[i].pos > pos {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
if s.entries[i].pos >= 0 {
|
||||
size++
|
||||
if s.entries[i].pos < minPos {
|
||||
minPos = s.entries[i].pos
|
||||
minIdx = i
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.size = size
|
||||
if size == 0 {
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
if next != -1 {
|
||||
s.next = next
|
||||
} else {
|
||||
// Full ring: overwrite the oldest checkpoint next.
|
||||
s.next = minIdx
|
||||
}
|
||||
s.lastPos = pos
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
|
||||
minPos = int32(math.MaxInt32)
|
||||
maxPos = int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
size++
|
||||
if pos < minPos {
|
||||
minPos = pos
|
||||
}
|
||||
if pos > maxPos {
|
||||
maxPos = pos
|
||||
}
|
||||
}
|
||||
if size == 0 {
|
||||
minPos = -1
|
||||
maxPos = -1
|
||||
}
|
||||
return size, minPos, maxPos, s.lastPos
|
||||
}
|
||||
|
||||
func (c *HybridCache) planCheckpoints(batch input.Batch) {
|
||||
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cap(c.curCheckpointPos) < len(c.curSeqs) {
|
||||
c.curCheckpointPos = make([]int32, len(c.curSeqs))
|
||||
} else {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
|
||||
}
|
||||
for i := range c.curCheckpointPos {
|
||||
c.curCheckpointPos[i] = -1
|
||||
}
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
|
||||
posMax := make(map[int]int32, len(c.curSeqs))
|
||||
for i, seq := range batch.Sequences {
|
||||
pos := batch.Positions[i]
|
||||
if cur, ok := posMax[seq]; !ok || pos > cur {
|
||||
posMax[seq] = pos
|
||||
}
|
||||
}
|
||||
|
||||
for i, seq := range c.curSeqs {
|
||||
pos, ok := posMax[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if pos < c.checkpointMinPos {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
store := c.checkpointStore(slot)
|
||||
lastPos := store.lastPos
|
||||
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
|
||||
c.curCheckpointPos[i] = pos
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) checkpointStore(slot int) *slotCheckpointStore {
|
||||
store, ok := c.checkpoints[slot]
|
||||
if ok {
|
||||
return store
|
||||
}
|
||||
store = newSlotCheckpointStore(c.checkpointCount)
|
||||
c.checkpoints[slot] = store
|
||||
return store
|
||||
}
|
||||
|
||||
func (c *HybridCache) checkpointIndexForSlot(slot int, pos int32) int {
|
||||
if c.checkpointCount == 0 {
|
||||
return -1
|
||||
}
|
||||
if idx, ok := c.curCheckpointSlots[slot]; ok {
|
||||
return idx
|
||||
}
|
||||
store := c.checkpointStore(slot)
|
||||
idx := store.record(pos)
|
||||
if idx >= 0 {
|
||||
c.curCheckpointSlots[slot] = idx
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func (c *HybridCache) hasCheckpoint(seq int, pos int32) bool {
|
||||
if pos <= 0 {
|
||||
return false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, _, ok = store.bestIndex(pos)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *HybridCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
|
||||
if targetPos <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
|
||||
return 0, false
|
||||
}
|
||||
idx, pos, ok := store.bestIndex(targetPos)
|
||||
if !ok {
|
||||
size, minPos, maxPos, lastPos := store.window()
|
||||
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
|
||||
"min", minPos, "max", maxPos, "last", lastPos)
|
||||
return 0, false
|
||||
}
|
||||
c.pendingRestore[seq] = checkpointRestore{
|
||||
slot: slot,
|
||||
idx: idx,
|
||||
pos: pos,
|
||||
}
|
||||
return pos + 1, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) applyCheckpointRestore(restore checkpointRestore) error {
|
||||
entry, ok := c.restoreEntry(restore)
|
||||
if !ok {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
|
||||
for layer, src := range entry.conv {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
for layer, src := range entry.delta {
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
|
||||
if len(entry.conv) > 0 || len(entry.delta) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
store := c.checkpoints[restore.slot]
|
||||
store.pruneAfter(restore.pos)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) restoreComplete(restore checkpointRestore) bool {
|
||||
_, ok := c.restoreEntry(restore)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *HybridCache) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
|
||||
store, ok := c.checkpoints[restore.slot]
|
||||
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
|
||||
return nil, false
|
||||
}
|
||||
entry := &store.entries[restore.idx]
|
||||
if entry.pos < 0 {
|
||||
return nil, false
|
||||
}
|
||||
if !c.entryComplete(entry) {
|
||||
return nil, false
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (c *HybridCache) entryComplete(entry *checkpointEntry) bool {
|
||||
for layer := range c.convStates {
|
||||
if entry.conv == nil || entry.conv[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for layer := range c.deltaStates {
|
||||
if entry.delta == nil || entry.delta[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *HybridCache) clearCheckpoints(slot int) {
|
||||
if store, ok := c.checkpoints[slot]; ok {
|
||||
store.reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
srcStore, ok := c.checkpoints[srcSlot]
|
||||
if !ok || srcStore.size == 0 {
|
||||
return
|
||||
}
|
||||
dstStore := c.checkpointStore(dstSlot)
|
||||
dstStore.size = srcStore.size
|
||||
dstStore.next = srcStore.next
|
||||
dstStore.lastPos = srcStore.lastPos
|
||||
|
||||
for i := range srcStore.entries {
|
||||
srcEntry := &srcStore.entries[i]
|
||||
dstEntry := &dstStore.entries[i]
|
||||
dstEntry.pos = srcEntry.pos
|
||||
if srcEntry.conv != nil {
|
||||
if dstEntry.conv == nil {
|
||||
dstEntry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.conv {
|
||||
dst := c.ensureCheckpointConv(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
if srcEntry.delta != nil {
|
||||
if dstEntry.delta == nil {
|
||||
dstEntry.delta = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.delta {
|
||||
dst := c.ensureCheckpointDelta(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointConv(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointConv(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) captureDeltaCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointDelta(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointDelta(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.conv == nil {
|
||||
entry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.conv[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointConvCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointConvCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
|
||||
entry.conv[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *HybridCache) ensureCheckpointDelta(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.delta == nil {
|
||||
entry.delta = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.delta[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointDeltaCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointDeltaCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.deltaStateSize, 1)
|
||||
entry.delta[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *HybridCache) reserveCheckpointConv(layer int) {
|
||||
key := checkpointReserveKey(layer, 0)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointConv(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func (c *HybridCache) reserveCheckpointDelta(layer int) {
|
||||
key := checkpointReserveKey(layer, 1)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointDelta(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func checkpointReserveKey(layer int, kind int) int {
|
||||
return layer*2 + kind
|
||||
}
|
||||
300
model/models/qwen3next/checkpoints_test.go
Normal file
300
model/models/qwen3next/checkpoints_test.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func newTestBackend(tb testing.TB) ml.Backend {
|
||||
tb.Helper()
|
||||
|
||||
f, err := os.CreateTemp(tb.TempDir(), "*.gguf")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
|
||||
_ = f.Close()
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
tb.Cleanup(func() {
|
||||
b.Close()
|
||||
})
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
|
||||
store := newSlotCheckpointStore(2)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
|
||||
_, pos, ok := store.bestIndex(15)
|
||||
if !ok || pos != 10 {
|
||||
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
store.record(30) // overwrite oldest (10)
|
||||
|
||||
if _, _, ok := store.bestIndex(15); ok {
|
||||
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(40)
|
||||
if !ok || pos != 30 {
|
||||
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCachePrepareRestore(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 1, 1)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
store := cache.checkpointStore(0)
|
||||
store.record(5)
|
||||
store.record(9)
|
||||
store.record(15)
|
||||
|
||||
restorePos, ok := cache.PrepareRestore(1, 12)
|
||||
if !ok {
|
||||
t.Fatalf("expected restore ok")
|
||||
}
|
||||
if restorePos != 10 {
|
||||
t.Fatalf("expected restorePos 10, got %d", restorePos)
|
||||
}
|
||||
rest, ok := cache.pendingRestore[1]
|
||||
if !ok {
|
||||
t.Fatalf("expected pending restore entry")
|
||||
}
|
||||
if rest.pos != 9 {
|
||||
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.pruneAfter(20)
|
||||
|
||||
if store.lastPos != 20 {
|
||||
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, pos, ok := store.bestIndex(25)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(35)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreDetachesSharedSlot(t *testing.T) {
|
||||
backend := newTestBackend(t)
|
||||
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.Init(backend, ml.DTypeF16, 2, 8, 2)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.slotForSeq[2] = 0
|
||||
cache.refCount[0] = 2
|
||||
cache.refCount[1] = 0
|
||||
cache.freeSlots = []int{1}
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
if err := cache.Remove(1, 10, math.MaxInt32); err != nil {
|
||||
t.Fatalf("Remove failed: %v", err)
|
||||
}
|
||||
|
||||
if cache.slotForSeq[1] == cache.slotForSeq[2] {
|
||||
t.Fatalf("expected restore to detach shared slot, got same slot %d", cache.slotForSeq[1])
|
||||
}
|
||||
if cache.slotForSeq[1] != 1 {
|
||||
t.Fatalf("expected seq 1 to move to slot 1, got %d", cache.slotForSeq[1])
|
||||
}
|
||||
if cache.slotForSeq[2] != 0 {
|
||||
t.Fatalf("expected seq 2 to remain on slot 0, got %d", cache.slotForSeq[2])
|
||||
}
|
||||
if cache.refCount[0] != 1 || cache.refCount[1] != 1 {
|
||||
t.Fatalf("unexpected refCounts: slot0=%d slot1=%d", cache.refCount[0], cache.refCount[1])
|
||||
}
|
||||
if _, ok := cache.pendingRestore[1]; ok {
|
||||
t.Fatalf("expected pending restore to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Simulate that layer 0 has both conv and delta state (so entryComplete expects both)
|
||||
cache.convStates[0] = nil // placeholder to indicate layer 0 exists
|
||||
cache.deltaStates[0] = nil // placeholder to indicate layer 0 exists
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
entry := &store.entries[idx]
|
||||
// Only set conv checkpoint, not delta - making it incomplete
|
||||
entry.conv = map[int]ml.Tensor{0: nil}
|
||||
// entry.delta is not set, so checkpoint is incomplete
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
err := cache.Remove(1, 10, math.MaxInt32)
|
||||
if !errors.Is(err, kvcache.ErrNotSupported) {
|
||||
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 1, 2, 2)
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Don't set convStates/deltaStates - with no layers to check,
|
||||
// entryComplete will return true as long as entry.pos >= 0
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
// Test that restoreComplete returns true when no layers need checkpoints
|
||||
restore := cache.pendingRestore[1]
|
||||
if !cache.restoreComplete(restore) {
|
||||
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
|
||||
// Test that ring buffer wrap-around reuses entries without clearing maps.
|
||||
store := newSlotCheckpointStore(3)
|
||||
|
||||
// Fill the buffer
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
// Create fake tensor data in the first entry's maps
|
||||
store.entries[0].conv = make(map[int]ml.Tensor)
|
||||
store.entries[0].conv[0] = nil // Simulated tensor reference
|
||||
store.entries[0].delta = make(map[int]ml.Tensor)
|
||||
store.entries[0].delta[0] = nil // Simulated tensor reference
|
||||
|
||||
// Record another entry, which should wrap around and overwrite entry 0
|
||||
store.record(40)
|
||||
|
||||
// Verify the maps are still present (we reuse tensors)
|
||||
if store.entries[0].conv == nil {
|
||||
t.Fatalf("expected conv map to be preserved on reuse")
|
||||
}
|
||||
if store.entries[0].delta == nil {
|
||||
t.Fatalf("expected delta map to be preserved on reuse")
|
||||
}
|
||||
|
||||
// Verify the new position was recorded
|
||||
if store.entries[0].pos != 40 {
|
||||
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
|
||||
// Test behavior when buffer is exactly at capacity
|
||||
store := newSlotCheckpointStore(2)
|
||||
|
||||
idx1 := store.record(10)
|
||||
idx2 := store.record(20)
|
||||
|
||||
if idx1 != 0 || idx2 != 1 {
|
||||
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
|
||||
}
|
||||
|
||||
if store.size != 2 {
|
||||
t.Fatalf("expected size 2, got %d", store.size)
|
||||
}
|
||||
|
||||
// Verify both checkpoints are accessible
|
||||
_, pos1, ok1 := store.bestIndex(15)
|
||||
_, pos2, ok2 := store.bestIndex(25)
|
||||
|
||||
if !ok1 || pos1 != 10 {
|
||||
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
|
||||
}
|
||||
if !ok2 || pos2 != 20 {
|
||||
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
|
||||
// Test behavior with zero-size buffer
|
||||
store := newSlotCheckpointStore(0)
|
||||
|
||||
idx := store.record(10)
|
||||
if idx != -1 {
|
||||
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(15)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint for empty buffer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
|
||||
// Test pruning that removes all checkpoints
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
// Prune everything by setting threshold below all positions
|
||||
store.pruneAfter(5)
|
||||
|
||||
if store.size != 0 {
|
||||
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
|
||||
}
|
||||
// When all checkpoints are pruned, lastPos is reset to -1
|
||||
if store.lastPos != -1 {
|
||||
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(100)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint after pruning all")
|
||||
}
|
||||
}
|
||||
472
model/models/qwen3next/deltanet.go
Normal file
472
model/models/qwen3next/deltanet.go
Normal file
@@ -0,0 +1,472 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
const chunkSize = 64
|
||||
|
||||
// TriType constants for triangular matrix operations
|
||||
const (
|
||||
TriTypeUpperDiag = 0
|
||||
TriTypeUpper = 1
|
||||
TriTypeLowerDiag = 2
|
||||
TriTypeLower = 3
|
||||
)
|
||||
|
||||
// convKernel wraps the 1D convolution kernel tensor
|
||||
type convKernel struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
// Masks holds pre-computed mask tensors for chunked attention
|
||||
type Masks struct {
|
||||
Causal ml.Tensor // Lower triangular [chunkSize, chunkSize]
|
||||
Identity ml.Tensor // Diagonal [chunkSize, chunkSize]
|
||||
Diag ml.Tensor // causal + identity
|
||||
}
|
||||
|
||||
// GatedDeltaNet implements linear attention with SSM convolution and recurrent state.
|
||||
// It implements the Operator interface directly.
|
||||
type GatedDeltaNet struct {
|
||||
// Optimized path: pre-split QKV and gate
|
||||
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
|
||||
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
|
||||
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha
|
||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
||||
SSMOut *nn.Linear `gguf:"ssm_out"`
|
||||
|
||||
// Layer index for cache access (set during model construction)
|
||||
Layer int
|
||||
}
|
||||
|
||||
// createMasks builds the constant mask tensors (called once, reused for all chunks)
|
||||
func createMasks(ctx ml.Context) *Masks {
|
||||
ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
|
||||
ones = ones.Fill(ctx, 1.0)
|
||||
causalMask := ones.Tri(ctx, TriTypeLower)
|
||||
|
||||
onesVec := ctx.Input().Zeros(ml.DTypeF32, chunkSize)
|
||||
onesVec = onesVec.Fill(ctx, 1.0)
|
||||
identity := onesVec.Diag(ctx)
|
||||
|
||||
diagMask := causalMask.Add(ctx, identity)
|
||||
|
||||
return &Masks{
|
||||
Causal: causalMask,
|
||||
Identity: identity,
|
||||
Diag: diagMask,
|
||||
}
|
||||
}
|
||||
|
||||
func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
layer := gdn.Layer
|
||||
nSeqTokens := hiddenStates.Dim(1)
|
||||
nSeqs := hiddenStates.Dim(2)
|
||||
if cache != nil && cache.IsSupportedForBatch() {
|
||||
seqTokens := cache.seqTokens()
|
||||
seqs := cache.numSeqs()
|
||||
if seqTokens > 0 && seqs > 0 {
|
||||
if nSeqs > 1 {
|
||||
if nSeqTokens != seqTokens || nSeqs != seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
} else {
|
||||
if nSeqTokens != seqTokens*seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
|
||||
nSeqTokens = seqTokens
|
||||
nSeqs = seqs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headKDim := opts.ssmDState
|
||||
numKHeads := opts.ssmNGroup
|
||||
numVHeads := opts.ssmDtRank
|
||||
headVDim := opts.ssmDInner / numVHeads
|
||||
convKernelSize := opts.convKernelSize
|
||||
|
||||
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
||||
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
|
||||
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
|
||||
}
|
||||
// Optimized path: pre-split QKV and gate
|
||||
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
||||
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
||||
|
||||
baNewDim := 2 * numVHeads / numKHeads
|
||||
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
|
||||
|
||||
// Split beta and alpha
|
||||
betaSize := numVHeads / numKHeads
|
||||
alphaSize := numVHeads / numKHeads
|
||||
|
||||
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
|
||||
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
|
||||
|
||||
// Reshape to merge head dimensions
|
||||
beta := b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs)
|
||||
alpha := a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||
alphaSoftplus := alphaBiased.Softplus(ctx)
|
||||
gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
|
||||
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
// Get conv state from cache
|
||||
convStates, err := cache.ConvState(ctx, layer)
|
||||
if err != nil {
|
||||
// Log this - if it happens, short-term context will be lost
|
||||
slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
|
||||
convStates = ctx.Input().Zeros(ml.DTypeF32, convKernelSize-1, qkvDim, nSeqs)
|
||||
}
|
||||
|
||||
// Reshape conv states
|
||||
convStates = convStates.Reshape(ctx, convKernelSize-1, qkvDim, nSeqs)
|
||||
|
||||
// Concatenate with input for convolution
|
||||
convInput := convStates.Concat(ctx, qkvMixed, 0)
|
||||
|
||||
// Save new conv state (last convKernelSize-1 tokens)
|
||||
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
|
||||
cache.UpdateConvState(ctx, layer, lastConvStates)
|
||||
|
||||
// Apply SSM convolution (kernel must be F32 for Metal)
|
||||
convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
|
||||
convOutput = convOutput.SILU(ctx)
|
||||
|
||||
// Reshape for extraction
|
||||
convQKVMix := convOutput.Contiguous(ctx, qkvDim, nSeqTokens*nSeqs)
|
||||
|
||||
// Extract convolved Q, K, V
|
||||
qConv := convQKVMix.Slice(ctx, 0, 0, headKDim*numKHeads, 1)
|
||||
kConv := convQKVMix.Slice(ctx, 0, headKDim*numKHeads, 2*headKDim*numKHeads, 1)
|
||||
vConv := convQKVMix.Slice(ctx, 0, 2*headKDim*numKHeads, qkvDim, 1)
|
||||
|
||||
// Reshape to 4D
|
||||
qConv = qConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
|
||||
kConv = kConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
|
||||
vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
// Get delta state from cache
|
||||
state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads)
|
||||
if err != nil {
|
||||
// Log this - if it happens frequently, context will degrade
|
||||
slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
|
||||
state = ctx.Input().Zeros(ml.DTypeF32, headVDim, headVDim*numVHeads, nSeqs)
|
||||
}
|
||||
state = state.Reshape(ctx, headVDim, headVDim*numVHeads, 1, nSeqs)
|
||||
|
||||
// Repeat interleave Q and K if numKHeads != numVHeads
|
||||
if numKHeads != numVHeads {
|
||||
repeatFactor := numVHeads / numKHeads
|
||||
|
||||
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||
|
||||
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||
|
||||
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||
}
|
||||
|
||||
// Choose computation mode based on sequence length
|
||||
var attnOut ml.Tensor
|
||||
if nSeqTokens == 1 {
|
||||
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
|
||||
} else {
|
||||
// Use pre-computed masks from opts (created once in Model.Forward)
|
||||
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
|
||||
}
|
||||
|
||||
// Apply gated normalization
|
||||
attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
||||
z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
||||
|
||||
// norm(attnOut, z) = RMSNorm(attnOut) * silu(z)
|
||||
attnOutNorm := gdn.SSMNorm.Forward(ctx, attnOut2D, opts.eps)
|
||||
zSilu := z2D.SILU(ctx)
|
||||
attnOutGated := attnOutNorm.Mul(ctx, zSilu)
|
||||
|
||||
// Reshape for output projection
|
||||
finalOutput := attnOutGated.Reshape(ctx, headVDim*numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
out := gdn.SSMOut.Forward(ctx, finalOutput)
|
||||
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
|
||||
}
|
||||
|
||||
// deltaNetAutoregressive implements single-token state update.
|
||||
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
|
||||
func (gdn *GatedDeltaNet) deltaNetAutoregressive(
|
||||
ctx ml.Context,
|
||||
q, k, v, gate, beta, state ml.Tensor,
|
||||
opts *Options,
|
||||
layer int,
|
||||
cache *HybridCache,
|
||||
) ml.Tensor {
|
||||
numVHeads := v.Dim(1)
|
||||
headVDim := v.Dim(0)
|
||||
nSeqs := q.Dim(3)
|
||||
|
||||
// L2 normalize Q and K
|
||||
q = q.L2Norm(ctx, opts.eps)
|
||||
k = k.L2Norm(ctx, opts.eps)
|
||||
|
||||
// Scale Q
|
||||
scale := 1.0 / math.Sqrt(float64(headVDim))
|
||||
q = q.Scale(ctx, scale)
|
||||
|
||||
// Sigmoid beta
|
||||
beta = beta.Sigmoid(ctx)
|
||||
|
||||
// Reshape state: [headVDim, headVDim, numVHeads, nSeqs]
|
||||
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
||||
|
||||
// Reshape gate and beta for broadcasting
|
||||
gT := gate.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
||||
betaT := beta.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
||||
|
||||
// Apply exponential to gate
|
||||
gT = gT.Exp(ctx)
|
||||
|
||||
// state = state * g_t
|
||||
state = state.Mul(ctx, gT)
|
||||
|
||||
// kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2)
|
||||
kTUnsqueezed := k.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
|
||||
kvMem := state.Mul(ctx, kTUnsqueezed)
|
||||
// Sum over dim=-2 (second dimension after permute)
|
||||
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
kvMem = kvMem.SumRows(ctx)
|
||||
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
// v_t with singleton dimension
|
||||
vT := v.Reshape(ctx, headVDim, 1, numVHeads, nSeqs)
|
||||
|
||||
// delta = (v_t - kv_mem) * beta_t
|
||||
vDiff := vT.Sub(ctx, kvMem)
|
||||
delta := vDiff.Mul(ctx, betaT)
|
||||
|
||||
// state = state + k_t.unsqueeze(-1) * delta
|
||||
kTUnsqueezedBroad := kTUnsqueezed.Repeat4D(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
||||
kTDelta := kTUnsqueezedBroad.Mul(ctx, delta)
|
||||
state = state.Add(ctx, kTDelta)
|
||||
|
||||
// core_attn_out = (state * q_t.unsqueeze(-1)).sum(dim=-2)
|
||||
qTUnsqueezed := q.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
|
||||
stateQ := state.Mul(ctx, qTUnsqueezed)
|
||||
stateQ = stateQ.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
coreAttnOut := stateQ.SumRows(ctx)
|
||||
coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
// Update delta state in cache
|
||||
cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
||||
|
||||
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs)
|
||||
}
|
||||
|
||||
// deltaNetChunked implements chunked computation for prefill.
|
||||
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
|
||||
func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
ctx ml.Context,
|
||||
q, k, v, gate, beta, state ml.Tensor,
|
||||
masks *Masks,
|
||||
opts *Options,
|
||||
layer int,
|
||||
cache *HybridCache,
|
||||
) ml.Tensor {
|
||||
headKDim := q.Dim(0)
|
||||
numVHeads := v.Dim(1)
|
||||
headVDim := v.Dim(0)
|
||||
nTokens := q.Dim(2)
|
||||
nSeqs := q.Dim(3)
|
||||
|
||||
// L2 normalize Q and K
|
||||
q = q.L2Norm(ctx, opts.eps)
|
||||
k = k.L2Norm(ctx, opts.eps)
|
||||
|
||||
// Scale Q
|
||||
scale := 1.0 / math.Sqrt(float64(headVDim))
|
||||
q = q.Scale(ctx, scale)
|
||||
|
||||
// Sigmoid beta
|
||||
beta = beta.Sigmoid(ctx)
|
||||
|
||||
// Permute tensors for chunked computation
|
||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
||||
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
|
||||
gate = gate.Permute(ctx, 2, 0, 3, 1).Contiguous(ctx, nTokens, 1, numVHeads, nSeqs)
|
||||
|
||||
beta = beta.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
||||
|
||||
// Compute padding
|
||||
pad := (chunkSize - nTokens%chunkSize) % chunkSize
|
||||
nChunks := (nTokens + pad) / chunkSize
|
||||
|
||||
// Pad tensors
|
||||
if pad > 0 {
|
||||
q = q.Pad(ctx, 0, pad, 0, 0)
|
||||
k = k.Pad(ctx, 0, pad, 0, 0)
|
||||
v = v.Pad(ctx, 0, pad, 0, 0)
|
||||
gate = gate.Pad(ctx, pad, 0, 0, 0)
|
||||
beta = beta.Pad(ctx, 0, pad, 0, 0)
|
||||
}
|
||||
|
||||
// Use pre-computed masks (passed in, not recreated)
|
||||
causalMask := masks.Causal
|
||||
identity := masks.Identity
|
||||
diagMask := masks.Diag
|
||||
identity4D := identity.Reshape(ctx, chunkSize, chunkSize, 1, 1)
|
||||
|
||||
// v_beta = v * beta, k_beta = k * beta
|
||||
vBeta := v.Mul(ctx, beta)
|
||||
kBeta := k.Mul(ctx, beta)
|
||||
|
||||
// Reshape for chunked computation
|
||||
q = q.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
k = k.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
|
||||
gate = gate.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
|
||||
|
||||
// g_cumsum = cumsum(gate)
|
||||
gCumsum := gate.CumSum(ctx)
|
||||
|
||||
// Compute decay mask
|
||||
gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
|
||||
gcsJ := gCumsum.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
gcsBroadcast := gcsJ.Repeat4D(ctx, chunkSize, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
decayMask := gcsBroadcast.Sub(ctx, gcsI)
|
||||
|
||||
decayMask = decayMask.Mul(ctx, diagMask)
|
||||
decayMask = decayMask.Exp(ctx)
|
||||
decayMask = decayMask.Mul(ctx, diagMask)
|
||||
|
||||
// k @ k_beta^T
|
||||
kMulKBeta := k.Mulmat(ctx, kBeta)
|
||||
|
||||
// k_decay = k @ k_beta^T * decay_mask
|
||||
kDecay := kMulKBeta.Mul(ctx, decayMask)
|
||||
|
||||
// attn = -k_decay * causal_mask
|
||||
attn := kDecay.Neg(ctx).Mul(ctx, causalMask)
|
||||
|
||||
// Triangular solve: (I - attn_lower)^-1 @ attn
|
||||
attnLower := attn.Mul(ctx, causalMask)
|
||||
lhs := attnLower.Neg(ctx).Add(ctx, identity4D)
|
||||
linSolve := lhs.SolveTri(ctx, attn, true, true, false)
|
||||
attn = linSolve.Mul(ctx, causalMask)
|
||||
attn = attn.Add(ctx, identity4D)
|
||||
|
||||
// v = v_beta^T @ attn
|
||||
vBetaT := vBeta.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
v = vBetaT.Mulmat(ctx, attn)
|
||||
|
||||
// Compute g_exp for state update
|
||||
gCumsumT := gCumsum.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
gExp := gCumsumT.Exp(ctx)
|
||||
|
||||
// kbeta_gexp = k_beta * g_exp
|
||||
kBetaGExp := kBeta.Mul(ctx, gExp)
|
||||
|
||||
// k_cumdecay = attn @ kbeta_gexp^T
|
||||
kBetaGExpT := kBetaGExp.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
kCumdecay := attn.Mulmat(ctx, kBetaGExpT)
|
||||
kCumdecay = kCumdecay.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Pre-compute attn_kq = (k @ q) * decay_mask * diag_mask
|
||||
attnKQ := k.Mulmat(ctx, q)
|
||||
attnKQ = attnKQ.Mul(ctx, decayMask)
|
||||
attnKQ = attnKQ.Mul(ctx, diagMask)
|
||||
|
||||
// Pre-compute g_last and key_gdiff
|
||||
// g_last = view of last element in g_cumsum along chunk_size dimension
|
||||
// We need to get the last row of gCumsum: shape [chunkSize, 1, nChunks, H*n_seqs] -> [1, 1, nChunks, H*n_seqs]
|
||||
gLast := gCumsum.Slice(ctx, 0, chunkSize-1, chunkSize, 1).Contiguous(ctx, 1, 1, nChunks, numVHeads*nSeqs)
|
||||
gLastExp := gLast.Exp(ctx)
|
||||
|
||||
// g_diff = -(g_cumsum - g_last) = g_last - g_cumsum
|
||||
gDiff := gCumsum.Neg(ctx).Add(ctx, gLast)
|
||||
gDiffExp := gDiff.Exp(ctx)
|
||||
|
||||
// Reshapes g_diff_exp to [1, chunkSize, nChunks, ...]
|
||||
gDiffExpReshaped := gDiffExp.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
|
||||
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Process chunks and update state
|
||||
var coreAttnOut ml.Tensor
|
||||
newState := state
|
||||
|
||||
for chunk := range nChunks {
|
||||
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
|
||||
|
||||
// state^T - permute is needed but Contiguous creates a copy
|
||||
stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
||||
|
||||
// v_prime = k_cumdecay @ state
|
||||
vPrime := stateT.Mulmat(ctx, kCumdecayChunk)
|
||||
|
||||
// v_new = v - v_prime
|
||||
vNew := vChunk.Sub(ctx, vPrime)
|
||||
vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// attn_inter = (q * g_exp) @ state
|
||||
qGExp := qChunk.Mul(ctx, gExpChunk)
|
||||
attnInter := stateT.Mulmat(ctx, qGExp)
|
||||
|
||||
// core_attn_out = attn_inter + attn @ v_new
|
||||
vAttn := vNewT.Mulmat(ctx, attnChunk)
|
||||
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
||||
|
||||
if coreAttnOut == nil {
|
||||
coreAttnOut = coreAttnOutChunk
|
||||
} else {
|
||||
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
|
||||
}
|
||||
|
||||
// Update state for next chunk
|
||||
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
|
||||
|
||||
// state = state * g_last + kgdmulvnew
|
||||
gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
||||
newState = newState.Mul(ctx, gExpLastReshaped)
|
||||
newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs))
|
||||
}
|
||||
|
||||
// Final reshape
|
||||
coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
||||
|
||||
// Slice to remove padding
|
||||
if pad > 0 {
|
||||
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
|
||||
}
|
||||
|
||||
// Update delta state in cache
|
||||
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
||||
|
||||
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs)
|
||||
}
|
||||
384
model/models/qwen3next/model.go
Normal file
384
model/models/qwen3next/model.go
Normal file
@@ -0,0 +1,384 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// Options contains model configuration
|
||||
type Options struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
numKVHeads int
|
||||
keyLength int
|
||||
valueLength int
|
||||
ropeDim int
|
||||
|
||||
eps float32
|
||||
ropeBase float32
|
||||
ropeScale float32
|
||||
ropeType string
|
||||
originalContextLength int
|
||||
attentionScale float64
|
||||
|
||||
// MoE config
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
normTopKProb bool
|
||||
|
||||
// Linear attention (Gated Delta Net) config
|
||||
ssmDInner int // d_inner = head_v_dim * num_v_heads
|
||||
ssmDState int // head_k_dim
|
||||
ssmNGroup int // num_k_heads
|
||||
ssmDtRank int // num_v_heads
|
||||
convKernelSize int // SSM conv kernel size
|
||||
|
||||
// Per-layer type from GGUF metadata
|
||||
isRecurrent []bool
|
||||
|
||||
// Pre-computed masks for chunked attention (created once per forward pass)
|
||||
masks *Masks
|
||||
}
|
||||
|
||||
func (o Options) headDim() int {
|
||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
opts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if o.ropeType == "yarn" {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
opts = append(opts,
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(1.),
|
||||
rope.WithAttentionFactor(attnFactor),
|
||||
)
|
||||
}
|
||||
ropeDim := cmp.Or(o.ropeDim, o.headDim())
|
||||
return nn.RoPE(ctx, states, positions, ropeDim, o.ropeBase, 1./o.ropeScale, opts...)
|
||||
}
|
||||
|
||||
// Operator is the interface for attention-like operators
|
||||
type Operator interface {
|
||||
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
|
||||
}
|
||||
|
||||
// MLP is the interface for feedforward networks
|
||||
type MLP interface {
|
||||
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
|
||||
}
|
||||
|
||||
// sparse implements MoE with shared experts
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
|
||||
// Shared experts
|
||||
SharedGateInp *nn.Linear `gguf:"ffn_gate_inp_shexp"`
|
||||
SharedGate *nn.Linear `gguf:"ffn_gate_shexp"`
|
||||
SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
|
||||
SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
|
||||
}
|
||||
|
||||
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
|
||||
if batchSize == 0 {
|
||||
batchSize = 1
|
||||
}
|
||||
hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
|
||||
|
||||
// Router logits
|
||||
routerLogits := mlp.Router.Forward(ctx, hiddenStates2D)
|
||||
|
||||
// Softmax routing weights
|
||||
routingWeights := routerLogits.Softmax(ctx)
|
||||
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
|
||||
if opts.normTopKProb {
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
|
||||
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
|
||||
}
|
||||
|
||||
hiddenStates3D := hiddenStates2D.Reshape(ctx, hiddenStates2D.Dim(0), 1, hiddenStates2D.Dim(1))
|
||||
|
||||
// Expert computation with SILU activation
|
||||
gateOut := mlp.Gate.Forward(ctx, hiddenStates3D, selectedExperts)
|
||||
upOut := mlp.Up.Forward(ctx, hiddenStates3D, selectedExperts)
|
||||
experts := gateOut.SILU(ctx, upOut)
|
||||
experts = mlp.Down.Forward(ctx, experts, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
// Sum over experts
|
||||
moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
// Add shared experts if present
|
||||
if mlp.SharedUp != nil {
|
||||
sharedGate := mlp.SharedGate.Forward(ctx, hiddenStates2D)
|
||||
sharedUp := mlp.SharedUp.Forward(ctx, hiddenStates2D)
|
||||
sharedOut := sharedGate.SILU(ctx, sharedUp)
|
||||
sharedOut = mlp.SharedDown.Forward(ctx, sharedOut)
|
||||
|
||||
// Apply shared expert gating
|
||||
if mlp.SharedGateInp != nil {
|
||||
sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D)
|
||||
sharedGateVal = sharedGateVal.SigmoidOut(ctx)
|
||||
// Broadcast gate to match dimensions
|
||||
sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0))
|
||||
sharedOut = sharedOut.Mul(ctx, sharedGateVal)
|
||||
}
|
||||
|
||||
moeOut = moeOut.Add(ctx, sharedOut)
|
||||
}
|
||||
|
||||
return moeOut
|
||||
}
|
||||
|
||||
// dense implements standard feedforward
|
||||
type dense struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
// Layer represents a single transformer layer
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
AttentionPostNorm *nn.RMSNorm `gguf:"post_attention_norm"` // Post-attention norm before FFN
|
||||
Operator Operator
|
||||
|
||||
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
residual := hiddenStates
|
||||
|
||||
// Pre-attention norm
|
||||
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// Attention (full or linear)
|
||||
var err error
|
||||
hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Output projection for last layer
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
// First residual connection
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
// Save for FFN residual
|
||||
ffnResidual := hiddenStates
|
||||
|
||||
// Post-attention norm (before FFN)
|
||||
hiddenStates = l.AttentionPostNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// FFN
|
||||
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
|
||||
|
||||
// Second residual connection
|
||||
return hiddenStates.Add(ctx, ffnResidual), nil
|
||||
}
|
||||
|
||||
// Model is the main Qwen3-Next model
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Layers []Layer `gguf:"blk"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
cache := m.Cache.(*HybridCache)
|
||||
|
||||
// Create masks once per forward pass
|
||||
m.Options.masks = createMasks(ctx)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
var err error
|
||||
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, positions, outputs, cache, m.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
var _ model.Model = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
layers := make([]Layer, numLayers)
|
||||
|
||||
// Get per-layer head counts (for detecting layer type)
|
||||
type headCounts interface {
|
||||
HeadCount() []uint64
|
||||
HeadCountKV() []uint64
|
||||
}
|
||||
|
||||
var isRecurrent []bool
|
||||
var headCountKV []uint64
|
||||
if hc, ok := c.(headCounts); ok {
|
||||
headCountKV = hc.HeadCountKV()
|
||||
}
|
||||
|
||||
isRecurrent = make([]bool, numLayers)
|
||||
hasZero := false
|
||||
hasFull := false
|
||||
for i := range numLayers {
|
||||
// If KV head count is 0, it's a recurrent layer
|
||||
if i < len(headCountKV) && headCountKV[i] == 0 {
|
||||
isRecurrent[i] = true
|
||||
hasZero = true
|
||||
} else if i < len(headCountKV) && headCountKV[i] > 0 {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if !hasZero || !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
|
||||
}
|
||||
|
||||
// Determine if MoE
|
||||
isMoE := c.Uint("expert_count") > 0
|
||||
|
||||
for i := range layers {
|
||||
if isRecurrent[i] {
|
||||
layers[i].Operator = &GatedDeltaNet{Layer: i}
|
||||
} else {
|
||||
layers[i].Operator = &FullAttention{}
|
||||
}
|
||||
|
||||
if isMoE {
|
||||
layers[i].MLP = &sparse{}
|
||||
} else {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
}
|
||||
|
||||
opts := &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: func() int {
|
||||
for _, v := range headCountKV {
|
||||
if v > 0 {
|
||||
return int(v)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}(),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeType: c.String("rope.scaling.type"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
attentionScale: float64(c.Float("attention.scale")),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||
ssmDInner: int(c.Uint("ssm.inner_size")),
|
||||
ssmDState: int(c.Uint("ssm.state_size")),
|
||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||
isRecurrent: isRecurrent,
|
||||
}
|
||||
if opts.numKVHeads == 0 {
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
||||
}
|
||||
|
||||
// Calculate cache dimensions
|
||||
convDim := max(0, opts.convKernelSize-1)
|
||||
convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
|
||||
headVDim := 0
|
||||
numVHeads := opts.ssmDtRank
|
||||
if numVHeads > 0 {
|
||||
headVDim = opts.ssmDInner / numVHeads
|
||||
}
|
||||
deltaStateSize := headVDim * headVDim * numVHeads
|
||||
|
||||
// Validate dimension assumption: headKDim == headVDim is required for state computations
|
||||
headKDim := opts.ssmDState
|
||||
if headKDim != headVDim && headKDim > 0 && headVDim > 0 {
|
||||
return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim)
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
// Qwen3 tokenizers typically set add_bos_token=false and bos_token=null.
|
||||
// Default to false when the GGUF key is missing to avoid injecting a spurious BOS.
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: opts,
|
||||
}
|
||||
|
||||
m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("qwen3next", New)
|
||||
}
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
17
model/parsers/glmocr.go
Normal file
17
model/parsers/glmocr.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package parsers
|
||||
|
||||
import "github.com/ollama/ollama/api"
|
||||
|
||||
// GlmOcrParser is the GLM46 parser with thinking disabled.
|
||||
type GlmOcrParser struct {
|
||||
GLM46Parser
|
||||
}
|
||||
|
||||
func (p *GlmOcrParser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *GlmOcrParser) Init(tools []api.Tool, _ *api.Message, _ *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
return tools
|
||||
}
|
||||
@@ -71,6 +71,8 @@ func ParserForName(name string) Parser {
|
||||
return &FunctionGemmaParser{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Parser{}
|
||||
case "glm-ocr":
|
||||
return &GlmOcrParser{}
|
||||
case "lfm2":
|
||||
return &LFM2Parser{hasThinkingSupport: false}
|
||||
case "lfm2-thinking":
|
||||
|
||||
109
model/renderers/glmocr.go
Normal file
109
model/renderers/glmocr.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type GlmOcrRenderer struct{}
|
||||
|
||||
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(formatGLM47ToolJSON(d))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
|
||||
}
|
||||
|
||||
enableThinking := false
|
||||
thinkingExplicitlySet := false
|
||||
if thinkValue != nil {
|
||||
enableThinking = thinkValue.Bool()
|
||||
thinkingExplicitlySet = true
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
|
||||
sb.WriteString("/nothink")
|
||||
}
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>\n")
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("<think>" + strings.TrimSpace(message.Thinking) + "</think>")
|
||||
} else {
|
||||
sb.WriteString("<think></think>")
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString("\n" + strings.TrimSpace(message.Content))
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("\n<tool_call>" + toolCall.Function.Name)
|
||||
sb.WriteString(renderGlmOcrToolArguments(toolCall.Function.Arguments))
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("\n<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>\n")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|assistant|>\n")
|
||||
if thinkingExplicitlySet && !enableThinking {
|
||||
sb.WriteString("<think></think>\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func renderGlmOcrToolArguments(args api.ToolCallFunctionArguments) string {
|
||||
var sb strings.Builder
|
||||
for key, value := range args.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>")
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -167,12 +167,12 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
|
||||
|
||||
// only start a new user block if this is the first tool response
|
||||
if i == 0 || filteredMessages[i-1].Role != "tool" {
|
||||
sb.WriteString(imStartTag + "user\n")
|
||||
sb.WriteString(imStartTag + "user")
|
||||
}
|
||||
|
||||
sb.WriteString("<tool_response>\n")
|
||||
sb.WriteString("\n<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>\n")
|
||||
sb.WriteString("\n</tool_response>")
|
||||
|
||||
// close the user block only if this is the last tool response
|
||||
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -127,8 +128,7 @@ fahrenheit
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>user
|
||||
That sounds nice! What about New York?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
@@ -233,8 +233,7 @@ I'll call double(1) and triple(2) for you.
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"number": 6}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
@@ -280,8 +279,7 @@ call tool<|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"payload": {"foo": "bar"}}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
@@ -337,6 +335,31 @@ func TestFormatToolCallArgument(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderRendererToolResponseNoTrailingNewline(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{Role: "user", Content: "call tool"},
|
||||
{Role: "assistant", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{
|
||||
Name: "echo",
|
||||
Arguments: testArgs(map[string]any{"payload": "ok"}),
|
||||
}},
|
||||
}},
|
||||
{Role: "tool", Content: "{\"payload\":\"ok\"}", ToolName: "echo"},
|
||||
}
|
||||
|
||||
rendered, err := (&Qwen3CoderRenderer{}).Render(msgs, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if strings.Contains(rendered, "</tool_response>\n<|im_end|>") {
|
||||
t.Fatalf("expected no newline after </tool_response>, got:\n%s", rendered)
|
||||
}
|
||||
if !strings.Contains(rendered, "</tool_response><|im_end|>") {
|
||||
t.Fatalf("expected </tool_response> to be immediately followed by <|im_end|>, got:\n%s", rendered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ToolDefinitionTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -82,6 +82,8 @@ func rendererForName(name string) Renderer {
|
||||
return &FunctionGemmaRenderer{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Renderer{}
|
||||
case "glm-ocr":
|
||||
return &GlmOcrRenderer{}
|
||||
case "lfm2":
|
||||
return &LFM2Renderer{IsThinking: false}
|
||||
case "lfm2-thinking":
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWordPiece(t *testing.T) {
|
||||
wpm := NewWordPiece(
|
||||
&Vocabulary{
|
||||
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
|
||||
AddBOS: true,
|
||||
AddEOS: true,
|
||||
BOS: []int32{1},
|
||||
EOS: []int32{2},
|
||||
},
|
||||
true, // lowercase
|
||||
)
|
||||
|
||||
ids, err := wpm.Encode("Hello world!", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
|
||||
t.Errorf("unexpected ids (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
words, err := wpm.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWordPieceWords(t *testing.T) {
|
||||
var wpm WordPiece
|
||||
|
||||
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
|
||||
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
|
||||
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
@@ -124,8 +124,17 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*In
|
||||
}
|
||||
|
||||
if c.cache != nil {
|
||||
if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) {
|
||||
numPast = 0
|
||||
if numPast > 0 {
|
||||
// Recurrent caches use checkpoints to pick a safe resume position.
|
||||
if cc, ok := c.cache.(kvcache.CheckpointCache); ok {
|
||||
if restored, ok := cc.PrepareRestore(slot.Id, numPast); ok {
|
||||
numPast = restored
|
||||
} else {
|
||||
numPast = 0
|
||||
}
|
||||
} else if !c.cache.CanResume(slot.Id, numPast) {
|
||||
numPast = 0
|
||||
}
|
||||
}
|
||||
|
||||
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
|
||||
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
}
|
||||
|
||||
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
|
||||
decoder := func(tokenID int) string {
|
||||
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
|
||||
text, _ := tok.Decode([]int32{int32(tokenID)})
|
||||
return text
|
||||
}
|
||||
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
|
||||
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
|
||||
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -514,13 +515,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
nextBatch.seqs[seqIdx] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []*input.Input{}
|
||||
@@ -709,7 +703,6 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numPredicted++
|
||||
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||
seq.inputs = []*input.Input{nextToken}
|
||||
nextBatchTokens[i] = nextToken
|
||||
@@ -740,8 +733,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
if seq == nil || nextBatchTokens[i] == nil {
|
||||
continue
|
||||
}
|
||||
// If the sequence was replaced while this batch was computing, discard results.
|
||||
if activeBatch.seqs[i] != seq {
|
||||
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
|
||||
seq.lastUpdatedAt = t
|
||||
seq.numPredicted++
|
||||
if seq.numPredicted == 1 {
|
||||
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
|
||||
seq.startedAt = seq.lastUpdatedAt
|
||||
@@ -766,7 +765,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
nextBatchTokens[i].Token = token
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
@@ -775,18 +774,25 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
|
||||
if err != nil {
|
||||
panic("failed to decode token")
|
||||
}
|
||||
|
||||
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
|
||||
if seq.logprobs {
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
|
||||
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
|
||||
}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(i, llm.DoneReasonLength)
|
||||
continue
|
||||
}
|
||||
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||
@@ -873,7 +879,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var grammar *sample.GrammarSampler
|
||||
var err error
|
||||
if req.Grammar != "" {
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -3,7 +3,7 @@ package runner
|
||||
import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
@@ -11,22 +11,13 @@ func Execute(args []string) error {
|
||||
args = args[1:]
|
||||
}
|
||||
|
||||
var newRunner bool
|
||||
var imageRunner bool
|
||||
if len(args) > 0 && args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
if len(args) > 0 && args[0] == "--image-engine" {
|
||||
args = args[1:]
|
||||
imageRunner = true
|
||||
}
|
||||
|
||||
if imageRunner {
|
||||
return imagerunner.Execute(args)
|
||||
} else if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
return llamarunner.Execute(args)
|
||||
if len(args) > 0 {
|
||||
switch args[0] {
|
||||
case "--ollama-engine":
|
||||
return ollamarunner.Execute(args[1:])
|
||||
case "--imagegen-engine":
|
||||
return imagegen.Execute(args[1:])
|
||||
}
|
||||
}
|
||||
return llamarunner.Execute(args)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// token represents information about a single token during sampling
|
||||
@@ -168,15 +168,15 @@ type GrammarSampler struct {
|
||||
grammar *llama.Grammar
|
||||
}
|
||||
|
||||
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(model.Vocabulary().Values))
|
||||
pieces := make([]string, len(model.Vocabulary().Values))
|
||||
for i := range model.Vocabulary().Values {
|
||||
pieces[i], _ = model.Decode([]int32{int32(i)})
|
||||
func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(tok.Vocabulary().Values))
|
||||
pieces := make([]string, len(tok.Vocabulary().Values))
|
||||
for i := range tok.Vocabulary().Values {
|
||||
pieces[i], _ = tok.Decode([]int32{int32(i)})
|
||||
vocabIds[i] = uint32(i)
|
||||
}
|
||||
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tok.Vocabulary().EOS)
|
||||
if grammar == nil {
|
||||
return nil, errors.New("sample: failed to initialize grammar")
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
func modelHelper(t testing.TB) tokenizer.Tokenizer {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||
f, err := os.Open(filepath.FromSlash("../tokenizer/testdata/llama3.2/encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
|
||||
merges := make([]string, 0, 1)
|
||||
// Only need vocab for Grammar Test
|
||||
return model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
return tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: tokens,
|
||||
Types: make([]int32, len(vocab)),
|
||||
Merges: merges,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user