mirror of
https://github.com/ollama/ollama.git
synced 2026-02-12 00:23:16 -05:00
Compare commits
64 Commits
v0.15.2
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ecb89007e | ||
|
|
da0190f8e4 | ||
|
|
1f2594f50f | ||
|
|
3fb0958e01 | ||
|
|
f08427c138 | ||
|
|
2dbb000908 | ||
|
|
c980e19995 | ||
|
|
6162374ca9 | ||
|
|
44bdd9a2ef | ||
|
|
db493d6e5e | ||
|
|
75695f16a5 | ||
|
|
a0407d07fa | ||
|
|
9ec733e527 | ||
|
|
5ef04dab52 | ||
|
|
aea316f1e9 | ||
|
|
235ba3df5c | ||
|
|
099a0f18ef | ||
|
|
fff696ee31 | ||
|
|
2e3ce6eab3 | ||
|
|
9e2003f88a | ||
|
|
42e1d49fbe | ||
|
|
814630ca60 | ||
|
|
87cf187774 | ||
|
|
6ddd8862cd | ||
|
|
f1373193dc | ||
|
|
8a4b77f9da | ||
|
|
5f53fe7884 | ||
|
|
7ab4ca0e7f | ||
|
|
e36f389e82 | ||
|
|
c61023f554 | ||
|
|
d25535c3f3 | ||
|
|
c323161f24 | ||
|
|
255579aaa7 | ||
|
|
f7102ba826 | ||
|
|
cefabd79a8 | ||
|
|
df70249520 | ||
|
|
77eb2ca619 | ||
|
|
ee25219edd | ||
|
|
b1fccabb34 | ||
|
|
a6355329bf | ||
|
|
0398b24b42 | ||
|
|
75b1dddf91 | ||
|
|
e1e80ffc3e | ||
|
|
71896485fd | ||
|
|
ef00199fb4 | ||
|
|
8f4a008139 | ||
|
|
d8cc798c2b | ||
|
|
6582f6da5c | ||
|
|
0334ffa625 | ||
|
|
d11fbd2c60 | ||
|
|
6a7c3f188e | ||
|
|
427e2c962a | ||
|
|
27db7f806f | ||
|
|
3590fbfa76 | ||
|
|
cd0094f772 | ||
|
|
06bc8e6712 | ||
|
|
fc5f9bb448 | ||
|
|
a0740f7ef7 | ||
|
|
a0923cbdd0 | ||
|
|
f92e362b2e | ||
|
|
aa23d8ecd2 | ||
|
|
7b62c41060 | ||
|
|
26acab64b7 | ||
|
|
e0f03790b1 |
6
.github/workflows/release.yaml
vendored
6
.github/workflows/release.yaml
vendored
@@ -337,6 +337,7 @@ jobs:
|
||||
name: bundles-windows
|
||||
path: |
|
||||
dist/*.zip
|
||||
dist/*.ps1
|
||||
dist/OllamaSetup.exe
|
||||
|
||||
linux-build:
|
||||
@@ -514,6 +515,9 @@ jobs:
|
||||
- name: Log dist contents
|
||||
run: |
|
||||
ls -l dist/
|
||||
- name: Copy install scripts to dist
|
||||
run: |
|
||||
cp scripts/install.sh dist/install.sh
|
||||
- name: Generate checksum file
|
||||
run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
|
||||
working-directory: dist
|
||||
@@ -536,7 +540,7 @@ jobs:
|
||||
- name: Upload release artifacts
|
||||
run: |
|
||||
pids=()
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
|
||||
echo "Uploading $payload"
|
||||
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
||||
pids[$!]=$!
|
||||
|
||||
22
.github/workflows/test-install.yaml
vendored
Normal file
22
.github/workflows/test-install.yaml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
name: test-install
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'scripts/install.sh'
|
||||
- '.github/workflows/test-install.yaml'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Run install script
|
||||
run: sh ./scripts/install.sh
|
||||
env:
|
||||
OLLAMA_NO_START: 1 # do not start app
|
||||
- name: Verify ollama is available
|
||||
run: ollama --version
|
||||
@@ -182,7 +182,7 @@ option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||
|
||||
if(MLX_ENGINE)
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
||||
|
||||
# Find CUDA toolkit if MLX is built with CUDA support
|
||||
find_package(CUDAToolkit)
|
||||
@@ -216,4 +216,4 @@ if(MLX_ENGINE)
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -147,7 +147,7 @@ ARG PARALLEL
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY x/imagegen/mlx x/imagegen/mlx
|
||||
COPY go.mod go.sum .
|
||||
COPY MLX_VERSION .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
|
||||
@@ -358,6 +358,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Odin Runes](https://github.com/leonid20000/OdinRunes)
|
||||
- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
|
||||
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
|
||||
- [Screenpipe](https://github.com/mediar-ai/screenpipe) (24/7 screen & mic recording with AI-powered search, uses Ollama for local LLM features)
|
||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
|
||||
@@ -465,6 +466,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
|
||||
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
|
||||
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
|
||||
- [Stakpak](https://github.com/stakpak/agent) (An open source, vendor neutral DevOps agent that works with any model, and any stack, for teams who just want to ship)
|
||||
|
||||
### Cloud
|
||||
|
||||
|
||||
153
anthropic/anthropic.go
Normal file → Executable file
153
anthropic/anthropic.go
Normal file → Executable file
@@ -211,6 +211,7 @@ type MessageDelta struct {
|
||||
|
||||
// DeltaUsage contains cumulative token usage
|
||||
type DeltaUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
@@ -517,24 +518,26 @@ func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
estimatedInputTokens: estimatedInputTokens,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -550,7 +553,11 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
// Use actual metrics if available, otherwise use estimate
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
|
||||
c.inputTokens = c.estimatedInputTokens
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
@@ -721,6 +728,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
})
|
||||
}
|
||||
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
c.outputTokens = r.Metrics.EvalCount
|
||||
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||
|
||||
@@ -732,6 +740,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: DeltaUsage{
|
||||
InputTokens: c.inputTokens,
|
||||
OutputTokens: c.outputTokens,
|
||||
},
|
||||
},
|
||||
@@ -776,3 +785,117 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// CountTokensRequest represents an Anthropic count_tokens request
|
||||
type CountTokensRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
|
||||
func EstimateInputTokens(req MessagesRequest) int {
|
||||
return estimateTokens(CountTokensRequest{
|
||||
Model: req.Model,
|
||||
Messages: req.Messages,
|
||||
System: req.System,
|
||||
Tools: req.Tools,
|
||||
Thinking: req.Thinking,
|
||||
})
|
||||
}
|
||||
|
||||
// CountTokensResponse represents an Anthropic count_tokens response
|
||||
type CountTokensResponse struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
}
|
||||
|
||||
// estimateTokens returns a rough estimate of tokens (len/4).
|
||||
// TODO: Replace with actual tokenization via Tokenize API for accuracy.
|
||||
// Current len/4 heuristic is a rough approximation (~4 chars/token average).
|
||||
func estimateTokens(req CountTokensRequest) int {
|
||||
var totalLen int
|
||||
|
||||
// Count system prompt
|
||||
if req.System != nil {
|
||||
totalLen += countAnyContent(req.System)
|
||||
}
|
||||
|
||||
// Count messages
|
||||
for _, msg := range req.Messages {
|
||||
// Count role (always present)
|
||||
totalLen += len(msg.Role)
|
||||
// Count content
|
||||
contentLen := countAnyContent(msg.Content)
|
||||
totalLen += contentLen
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
|
||||
}
|
||||
|
||||
// Return len/4 as rough token estimate, minimum 1 if there's any content
|
||||
tokens := totalLen / 4
|
||||
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
|
||||
tokens = 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func countAnyContent(content any) int {
|
||||
if content == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return len(c)
|
||||
case []any:
|
||||
total := 0
|
||||
for _, block := range c {
|
||||
total += countContentBlock(block)
|
||||
}
|
||||
return total
|
||||
default:
|
||||
if data, err := json.Marshal(content); err == nil {
|
||||
return len(data)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func countContentBlock(block any) int {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
if s, ok := block.(string); ok {
|
||||
return len(s)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
total := 0
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
total += len(text)
|
||||
}
|
||||
|
||||
if thinking, ok := blockMap["thinking"].(string); ok {
|
||||
total += len(thinking)
|
||||
}
|
||||
|
||||
if blockType == "tool_use" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if blockType == "tool_result" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
142
anthropic/anthropic_test.go
Normal file → Executable file
142
anthropic/anthropic_test.go
Normal file → Executable file
@@ -321,8 +321,6 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
||||
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
||||
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
@@ -605,7 +603,7 @@ func TestGenerateMessageID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
@@ -642,7 +640,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{EvalCount: 5},
|
||||
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||
}
|
||||
|
||||
events2 := conv.Process(resp2)
|
||||
@@ -650,6 +648,24 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||
hasStop := false
|
||||
for _, e := range events2 {
|
||||
if e.Event == "message_delta" {
|
||||
if data, ok := e.Data.(MessageDeltaEvent); ok {
|
||||
if data.Type != "message_delta" {
|
||||
t.Errorf("unexpected data type: %+v", data)
|
||||
}
|
||||
|
||||
if data.Delta.StopReason != "end_turn" {
|
||||
t.Errorf("unexpected stop reason: %+v", data.Delta.StopReason)
|
||||
}
|
||||
|
||||
if data.Usage.InputTokens != 10 || data.Usage.OutputTokens != 5 {
|
||||
t.Errorf("unexpected usage: %+v", data.Usage)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("unexpected data: %+v", e.Data)
|
||||
}
|
||||
}
|
||||
|
||||
if e.Event == "message_stop" {
|
||||
hasStop = true
|
||||
}
|
||||
@@ -660,7 +676,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -713,7 +729,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
@@ -760,7 +776,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
@@ -824,10 +840,6 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
||||
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
||||
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
||||
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
||||
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -881,11 +893,9 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -919,7 +929,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -951,3 +961,105 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEstimateTokens_SimpleMessage(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// "user" (4) + "Hello, world!" (13) = 17 chars / 4 = 4 tokens
|
||||
if tokens < 1 {
|
||||
t.Errorf("expected at least 1 token, got %d", tokens)
|
||||
}
|
||||
// Sanity check: shouldn't be wildly off
|
||||
if tokens > 10 {
|
||||
t.Errorf("expected fewer than 10 tokens for short message, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// System prompt adds to count
|
||||
if tokens < 5 {
|
||||
t.Errorf("expected at least 5 tokens with system prompt, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithTools(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather for a location",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// Tools add significant content
|
||||
if tokens < 10 {
|
||||
t.Errorf("expected at least 10 tokens with tools, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithThinking(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this carefully...",
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "Here is my response.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// Thinking content should be counted
|
||||
if tokens < 10 {
|
||||
t.Errorf("expected at least 10 tokens with thinking content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_EmptyContent(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
if tokens != 0 {
|
||||
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -466,3 +466,25 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// AliasRequest is the request body for creating or updating a model alias.
|
||||
type AliasRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
|
||||
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
// AliasDeleteRequest is the request body for deleting a model alias.
|
||||
type AliasDeleteRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
|
||||
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
|
||||
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
13
cmd/background_unix.go
Normal file
13
cmd/background_unix.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !windows
|
||||
|
||||
package cmd
|
||||
|
||||
import "syscall"
|
||||
|
||||
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Unix.
|
||||
// Setpgid prevents the server from being killed when the parent process exits.
|
||||
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
}
|
||||
}
|
||||
12
cmd/background_windows.go
Normal file
12
cmd/background_windows.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package cmd
|
||||
|
||||
import "syscall"
|
||||
|
||||
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Windows.
|
||||
// CREATE_NO_WINDOW (0x08000000) prevents a console window from appearing.
|
||||
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
CreationFlags: 0x08000000,
|
||||
HideWindow: true,
|
||||
}
|
||||
}
|
||||
283
cmd/cmd.go
283
cmd/cmd.go
@@ -15,6 +15,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -29,6 +30,7 @@ import (
|
||||
"github.com/containerd/console"
|
||||
"github.com/mattn/go-runewidth"
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/pkg/browser"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -36,6 +38,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/tui"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
@@ -52,7 +55,50 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
func init() {
|
||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem) (string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
}
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
}
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
config.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
userName, err := tui.RunSignIn(modelName, signInURL)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
return userName, err
|
||||
}
|
||||
|
||||
config.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
ok, err := tui.RunConfirm(prompt)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return false, config.ErrCancelled
|
||||
}
|
||||
return ok, err
|
||||
}
|
||||
}
|
||||
|
||||
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||
@@ -366,14 +412,25 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
return err
|
||||
} else if info.RemoteHost != "" {
|
||||
// Cloud model, no need to load/unload
|
||||
|
||||
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||
|
||||
// Check if user is signed in for ollama.com cloud models
|
||||
if isCloud {
|
||||
if _, err := client.Whoami(cmd.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if opts.ShowConnect {
|
||||
p.StopAndClear()
|
||||
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
|
||||
if isCloud {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -663,6 +720,7 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println()
|
||||
|
||||
if aErr.SigninURL != "" {
|
||||
_ = browser.OpenURL(aErr.SigninURL)
|
||||
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
@@ -1750,7 +1808,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||
return err
|
||||
}
|
||||
if err := startApp(cmd.Context(), client); err != nil {
|
||||
return fmt.Errorf("ollama server not responding - %w", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -1791,6 +1849,197 @@ Environment Variables:
|
||||
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
|
||||
}
|
||||
|
||||
// ensureServerRunning checks if the ollama server is running and starts it in the background if not.
|
||||
func ensureServerRunning(ctx context.Context) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if server is already running
|
||||
if err := client.Heartbeat(ctx); err == nil {
|
||||
return nil // server is already running
|
||||
}
|
||||
|
||||
// Server not running, start it in the background
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not find executable: %w", err)
|
||||
}
|
||||
|
||||
serverCmd := exec.CommandContext(ctx, exe, "serve")
|
||||
serverCmd.Env = os.Environ()
|
||||
serverCmd.SysProcAttr = backgroundServerSysProcAttr()
|
||||
if err := serverCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the server to be ready
|
||||
for {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := client.Heartbeat(ctx); err == nil {
|
||||
return nil // server has started
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runInteractiveTUI runs the main interactive TUI menu.
|
||||
func runInteractiveTUI(cmd *cobra.Command) {
|
||||
// Ensure the server is running before showing the TUI
|
||||
if err := ensureServerRunning(cmd.Context()); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem) (string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
}
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
}
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
for {
|
||||
result, err := tui.Run()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
runModel := func(modelName string) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil {
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
_ = config.SetLastModel(modelName)
|
||||
opts := runOptions{
|
||||
Model: modelName,
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := generateInteractive(cmd, opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
launchIntegration := func(name string) bool {
|
||||
// If not configured or model no longer exists, prompt for model selection
|
||||
configuredModel := config.IntegrationModel(name)
|
||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
return false // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
|
||||
return true
|
||||
}
|
||||
}
|
||||
if err := config.LaunchIntegration(name); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
switch result.Selection {
|
||||
case tui.SelectionNone:
|
||||
// User quit
|
||||
return
|
||||
case tui.SelectionRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
if modelName := config.LastModel(); modelName != "" {
|
||||
runModel(modelName)
|
||||
} else {
|
||||
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
runModel(modelName)
|
||||
}
|
||||
case tui.SelectionChangeRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
// Use model from modal if selected, otherwise show picker
|
||||
modelName := result.Model
|
||||
if modelName == "" {
|
||||
var err error
|
||||
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
runModel(modelName)
|
||||
case tui.SelectionIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
if !launchIntegration(result.Integration) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
case tui.SelectionChangeIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
// Use model from modal if selected, otherwise show picker
|
||||
if result.Model != "" {
|
||||
// Model already selected from modal - save and launch
|
||||
if err := config.SaveIntegrationModel(result.Integration, result.Model); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
} else {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegration(result.Integration); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewCLI() *cobra.Command {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
cobra.EnableCommandSorting = false
|
||||
@@ -1813,11 +2062,13 @@ func NewCLI() *cobra.Command {
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Print(cmd.UsageString())
|
||||
runInteractiveTUI(cmd)
|
||||
},
|
||||
}
|
||||
|
||||
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
||||
rootCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
rootCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
@@ -1888,7 +2139,7 @@ func NewCLI() *cobra.Command {
|
||||
serveCmd := &cobra.Command{
|
||||
Use: "serve",
|
||||
Aliases: []string{"start"},
|
||||
Short: "Start ollama",
|
||||
Short: "Start Ollama",
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: RunServer,
|
||||
}
|
||||
@@ -1921,6 +2172,15 @@ func NewCLI() *cobra.Command {
|
||||
RunE: SigninHandler,
|
||||
}
|
||||
|
||||
loginCmd := &cobra.Command{
|
||||
Use: "login",
|
||||
Short: "Sign in to ollama.com",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(0),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SigninHandler,
|
||||
}
|
||||
|
||||
signoutCmd := &cobra.Command{
|
||||
Use: "signout",
|
||||
Short: "Sign out from ollama.com",
|
||||
@@ -1929,6 +2189,15 @@ func NewCLI() *cobra.Command {
|
||||
RunE: SignoutHandler,
|
||||
}
|
||||
|
||||
logoutCmd := &cobra.Command{
|
||||
Use: "logout",
|
||||
Short: "Sign out from ollama.com",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(0),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SignoutHandler,
|
||||
}
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
@@ -2025,13 +2294,15 @@ func NewCLI() *cobra.Command {
|
||||
pullCmd,
|
||||
pushCmd,
|
||||
signinCmd,
|
||||
loginCmd,
|
||||
signoutCmd,
|
||||
logoutCmd,
|
||||
listCmd,
|
||||
psCmd,
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
config.LaunchCmd(checkServerHeartbeat),
|
||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
105
cmd/cmd_test.go
105
cmd/cmd_test.go
@@ -3,6 +3,7 @@ package cmd
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -1553,7 +1554,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
Details: api.ModelDetails{
|
||||
Family: "ZImagePipeline",
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
QuantizationLevel: "Q8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImage},
|
||||
Requires: "0.14.0",
|
||||
@@ -1565,7 +1566,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
expect := " Model\n" +
|
||||
" architecture ZImagePipeline \n" +
|
||||
" parameters 10.3B \n" +
|
||||
" quantization FP8 \n" +
|
||||
" quantization Q8 \n" +
|
||||
" requires 0.14.0 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
@@ -1659,3 +1660,103 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
t.Error("Copy Think should not be affected by original modification")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteHost string
|
||||
whoamiStatus int
|
||||
whoamiResp any
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "ollama.com cloud model - user signed in",
|
||||
remoteHost: "https://ollama.com",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
},
|
||||
{
|
||||
name: "ollama.com cloud model - user not signed in",
|
||||
remoteHost: "https://ollama.com",
|
||||
whoamiStatus: http.StatusUnauthorized,
|
||||
whoamiResp: map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
},
|
||||
expectedError: "unauthorized",
|
||||
},
|
||||
{
|
||||
name: "non-ollama.com remote - no auth check",
|
||||
remoteHost: "https://other-remote.com",
|
||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||
whoamiResp: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
whoamiCalled := false
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
RemoteHost: tt.remoteHost,
|
||||
RemoteModel: "test-model",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
case "/api/me":
|
||||
whoamiCalled = true
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(tt.whoamiStatus)
|
||||
if tt.whoamiResp != nil {
|
||||
if err := json.NewEncoder(w).Encode(tt.whoamiResp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
opts := &runOptions{
|
||||
Model: "test-cloud-model",
|
||||
ShowConnect: false,
|
||||
}
|
||||
|
||||
err := loadOrUnloadModel(cmd, opts)
|
||||
|
||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
|
||||
if !whoamiCalled {
|
||||
t.Error("expected whoami to be called for ollama.com cloud model")
|
||||
}
|
||||
} else {
|
||||
if whoamiCalled {
|
||||
t.Error("whoami should not be called for non-ollama.com remote")
|
||||
}
|
||||
}
|
||||
|
||||
if tt.expectedError != "" {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
||||
} else {
|
||||
var authErr api.AuthorizationError
|
||||
if !errors.As(err, &authErr) {
|
||||
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,23 +1,32 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration
|
||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
// Compile-time check that Claude implements AliasConfigurer
|
||||
var _ AliasConfigurer = (*Claude)(nil)
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string) []string {
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
return []string{"--model", model}
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
return nil
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
@@ -39,20 +48,145 @@ func (c *Claude) findPath() (string, error) {
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string) error {
|
||||
func (c *Claude) Run(model string, args []string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command(claudePath, c.args(model)...)
|
||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL=http://localhost:11434",
|
||||
|
||||
env := append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
|
||||
env = append(env, c.modelEnvVars(model)...)
|
||||
|
||||
cmd.Env = env
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||
func (c *Claude) modelEnvVars(model string) []string {
|
||||
primary := model
|
||||
fast := model
|
||||
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
|
||||
if p := cfg.Aliases["primary"]; p != "" {
|
||||
primary = p
|
||||
}
|
||||
if f := cfg.Aliases["fast"]; f != "" {
|
||||
fast = f
|
||||
}
|
||||
}
|
||||
return []string{
|
||||
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
|
||||
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureAliases sets up model aliases for Claude Code.
|
||||
// model: the model to use (if empty, user will be prompted to select)
|
||||
// aliases: existing alias configuration to preserve/update
|
||||
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
|
||||
// there is a better strategy for prompt caching on local models.
|
||||
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existingAliases {
|
||||
aliases[k] = v
|
||||
}
|
||||
|
||||
if model != "" {
|
||||
aliases["primary"] = model
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" {
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
||||
return aliases, false, nil
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
return aliases, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := DefaultSingleSelector("Select model:", items)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
aliases["primary"] = primary
|
||||
}
|
||||
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
return aliases, true, nil
|
||||
}
|
||||
|
||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
||||
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
|
||||
// prevent stale routing to a previous cloud model.
|
||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
if aliases["fast"] == "" {
|
||||
for _, prefix := range prefixes {
|
||||
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
prefixAliases := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
var errs []string
|
||||
for prefix, target := range prefixAliases {
|
||||
req := &api.AliasRequest{
|
||||
Alias: prefix,
|
||||
Target: target,
|
||||
PrefixMatching: true,
|
||||
}
|
||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
||||
errs = append(errs, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -84,18 +85,114 @@ func TestClaudeArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil},
|
||||
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil, nil},
|
||||
{"with model and verbose", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||
{"with allowed tools", "llama3.2", []string{"--allowedTools", "Read,Write,Bash"}, []string{"--model", "llama3.2", "--allowedTools", "Read,Write,Bash"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeModelEnvVars(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
envMap := func(envs []string) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for _, e := range envs {
|
||||
k, v, _ := strings.Cut(e, "=")
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
got := envMap(c.modelEnvVars("llama3.2"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
||||
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2" {
|
||||
t.Errorf("SONNET = %q, want llama3.2", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2" {
|
||||
t.Errorf("HAIKU = %q, want llama3.2", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"qwen3:8b"})
|
||||
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
||||
|
||||
got := envMap(c.modelEnvVars("qwen3:8b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses fast alias for haiku", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"llama3.2:70b"})
|
||||
saveAliases("claude", map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
})
|
||||
|
||||
got := envMap(c.modelEnvVars("llama3.2:70b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
|
||||
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("alias primary overrides model param", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"saved-model"})
|
||||
saveAliases("claude", map[string]string{"primary": "saved-model"})
|
||||
|
||||
got := envMap(c.modelEnvVars("different-model"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
|
||||
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,20 +14,21 @@ type Codex struct{}
|
||||
|
||||
func (c *Codex) String() string { return "Codex" }
|
||||
|
||||
func (c *Codex) args(model string) []string {
|
||||
func (c *Codex) args(model string, extra []string) []string {
|
||||
args := []string{"--oss"}
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Codex) Run(model string) error {
|
||||
func (c *Codex) Run(model string, args []string) error {
|
||||
if err := checkCodexVersion(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("codex", c.args(model)...)
|
||||
cmd := exec.Command("codex", c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
@@ -11,17 +11,20 @@ func TestCodexArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", []string{"--oss"}},
|
||||
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", nil, []string{"--oss"}},
|
||||
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
|
||||
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,23 +3,37 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
LastModel string `json:"last_model,omitempty"`
|
||||
LastSelection string `json:"last_selection,omitempty"` // "run" or integration name
|
||||
}
|
||||
|
||||
func configPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "config.json"), nil
|
||||
}
|
||||
|
||||
func legacyConfigPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -27,6 +41,44 @@ func configPath() (string, error) {
|
||||
return filepath.Join(home, ".ollama", "config", "config.json"), nil
|
||||
}
|
||||
|
||||
// migrateConfig moves the config from the legacy path to ~/.ollama/config.json
|
||||
func migrateConfig() (bool, error) {
|
||||
oldPath, err := legacyConfigPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
oldData, err := os.ReadFile(oldPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
var js json.RawMessage
|
||||
if err := json.Unmarshal(oldData, &js); err != nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
newPath, err := configPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := os.WriteFile(newPath, oldData, 0o644); err != nil {
|
||||
return false, fmt.Errorf("write new config: %w", err)
|
||||
}
|
||||
|
||||
_ = os.Remove(oldPath)
|
||||
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func load() (*config, error) {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
@@ -34,6 +86,11 @@ func load() (*config, error) {
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
if migrated, merr := migrateConfig(); merr == nil && migrated {
|
||||
data, err = os.ReadFile(path)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &config{Integrations: make(map[string]*integration)}, nil
|
||||
@@ -79,13 +136,89 @@ func saveIntegration(appName string, models []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
||||
Models: models,
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
var aliases map[string]string
|
||||
if existing != nil && existing.Aliases != nil {
|
||||
aliases = existing.Aliases
|
||||
}
|
||||
|
||||
cfg.Integrations[key] = &integration{
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||
func IntegrationModel(appName string) string {
|
||||
ic, err := loadIntegration(appName)
|
||||
if err != nil || len(ic.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return ic.Models[0]
|
||||
}
|
||||
|
||||
// LastModel returns the last model that was run, or empty string if none.
|
||||
func LastModel() string {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cfg.LastModel
|
||||
}
|
||||
|
||||
// SetLastModel saves the last model that was run.
|
||||
func SetLastModel(model string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.LastModel = model
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// LastSelection returns the last menu selection ("run" or integration name), or empty string if none.
|
||||
func LastSelection() string {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cfg.LastSelection
|
||||
}
|
||||
|
||||
// SetLastSelection saves the last menu selection ("run" or integration name).
|
||||
func SetLastSelection(selection string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.LastSelection = selection
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// ModelExists checks if a model exists on the Ollama server.
|
||||
func ModelExists(ctx context.Context, name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func loadIntegration(appName string) (*integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
@@ -100,6 +233,29 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return ic, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
if existing == nil {
|
||||
existing = &integration{}
|
||||
}
|
||||
|
||||
// Replace aliases entirely (not merge) so deletions are persisted
|
||||
existing.Aliases = aliases
|
||||
|
||||
cfg.Integrations[key] = existing
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
func listIntegrations() ([]integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
|
||||
677
cmd/config/config_cloud_test.go
Normal file
677
cmd/config/config_cloud_test.go
Normal file
@@ -0,0 +1,677 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSetAliases_CloudModel(t *testing.T) {
|
||||
// Test the SetAliases logic by checking the alias map behavior
|
||||
aliases := map[string]string{
|
||||
"primary": "kimi-k2.5:cloud",
|
||||
"fast": "kimi-k2.5:cloud",
|
||||
}
|
||||
|
||||
// Verify fast is set (cloud model behavior)
|
||||
if aliases["fast"] == "" {
|
||||
t.Error("cloud model should have fast alias set")
|
||||
}
|
||||
if aliases["fast"] != aliases["primary"] {
|
||||
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalModel(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:latest",
|
||||
}
|
||||
// Simulate local model behavior: fast should be empty
|
||||
delete(aliases, "fast")
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("local model should have empty fast alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save with both primary and fast
|
||||
initial := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
if err := saveAliases("claude", initial); err != nil {
|
||||
t.Fatalf("failed to save initial aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify both are saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
|
||||
// Now save without fast (simulating switch to local model)
|
||||
updated := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast intentionally missing
|
||||
}
|
||||
if err := saveAliases("claude", updated); err != nil {
|
||||
t.Fatalf("failed to save updated aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify fast is GONE (not merged/preserved)
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save integration with models
|
||||
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
|
||||
t.Fatalf("failed to save integration: %v", err)
|
||||
}
|
||||
|
||||
// Then update aliases
|
||||
aliases := map[string]string{"primary": "new-model"}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
|
||||
t.Errorf("models should be preserved, got %v", loaded.Models)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyMap clears all aliases
|
||||
func TestSaveAliases_EmptyMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save empty map
|
||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
||||
t.Fatalf("failed to save empty: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) != 0 {
|
||||
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_NilMap handles nil gracefully
|
||||
func TestSaveAliases_NilMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases first
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save nil map - should clear aliases
|
||||
if err := saveAliases("claude", nil); err != nil {
|
||||
t.Fatalf("failed to save nil: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) > 0 {
|
||||
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyAppName returns error
|
||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||
err := saveAliases("", map[string]string{"primary": "model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Load with different case
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model1" {
|
||||
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
|
||||
// Update with different case
|
||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to update: %v", err)
|
||||
}
|
||||
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model2" {
|
||||
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
|
||||
func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases for non-existent integration
|
||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("newintegration")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureAliases_AliasMap(t *testing.T) {
|
||||
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
|
||||
aliases := make(map[string]string)
|
||||
aliases["primary"] = "cloud-model"
|
||||
|
||||
// Simulate cloud model behavior
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model preserves custom fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "custom-fast-model",
|
||||
}
|
||||
|
||||
// Simulate cloud model behavior - should preserve existing fast
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "custom-fast-model" {
|
||||
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local model clears fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
"fast": "should-be-cleared",
|
||||
}
|
||||
|
||||
// Simulate local model behavior
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching cloud to local clears fast", func(t *testing.T) {
|
||||
// Start with cloud config
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
|
||||
// Switch to local
|
||||
aliases["primary"] = "local-model"
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
|
||||
}
|
||||
if aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated, got %q", aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching local to cloud sets fast", func(t *testing.T) {
|
||||
// Start with local config (no fast)
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
}
|
||||
|
||||
// Switch to cloud
|
||||
aliases["primary"] = "cloud-model"
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetAliases_PrefixMapping(t *testing.T) {
|
||||
// This tests the expected mapping without needing a real client
|
||||
aliases := map[string]string{
|
||||
"primary": "my-cloud-model",
|
||||
"fast": "my-fast-model",
|
||||
}
|
||||
|
||||
expectedMappings := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
|
||||
t.Errorf("claude-sonnet- should map to primary")
|
||||
}
|
||||
if expectedMappings["claude-haiku-"] != "my-fast-model" {
|
||||
t.Errorf("claude-haiku- should map to fast")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast is empty/missing - indicates local model
|
||||
}
|
||||
|
||||
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
// Verify the logic: when fast is empty, we should delete
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("fast should be empty for local model")
|
||||
}
|
||||
|
||||
// Verify we have the right prefixes to delete
|
||||
if len(prefixesToDelete) != 2 {
|
||||
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
|
||||
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server fails, config should NOT be saved
|
||||
serverErr := errors.New("server unavailable")
|
||||
|
||||
if serverErr == nil {
|
||||
t.Error("config should NOT be saved when server fails")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
|
||||
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server succeeds, config should be saved
|
||||
var serverErr error
|
||||
if serverErr != nil {
|
||||
t.Fatal("server should succeed")
|
||||
}
|
||||
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("saveAliases failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was actually saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Write config with extra fields
|
||||
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
os.MkdirAll(filepath.Dir(configPath), 0o755)
|
||||
|
||||
// Note: Our config struct only has Integrations, so top-level unknown fields
|
||||
// won't be preserved by our current implementation. This test documents that.
|
||||
initialConfig := `{
|
||||
"integrations": {
|
||||
"claude": {
|
||||
"models": ["model1"],
|
||||
"aliases": {"primary": "model1"},
|
||||
"unknownField": "should be lost"
|
||||
}
|
||||
},
|
||||
"topLevelUnknown": "will be lost"
|
||||
}`
|
||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||
|
||||
// Update aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Read raw file to check
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
// models should be preserved
|
||||
if !contains(content, "model1") {
|
||||
t.Error("models should be preserved")
|
||||
}
|
||||
|
||||
// primary should be updated
|
||||
if !contains(content, "model2") {
|
||||
t.Error("primary should be updated to model2")
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
||||
c := &Claude{}
|
||||
var _ AliasConfigurer = c // Compile-time check
|
||||
}
|
||||
|
||||
func TestModelNameEdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
model string
|
||||
}{
|
||||
{"simple", "llama3.2"},
|
||||
{"with tag", "llama3.2:latest"},
|
||||
{"with cloud tag", "kimi-k2.5:cloud"},
|
||||
{"with namespace", "library/llama3.2"},
|
||||
{"with dots", "glm-4.7-flash"},
|
||||
{"with numbers", "qwen3:8b"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
aliases := map[string]string{"primary": tc.model}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != tc.model {
|
||||
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchingScenarios(t *testing.T) {
|
||||
t.Run("cloud to local removes fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to local (no fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local to cloud adds fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial local config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to cloud (with fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud to different cloud updates both", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-1",
|
||||
"fast": "cloud-model-1",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to different cloud
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-2",
|
||||
"fast": "cloud-model-2",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model-2" {
|
||||
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolCapabilityFiltering(t *testing.T) {
|
||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
||||
// Both cloud and local models are checked for tool capability via Show API
|
||||
// Only models with "tools" in capabilities are included
|
||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("tool capable model should be marked as such")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("ToolCapable field should be accessible")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsCloudModel_RequiresClient(t *testing.T) {
|
||||
t.Run("nil client always returns false", func(t *testing.T) {
|
||||
// isCloudModel now only uses Show API, no suffix detection
|
||||
if isCloudModel(context.Background(), nil, "model:cloud") {
|
||||
t.Error("nil client should return false regardless of suffix")
|
||||
}
|
||||
if isCloudModel(context.Background(), nil, "local-model") {
|
||||
t.Error("nil client should return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases with one model
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Save integration with same model (this is the pattern we use)
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("out of sync config is detectable", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate out-of-sync state (like manual edit or bug)
|
||||
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
|
||||
// They should be different (this is the bug state)
|
||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||
t.Error("expected out-of-sync state for this test")
|
||||
}
|
||||
|
||||
// The fix: when updating aliases, also update models
|
||||
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ = loadIntegration("claude")
|
||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||
loaded.Models[0], loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updating primary alias updates models too", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial state
|
||||
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Update aliases AND models together
|
||||
newAliases := map[string]string{"primary": "updated-model"}
|
||||
if err := saveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Models[0] != "updated-model" {
|
||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "updated-model" {
|
||||
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -46,6 +46,53 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save and load aliases", func(t *testing.T) {
|
||||
models := []string{"llama3.2"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases == nil {
|
||||
t.Fatal("expected aliases to be saved")
|
||||
}
|
||||
for k, v := range aliases {
|
||||
if config.Aliases[k] != v {
|
||||
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases["primary"] != "model-a" {
|
||||
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
@@ -200,12 +247,10 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Create corrupted config.json file
|
||||
dir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
dir := filepath.Join(tmpDir, ".ollama")
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
// Corrupted file is treated as empty, so loadIntegration returns not found
|
||||
_, err := loadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
@@ -267,7 +312,7 @@ func TestConfigPath(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
if path != expected {
|
||||
t.Errorf("expected %s, got %s", expected, path)
|
||||
}
|
||||
@@ -322,6 +367,183 @@ func TestLoad(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateConfig(t *testing.T) {
|
||||
t.Run("migrates legacy file to new location", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
data := []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), data, 0o644)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
newPath, _ := configPath()
|
||||
got, err := os.ReadFile(newPath)
|
||||
if err != nil {
|
||||
t.Fatalf("new config not found: %v", err)
|
||||
}
|
||||
if string(got) != string(data) {
|
||||
t.Errorf("content mismatch: got %s", got)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||
t.Error("legacy file should have been removed")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(legacyDir); !os.IsNotExist(err) {
|
||||
t.Error("legacy directory should have been removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no-op when no legacy file exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("expected no migration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips corrupt legacy file", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{corrupt`), 0o644)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("should not migrate corrupt file")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); os.IsNotExist(err) {
|
||||
t.Error("corrupt legacy file should not have been deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("new path takes precedence over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"old":{"models":["old-model"]}}}`), 0o644)
|
||||
|
||||
newDir := filepath.Join(tmpDir, ".ollama")
|
||||
os.WriteFile(filepath.Join(newDir, "config.json"), []byte(`{"integrations":{"new":{"models":["new-model"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := cfg.Integrations["new"]; !ok {
|
||||
t.Error("expected new-path integration to be loaded")
|
||||
}
|
||||
if _, ok := cfg.Integrations["old"]; ok {
|
||||
t.Error("legacy integration should not have been loaded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("idempotent when called twice", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||
|
||||
if _, err := migrateConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("second migration should be a no-op")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy directory preserved if not empty", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "other-file.txt"), []byte("keep me"), 0o644)
|
||||
|
||||
if _, err := migrateConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(legacyDir); os.IsNotExist(err) {
|
||||
t.Error("directory with other files should not have been removed")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "other-file.txt")); os.IsNotExist(err) {
|
||||
t.Error("other files in legacy directory should be untouched")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save writes to new path after migration", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
// load triggers migration, then save should write to new path
|
||||
if err := saveIntegration("codex", []string{"qwen2.5"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
if _, err := os.Stat(newPath); os.IsNotExist(err) {
|
||||
t.Error("save should write to new path")
|
||||
}
|
||||
|
||||
// old path should not be recreated
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||
t.Error("save should not recreate legacy path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load triggers migration transparently", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Integrations["claude"] == nil || cfg.Integrations["claude"].Models[0] != "llama3.2" {
|
||||
t.Error("migration via load() did not preserve data")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Droid implements Runner and Editor for Droid integration
|
||||
@@ -37,7 +41,7 @@ type modelEntry struct {
|
||||
|
||||
func (d *Droid) String() string { return "Droid" }
|
||||
|
||||
func (d *Droid) Run(model string) error {
|
||||
func (d *Droid) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("droid"); err != nil {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
@@ -51,7 +55,7 @@ func (d *Droid) Run(model string) error {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid")
|
||||
cmd := exec.Command("droid", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -110,17 +114,25 @@ func (d *Droid) Edit(models []string) error {
|
||||
}
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
var newModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
maxOutput := 64000
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
maxOutput = l.Output
|
||||
}
|
||||
}
|
||||
modelID := fmt.Sprintf("custom:%s-%d", model, i)
|
||||
newModels = append(newModels, modelEntry{
|
||||
Model: model,
|
||||
DisplayName: model,
|
||||
BaseURL: "http://localhost:11434/v1",
|
||||
BaseURL: envconfig.Host().String() + "/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: 64000,
|
||||
MaxOutputTokens: maxOutput,
|
||||
SupportsImages: false,
|
||||
ID: modelID,
|
||||
Index: i,
|
||||
|
||||
@@ -218,7 +218,7 @@ func TestDroidEdit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if model["baseUrl"] != "http://localhost:11434/v1" {
|
||||
if model["baseUrl"] != "http://127.0.0.1:11434/v1" {
|
||||
t.Errorf("unexpected baseUrl: %s", model["baseUrl"])
|
||||
}
|
||||
if model["apiKey"] != "ollama" {
|
||||
@@ -447,7 +447,7 @@ const testDroidSettingsFixture = `{
|
||||
{
|
||||
"model": "existing-ollama-model",
|
||||
"displayName": "existing-ollama-model",
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
||||
"apiKey": "ollama",
|
||||
"provider": "generic-chat-completion-api",
|
||||
"maxOutputTokens": 64000,
|
||||
@@ -1251,6 +1251,55 @@ func TestDroidEdit_LargeNumberOfModels(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||
|
||||
if err := d.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(settingsPath)
|
||||
var settings map[string]any
|
||||
json.Unmarshal(data, &settings)
|
||||
|
||||
models := settings["customModels"].([]any)
|
||||
entry := models[0].(map[string]any)
|
||||
if entry["maxOutputTokens"] != float64(64000) {
|
||||
t.Errorf("local model maxOutputTokens = %v, want 64000", entry["maxOutputTokens"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
||||
// Verify that every cloud model in cloudModelLimits has a valid output
|
||||
// value that would be used for maxOutputTokens when isCloudModel returns true.
|
||||
// :cloud suffix stripping must also work since that's how users specify them.
|
||||
for name, expected := range cloudModelLimits {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(name)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
|
||||
}
|
||||
if l.Output != expected.Output {
|
||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||
}
|
||||
// Also verify :cloud suffix lookup
|
||||
cloudName := name + ":cloud"
|
||||
l2, ok := lookupCloudModelLimit(cloudName)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
||||
}
|
||||
if l2.Output != expected.Output {
|
||||
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_ArraysWithMixedTypes(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -9,28 +9,48 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
type Clawdbot struct{}
|
||||
type Openclaw struct{}
|
||||
|
||||
func (c *Clawdbot) String() string { return "Clawdbot" }
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
const ansiGreen = "\033[32m"
|
||||
|
||||
func (c *Clawdbot) Run(model string) error {
|
||||
if _, err := exec.LookPath("clawdbot"); err != nil {
|
||||
return fmt.Errorf("clawdbot is not installed, install from https://docs.clawd.bot")
|
||||
func (c *Openclaw) Run(model string, args []string) error {
|
||||
bin := "openclaw"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
bin = "clawdbot"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
|
||||
}
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
|
||||
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("clawdbot", "gateway")
|
||||
if !c.onboarded() {
|
||||
// Onboarding not completed: run it (model already set via Edit)
|
||||
// Use "ollama" as gateway token for simple local access
|
||||
cmd := exec.Command(bin, "onboard",
|
||||
"--auth-choice", "skip",
|
||||
"--gateway-token", "ollama",
|
||||
)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// Onboarding completed: run gateway
|
||||
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
// Capture output to detect "already running" message
|
||||
@@ -40,22 +60,55 @@ func (c *Clawdbot) Run(model string) error {
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
|
||||
fmt.Fprintf(os.Stderr, "%sClawdbot has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Clawdbot) Paths() []string {
|
||||
// onboarded checks if OpenClaw onboarding wizard was completed
|
||||
// by looking for the wizard.lastRunAt marker in the config
|
||||
func (c *Openclaw) onboarded() bool {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for wizard.lastRunAt marker (set when onboarding completes)
|
||||
wizard, _ := config["wizard"].(map[string]any)
|
||||
if wizard == nil {
|
||||
return false
|
||||
}
|
||||
lastRunAt, _ := wizard["lastRunAt"].(string)
|
||||
return lastRunAt != ""
|
||||
}
|
||||
|
||||
func (c *Openclaw) Paths() []string {
|
||||
home, _ := os.UserHomeDir()
|
||||
p := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
p := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if _, err := os.Stat(legacy); err == nil {
|
||||
return []string{legacy}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Clawdbot) Edit(models []string) error {
|
||||
func (c *Openclaw) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -65,7 +118,8 @@ func (c *Clawdbot) Edit(models []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -74,6 +128,8 @@ func (c *Clawdbot) Edit(models []string) error {
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
// Navigate/create: models.providers.ollama (preserving other providers)
|
||||
@@ -90,7 +146,7 @@ func (c *Clawdbot) Edit(models []string) error {
|
||||
ollama = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama["baseUrl"] = "http://127.0.0.1:11434/v1"
|
||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
// TODO(parthsareen): potentially move to responses
|
||||
@@ -165,15 +221,18 @@ func (c *Clawdbot) Edit(models []string) error {
|
||||
return writeWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Clawdbot) Models() []string {
|
||||
func (c *Openclaw) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClawdbotIntegration(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
func TestOpenclawIntegration(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Clawdbot" {
|
||||
t.Errorf("String() = %q, want %q", got, "Clawdbot")
|
||||
if got := c.String(); got != "OpenClaw" {
|
||||
t.Errorf("String() = %q, want %q", got, "OpenClaw")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -26,13 +26,13 @@ func TestClawdbotIntegration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotEdit(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
func TestOpenclawEdit(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
@@ -41,8 +41,8 @@ func TestClawdbotEdit(t *testing.T) {
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
||||
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
})
|
||||
|
||||
t.Run("multiple models - first is primary", func(t *testing.T) {
|
||||
@@ -50,9 +50,9 @@ func TestClawdbotEdit(t *testing.T) {
|
||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
assertClawdbotModelExists(t, configPath, "mistral")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
||||
assertOpenclawModelExists(t, configPath, "mistral")
|
||||
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other providers", func(t *testing.T) {
|
||||
@@ -127,8 +127,8 @@ func TestClawdbotEdit(t *testing.T) {
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
assertClawdbotModelNotExists(t, configPath, "mistral")
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
||||
assertOpenclawModelNotExists(t, configPath, "mistral")
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
@@ -169,12 +169,12 @@ func TestClawdbotEdit(t *testing.T) {
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotModels(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
func TestOpenclawModels(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
@@ -185,9 +185,9 @@ func TestClawdbotModels(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("returns all ollama models", func(t *testing.T) {
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[
|
||||
{"id":"llama3.2"},
|
||||
{"id":"mistral"}
|
||||
@@ -202,7 +202,7 @@ func TestClawdbotModels(t *testing.T) {
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func assertClawdbotModelExists(t *testing.T, path, model string) {
|
||||
func assertOpenclawModelExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
@@ -221,7 +221,7 @@ func assertClawdbotModelExists(t *testing.T, path, model string) {
|
||||
t.Errorf("model %s not found", model)
|
||||
}
|
||||
|
||||
func assertClawdbotModelNotExists(t *testing.T, path, model string) {
|
||||
func assertOpenclawModelNotExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
@@ -239,7 +239,7 @@ func assertClawdbotModelNotExists(t *testing.T, path, model string) {
|
||||
}
|
||||
}
|
||||
|
||||
func assertClawdbotPrimaryModel(t *testing.T, path, expected string) {
|
||||
func assertOpenclawPrimaryModel(t *testing.T, path, expected string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
@@ -252,15 +252,15 @@ func assertClawdbotPrimaryModel(t *testing.T, path, expected string) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotPaths(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
func TestOpenclawPaths(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{}`), 0o644)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 {
|
||||
@@ -277,12 +277,12 @@ func TestClawdbotPaths(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotModelsEdgeCases(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
func TestOpenclawModelsEdgeCases(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("corrupted JSON returns nil", func(t *testing.T) {
|
||||
@@ -340,11 +340,11 @@ func TestClawdbotModelsEdgeCases(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotEditSchemaFields(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
func TestOpenclawEditSchemaFields(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
|
||||
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -381,20 +381,20 @@ func TestClawdbotEditSchemaFields(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEditModelNames(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
func TestOpenclawEditModelNames(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
|
||||
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".clawdbot")) }
|
||||
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
|
||||
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".openclaw")) }
|
||||
|
||||
t.Run("model with colon tag", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "llama3.2:70b")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2:70b")
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2:70b")
|
||||
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2:70b")
|
||||
})
|
||||
|
||||
t.Run("model with slash", func(t *testing.T) {
|
||||
@@ -402,8 +402,8 @@ func TestClawdbotEditModelNames(t *testing.T) {
|
||||
if err := c.Edit([]string{"library/model:tag"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "library/model:tag")
|
||||
assertClawdbotPrimaryModel(t, configPath, "ollama/library/model:tag")
|
||||
assertOpenclawModelExists(t, configPath, "library/model:tag")
|
||||
assertOpenclawPrimaryModel(t, configPath, "ollama/library/model:tag")
|
||||
})
|
||||
|
||||
t.Run("model with hyphen", func(t *testing.T) {
|
||||
@@ -411,16 +411,16 @@ func TestClawdbotEditModelNames(t *testing.T) {
|
||||
if err := c.Edit([]string{"test-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertClawdbotModelExists(t, configPath, "test-model")
|
||||
assertOpenclawModelExists(t, configPath, "test-model")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClawdbotEditAgentsPreservation(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
func TestOpenclawEditAgentsPreservation(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("preserve other agent defaults", func(t *testing.T) {
|
||||
@@ -457,7 +457,7 @@ func TestClawdbotEditAgentsPreservation(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
const testClawdbotFixture = `{
|
||||
const testOpenclawFixture = `{
|
||||
"theme": "dark",
|
||||
"mcp": {"servers": {"custom": {"enabled": true}}},
|
||||
"models": {
|
||||
@@ -475,15 +475,15 @@ const testClawdbotFixture = `{
|
||||
}
|
||||
}`
|
||||
|
||||
func TestClawdbotEdit_RoundTrip(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
func TestOpenclawEdit_RoundTrip(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
||||
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -521,15 +521,15 @@ func TestClawdbotEdit_RoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_Idempotent(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
func TestOpenclawEdit_Idempotent(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
||||
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
firstData, _ := os.ReadFile(configPath)
|
||||
@@ -542,15 +542,15 @@ func TestClawdbotEdit_Idempotent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_MultipleConsecutiveEdits(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
func TestOpenclawEdit_MultipleConsecutiveEdits(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
||||
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
|
||||
|
||||
for i := range 10 {
|
||||
models := []string{"model-a", "model-b"}
|
||||
@@ -573,12 +573,12 @@ func TestClawdbotEdit_MultipleConsecutiveEdits(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_BackupCreated(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
func TestOpenclawEdit_BackupCreated(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
@@ -590,7 +590,7 @@ func TestClawdbotEdit_BackupCreated(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
backups, _ := filepath.Glob(filepath.Join(backupDir, "clawdbot.json.*"))
|
||||
backups, _ := filepath.Glob(filepath.Join(backupDir, "openclaw.json.*"))
|
||||
foundBackup := false
|
||||
for _, backup := range backups {
|
||||
data, _ := os.ReadFile(backup)
|
||||
@@ -605,11 +605,151 @@ func TestClawdbotEdit_BackupCreated(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClawdbotEdit_CreatesDirectoryIfMissing(t *testing.T) {
|
||||
c := &Clawdbot{}
|
||||
func TestOpenclawClawdbotAlias(t *testing.T) {
|
||||
for _, alias := range []string{"clawdbot", "moltbot"} {
|
||||
t.Run(alias+" alias resolves to Openclaw runner", func(t *testing.T) {
|
||||
r, ok := integrations[alias]
|
||||
if !ok {
|
||||
t.Fatalf("%s not found in integrations", alias)
|
||||
}
|
||||
if _, ok := r.(*Openclaw); !ok {
|
||||
t.Errorf("%s integration is %T, want *Openclaw", alias, r)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run(alias+" is hidden from selector", func(t *testing.T) {
|
||||
if !integrationAliases[alias] {
|
||||
t.Errorf("%s should be in integrationAliases", alias)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawLegacyPaths(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("falls back to legacy clawdbot path", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 {
|
||||
t.Fatalf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
if paths[0] != filepath.Join(legacyDir, "clawdbot.json") {
|
||||
t.Errorf("expected legacy path, got %s", paths[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prefers new path over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(newDir, 0o755)
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 {
|
||||
t.Fatalf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
if paths[0] != filepath.Join(newDir, "openclaw.json") {
|
||||
t.Errorf("expected new path, got %s", paths[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Models reads from legacy path", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[{"id":"llama3.2"}]}}}
|
||||
}`), 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 1 || models[0] != "llama3.2" {
|
||||
t.Errorf("expected [llama3.2], got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Models prefers new path over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(newDir, 0o755)
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[{"id":"new-model"}]}}}
|
||||
}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[{"id":"legacy-model"}]}}}
|
||||
}`), 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 1 || models[0] != "new-model" {
|
||||
t.Errorf("expected [new-model], got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Edit reads new path over legacy when both exist", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(newDir, 0o755)
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{"theme":"new"}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"legacy"}`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(filepath.Join(newDir, "openclaw.json"))
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "new" {
|
||||
t.Errorf("expected theme from new config, got %v", cfg["theme"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Edit migrates from legacy config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"dark"}`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should write to new path
|
||||
newPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
|
||||
data, err := os.ReadFile(newPath)
|
||||
if err != nil {
|
||||
t.Fatal("expected new config file to be created")
|
||||
}
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("legacy theme setting was not migrated")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawEdit_CreatesDirectoryIfMissing(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
|
||||
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
|
||||
t.Fatal("directory should not exist before test")
|
||||
@@ -623,3 +763,116 @@ func TestClawdbotEdit_CreatesDirectoryIfMissing(t *testing.T) {
|
||||
t.Fatal("directory was not created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawOnboarded(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("returns false when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when no config exists")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false when config exists but no wizard section", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when no wizard section")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false when wizard section exists but no lastRunAt", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{}}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when wizard.lastRunAt is missing")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false when wizard.lastRunAt is empty string", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":""}}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when wizard.lastRunAt is empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns true when wizard.lastRunAt is set", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
||||
|
||||
if !c.onboarded() {
|
||||
t.Error("expected true when wizard.lastRunAt is set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("checks legacy clawdbot path", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
||||
|
||||
if !c.onboarded() {
|
||||
t.Error("expected true when legacy config has wizard.lastRunAt")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prefers new path over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(newDir, 0o755)
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
// New path has no wizard marker
|
||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
|
||||
// Legacy has wizard marker
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false - should prefer new path which has no wizard marker")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false for corrupted JSON")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles wrong type for wizard section", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":"not a map"}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when wizard is wrong type")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
@@ -9,14 +10,57 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// cloudModelLimits maps cloud model base names to their token limits.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
if l, ok := cloudModelLimits[name]; ok {
|
||||
return l, true
|
||||
}
|
||||
base := strings.TrimSuffix(name, ":cloud")
|
||||
if base != name {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string) error {
|
||||
func (o *OpenCode) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("opencode"); err != nil {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
@@ -30,7 +74,7 @@ func (o *OpenCode) Run(model string) error {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode")
|
||||
cmd := exec.Command("opencode", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -88,7 +132,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama (local)",
|
||||
"options": map[string]any{
|
||||
"baseURL": "http://localhost:11434/v1",
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -111,6 +155,8 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
@@ -120,12 +166,29 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
existing["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
models[model] = map[string]any{
|
||||
entry := map[string]any{
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
models[model] = entry
|
||||
}
|
||||
|
||||
ollama["models"] = models
|
||||
|
||||
@@ -2,6 +2,7 @@ package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -495,6 +496,166 @@ func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry, ok := models[model].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("model %s not found in config", model)
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
if entry["limit"] != nil {
|
||||
t.Errorf("local model should not have limit set, got %v", entry["limit"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
// Set up a model with a user-configured limit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"llama3.2": {
|
||||
"name": "llama3.2",
|
||||
"_launch": true,
|
||||
"limit": {"context": 8192, "output": 4096}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`), 0o644)
|
||||
|
||||
// Re-edit should preserve the user's limit (not delete it)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("user-configured limit was removed")
|
||||
}
|
||||
if limit["context"] != float64(8192) {
|
||||
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
|
||||
}
|
||||
if limit["output"] != float64(4096) {
|
||||
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
||||
// Verify that when a cloud model entry has limits set (as Edit would do),
|
||||
// the structure matches what opencode expects and re-edit preserves them.
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
expected := cloudModelLimits["glm-4.7"]
|
||||
|
||||
// Simulate a cloud model that already has the limit set by a previous Edit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud",
|
||||
"_launch": true,
|
||||
"limit": {"context": %d, "output": %d}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, expected.Context, expected.Output)), 0o644)
|
||||
|
||||
// Re-edit should preserve the cloud model limit
|
||||
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("cloud model limit was removed on re-edit")
|
||||
}
|
||||
if limit["context"] != float64(expected.Context) {
|
||||
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
|
||||
}
|
||||
if limit["output"] != float64(expected.Output) {
|
||||
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupCloudModelLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantOK bool
|
||||
wantContext int
|
||||
wantOutput int
|
||||
}{
|
||||
{"glm-4.7", true, 202_752, 131_072},
|
||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||
{"kimi-k2.5", true, 262_144, 262_144},
|
||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
||||
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||
{"llama3.2", false, 0, 0},
|
||||
{"unknown-model:cloud", false, 0, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(tt.name)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
|
||||
}
|
||||
if ok {
|
||||
if l.Context != tt.wantContext {
|
||||
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
|
||||
}
|
||||
if l.Output != tt.wantOutput {
|
||||
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
237
cmd/config/pi.go
Normal file
237
cmd/config/pi.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
|
||||
type Pi struct{}
|
||||
|
||||
func (p *Pi) String() string { return "Pi" }
|
||||
|
||||
func (p *Pi) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("pi"); err != nil {
|
||||
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := p.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("pi", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (p *Pi) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var paths []string
|
||||
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
if _, err := os.Stat(modelsPath); err == nil {
|
||||
paths = append(paths, modelsPath)
|
||||
}
|
||||
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||
if _, err := os.Stat(settingsPath); err == nil {
|
||||
paths = append(paths, settingsPath)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func (p *Pi) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
providers, ok := config["providers"].(map[string]any)
|
||||
if !ok {
|
||||
providers = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama, ok := providers["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"baseUrl": envconfig.Host().String() + "/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
}
|
||||
}
|
||||
|
||||
existingModels, ok := ollama["models"].([]any)
|
||||
if !ok {
|
||||
existingModels = make([]any, 0)
|
||||
}
|
||||
|
||||
// Build set of selected models to track which need to be added
|
||||
selectedSet := make(map[string]bool, len(models))
|
||||
for _, m := range models {
|
||||
selectedSet[m] = true
|
||||
}
|
||||
|
||||
// Build new models list:
|
||||
// 1. Keep user-managed models (no _launch marker) - untouched
|
||||
// 2. Keep ollama-managed models (_launch marker) that are still selected
|
||||
// 3. Add new ollama-managed models
|
||||
var newModels []any
|
||||
for _, m := range existingModels {
|
||||
if modelObj, ok := m.(map[string]any); ok {
|
||||
if id, ok := modelObj["id"].(string); ok {
|
||||
// User-managed model (no _launch marker) - always preserve
|
||||
if !isPiOllamaModel(modelObj) {
|
||||
newModels = append(newModels, m)
|
||||
} else if selectedSet[id] {
|
||||
// Ollama-managed and still selected - keep it
|
||||
newModels = append(newModels, m)
|
||||
selectedSet[id] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add newly selected models that weren't already in the list
|
||||
client := api.NewClient(envconfig.Host(), http.DefaultClient)
|
||||
ctx := context.Background()
|
||||
for _, model := range models {
|
||||
if selectedSet[model] {
|
||||
newModels = append(newModels, createConfig(ctx, client, model))
|
||||
}
|
||||
}
|
||||
|
||||
ollama["models"] = newModels
|
||||
providers["ollama"] = ollama
|
||||
config["providers"] = providers
|
||||
|
||||
configData, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update settings.json with default provider and model
|
||||
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||
settings := make(map[string]any)
|
||||
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||
_ = json.Unmarshal(data, &settings)
|
||||
}
|
||||
|
||||
settings["defaultProvider"] = "ollama"
|
||||
settings["defaultModel"] = models[0]
|
||||
|
||||
settingsData, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(settingsPath, settingsData)
|
||||
}
|
||||
|
||||
func (p *Pi) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
config, err := readJSONFile(configPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
providers, _ := config["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].([]any)
|
||||
|
||||
var result []string
|
||||
for _, m := range models {
|
||||
if modelObj, ok := m.(map[string]any); ok {
|
||||
if id, ok := modelObj["id"].(string); ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
slices.Sort(result)
|
||||
return result
|
||||
}
|
||||
|
||||
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
|
||||
func isPiOllamaModel(cfg map[string]any) bool {
|
||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// createConfig builds Pi model config with capability detection
|
||||
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
||||
cfg := map[string]any{
|
||||
"id": modelID,
|
||||
"_launch": true,
|
||||
}
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
||||
if err != nil {
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Set input types based on vision capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
||||
cfg["input"] = []string{"text", "image"}
|
||||
} else {
|
||||
cfg["input"] = []string{"text"}
|
||||
}
|
||||
|
||||
// Set reasoning based on thinking capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
cfg["reasoning"] = true
|
||||
}
|
||||
|
||||
// Extract context window from ModelInfo
|
||||
for key, val := range resp.ModelInfo {
|
||||
if strings.HasSuffix(key, ".context_length") {
|
||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||
cfg["contextWindow"] = int(ctxLen)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
830
cmd/config/pi_test.go
Normal file
830
cmd/config/pi_test.go
Normal file
@@ -0,0 +1,830 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestPiIntegration(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := pi.String(); got != "Pi" {
|
||||
t.Errorf("String() = %q, want %q", got, "Pi")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = pi
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = pi
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiPaths(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
|
||||
t.Run("returns empty when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
paths := pi.Paths()
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("Paths() = %v, want empty", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte("{}"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
paths := pi.Paths()
|
||||
if len(paths) != 1 || paths[0] != configPath {
|
||||
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiEdit(t *testing.T) {
|
||||
// Mock Ollama server for createConfig calls during Edit
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
pi := &Pi{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
|
||||
cleanup := func() {
|
||||
os.RemoveAll(configDir)
|
||||
}
|
||||
|
||||
readConfig := func() map[string]any {
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
return cfg
|
||||
}
|
||||
|
||||
t.Run("returns nil for empty models", func(t *testing.T) {
|
||||
if err := pi.Edit([]string{}); err != nil {
|
||||
t.Errorf("Edit([]) error = %v, want nil", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates config with models", func(t *testing.T) {
|
||||
cleanup()
|
||||
|
||||
models := []string{"llama3.2", "qwen3:8b"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
|
||||
providers, ok := cfg["providers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Config missing providers")
|
||||
}
|
||||
|
||||
ollama, ok := providers["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Providers missing ollama")
|
||||
}
|
||||
|
||||
modelsArray, ok := ollama["models"].([]any)
|
||||
if !ok || len(modelsArray) != 2 {
|
||||
t.Errorf("Expected 2 models, got %v", modelsArray)
|
||||
}
|
||||
|
||||
if ollama["baseUrl"] == nil {
|
||||
t.Error("Missing baseUrl")
|
||||
}
|
||||
if ollama["api"] != "openai-completions" {
|
||||
t.Errorf("Expected api=openai-completions, got %v", ollama["api"])
|
||||
}
|
||||
if ollama["apiKey"] != "ollama" {
|
||||
t.Errorf("Expected apiKey=ollama, got %v", ollama["apiKey"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates existing config preserving ollama provider settings", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://custom:8080/v1",
|
||||
"api": "custom-api",
|
||||
"apiKey": "custom-key",
|
||||
"models": [
|
||||
{"id": "old-model", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"new-model"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
|
||||
if ollama["baseUrl"] != "http://custom:8080/v1" {
|
||||
t.Errorf("Custom baseUrl not preserved, got %v", ollama["baseUrl"])
|
||||
}
|
||||
if ollama["api"] != "custom-api" {
|
||||
t.Errorf("Custom api not preserved, got %v", ollama["api"])
|
||||
}
|
||||
if ollama["apiKey"] != "custom-key" {
|
||||
t.Errorf("Custom apiKey not preserved, got %v", ollama["apiKey"])
|
||||
}
|
||||
|
||||
modelsArray := ollama["models"].([]any)
|
||||
if len(modelsArray) != 1 {
|
||||
t.Errorf("Expected 1 model after update, got %d", len(modelsArray))
|
||||
} else {
|
||||
modelEntry := modelsArray[0].(map[string]any)
|
||||
if modelEntry["id"] != "new-model" {
|
||||
t.Errorf("Expected new-model, got %v", modelEntry["id"])
|
||||
}
|
||||
// Verify _launch marker is present
|
||||
if modelEntry["_launch"] != true {
|
||||
t.Errorf("Expected _launch marker to be true")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("replaces old models with new ones", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Old models must have _launch marker to be managed by us
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "old-model-1", "_launch": true},
|
||||
{"id": "old-model-2", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newModels := []string{"new-model-1", "new-model-2"}
|
||||
if err := pi.Edit(newModels); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
if len(modelsArray) != 2 {
|
||||
t.Errorf("Expected 2 models, got %d", len(modelsArray))
|
||||
}
|
||||
|
||||
modelIDs := make(map[string]bool)
|
||||
for _, m := range modelsArray {
|
||||
modelObj := m.(map[string]any)
|
||||
id := modelObj["id"].(string)
|
||||
modelIDs[id] = true
|
||||
}
|
||||
|
||||
if !modelIDs["new-model-1"] || !modelIDs["new-model-2"] {
|
||||
t.Errorf("Expected new models, got %v", modelIDs)
|
||||
}
|
||||
if modelIDs["old-model-1"] || modelIDs["old-model-2"] {
|
||||
t.Errorf("Old models should have been removed, got %v", modelIDs)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles partial overlap in model list", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Models must have _launch marker to be managed
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "keep-model", "_launch": true},
|
||||
{"id": "remove-model", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newModels := []string{"keep-model", "add-model"}
|
||||
if err := pi.Edit(newModels); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
if len(modelsArray) != 2 {
|
||||
t.Errorf("Expected 2 models, got %d", len(modelsArray))
|
||||
}
|
||||
|
||||
modelIDs := make(map[string]bool)
|
||||
for _, m := range modelsArray {
|
||||
modelObj := m.(map[string]any)
|
||||
id := modelObj["id"].(string)
|
||||
modelIDs[id] = true
|
||||
}
|
||||
|
||||
if !modelIDs["keep-model"] || !modelIDs["add-model"] {
|
||||
t.Errorf("Expected keep-model and add-model, got %v", modelIDs)
|
||||
}
|
||||
if modelIDs["remove-model"] {
|
||||
t.Errorf("remove-model should have been removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupt config gracefully", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"test-model"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() should not fail with corrupt config, got %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read config: %v", err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("Config should be valid after Edit, got parse error: %v", err)
|
||||
}
|
||||
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
if len(modelsArray) != 1 {
|
||||
t.Errorf("Expected 1 model, got %d", len(modelsArray))
|
||||
}
|
||||
})
|
||||
|
||||
// CRITICAL SAFETY TEST: verifies we don't stomp on user configs
|
||||
t.Run("preserves user-managed models without _launch marker", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// User has manually configured models in ollama provider (no _launch marker)
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "user-model-1"},
|
||||
{"id": "user-model-2", "customField": "preserved"},
|
||||
{"id": "ollama-managed", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add a new ollama-managed model
|
||||
newModels := []string{"new-ollama-model"}
|
||||
if err := pi.Edit(newModels); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
// Should have: new-ollama-model (managed) + 2 user models (preserved)
|
||||
if len(modelsArray) != 3 {
|
||||
t.Errorf("Expected 3 models (1 new managed + 2 preserved user models), got %d", len(modelsArray))
|
||||
}
|
||||
|
||||
modelIDs := make(map[string]map[string]any)
|
||||
for _, m := range modelsArray {
|
||||
modelObj := m.(map[string]any)
|
||||
id := modelObj["id"].(string)
|
||||
modelIDs[id] = modelObj
|
||||
}
|
||||
|
||||
// Verify new model has _launch marker
|
||||
if m, ok := modelIDs["new-ollama-model"]; !ok {
|
||||
t.Errorf("new-ollama-model should be present")
|
||||
} else if m["_launch"] != true {
|
||||
t.Errorf("new-ollama-model should have _launch marker")
|
||||
}
|
||||
|
||||
// Verify user models are preserved
|
||||
if _, ok := modelIDs["user-model-1"]; !ok {
|
||||
t.Errorf("user-model-1 should be preserved")
|
||||
}
|
||||
if _, ok := modelIDs["user-model-2"]; !ok {
|
||||
t.Errorf("user-model-2 should be preserved")
|
||||
} else if modelIDs["user-model-2"]["customField"] != "preserved" {
|
||||
t.Errorf("user-model-2 customField should be preserved")
|
||||
}
|
||||
|
||||
// Verify old ollama-managed model is removed (not in new list)
|
||||
if _, ok := modelIDs["ollama-managed"]; ok {
|
||||
t.Errorf("ollama-managed should be removed (old ollama model not in new selection)")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates settings.json with default provider and model", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Create existing settings with other fields
|
||||
settingsPath := filepath.Join(configDir, "settings.json")
|
||||
existingSettings := `{
|
||||
"theme": "dark",
|
||||
"customSetting": "value",
|
||||
"defaultProvider": "anthropic",
|
||||
"defaultModel": "claude-3"
|
||||
}`
|
||||
if err := os.WriteFile(settingsPath, []byte(existingSettings), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"llama3.2"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read settings: %v", err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
t.Fatalf("Failed to parse settings: %v", err)
|
||||
}
|
||||
|
||||
// Verify defaultProvider is set to ollama
|
||||
if settings["defaultProvider"] != "ollama" {
|
||||
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||
}
|
||||
|
||||
// Verify defaultModel is set to first model
|
||||
if settings["defaultModel"] != "llama3.2" {
|
||||
t.Errorf("defaultModel = %v, want llama3.2", settings["defaultModel"])
|
||||
}
|
||||
|
||||
// Verify other fields are preserved
|
||||
if settings["theme"] != "dark" {
|
||||
t.Errorf("theme = %v, want dark (preserved)", settings["theme"])
|
||||
}
|
||||
if settings["customSetting"] != "value" {
|
||||
t.Errorf("customSetting = %v, want value (preserved)", settings["customSetting"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates settings.json if it does not exist", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
models := []string{"qwen3:8b"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
settingsPath := filepath.Join(configDir, "settings.json")
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("settings.json should be created: %v", err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
t.Fatalf("Failed to parse settings: %v", err)
|
||||
}
|
||||
|
||||
if settings["defaultProvider"] != "ollama" {
|
||||
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||
}
|
||||
if settings["defaultModel"] != "qwen3:8b" {
|
||||
t.Errorf("defaultModel = %v, want qwen3:8b", settings["defaultModel"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupt settings.json gracefully", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Create corrupt settings
|
||||
settingsPath := filepath.Join(configDir, "settings.json")
|
||||
if err := os.WriteFile(settingsPath, []byte("{invalid"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"test-model"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() should not fail with corrupt settings, got %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read settings: %v", err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
t.Fatalf("settings.json should be valid after Edit, got parse error: %v", err)
|
||||
}
|
||||
|
||||
if settings["defaultProvider"] != "ollama" {
|
||||
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||
}
|
||||
if settings["defaultModel"] != "test-model" {
|
||||
t.Errorf("defaultModel = %v, want test-model", settings["defaultModel"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiModels(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
|
||||
t.Run("returns nil when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
models := pi.Models()
|
||||
if models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns models from config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"models": [
|
||||
{"id": "llama3.2"},
|
||||
{"id": "qwen3:8b"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if len(models) != 2 {
|
||||
t.Errorf("Models() returned %d models, want 2", len(models))
|
||||
}
|
||||
if models[0] != "llama3.2" || models[1] != "qwen3:8b" {
|
||||
t.Errorf("Models() = %v, want [llama3.2 qwen3:8b] (sorted)", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns sorted models", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"models": [
|
||||
{"id": "z-model"},
|
||||
{"id": "a-model"},
|
||||
{"id": "m-model"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if models[0] != "a-model" || models[1] != "m-model" || models[2] != "z-model" {
|
||||
t.Errorf("Models() = %v, want [a-model m-model z-model] (sorted)", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when models array is missing", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := `{
|
||||
"providers": {
|
||||
"ollama": {}
|
||||
}
|
||||
}`
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if models != nil {
|
||||
t.Errorf("Models() = %v, want nil when models array is missing", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupt config gracefully", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if models != nil {
|
||||
t.Errorf("Models() = %v, want nil for corrupt config", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsPiOllamaModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg map[string]any
|
||||
want bool
|
||||
}{
|
||||
{"with _launch true", map[string]any{"id": "m", "_launch": true}, true},
|
||||
{"with _launch false", map[string]any{"id": "m", "_launch": false}, false},
|
||||
{"without _launch", map[string]any{"id": "m"}, false},
|
||||
{"with _launch non-bool", map[string]any{"id": "m", "_launch": "yes"}, false},
|
||||
{"empty map", map[string]any{}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isPiOllamaModel(tt.cfg); got != tt.want {
|
||||
t.Errorf("isPiOllamaModel(%v) = %v, want %v", tt.cfg, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateConfig(t *testing.T) {
|
||||
t.Run("sets vision input when model has vision capability", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "llava:7b")
|
||||
|
||||
if cfg["id"] != "llava:7b" {
|
||||
t.Errorf("id = %v, want llava:7b", cfg["id"])
|
||||
}
|
||||
if cfg["_launch"] != true {
|
||||
t.Error("expected _launch = true")
|
||||
}
|
||||
input, ok := cfg["input"].([]string)
|
||||
if !ok || len(input) != 2 || input[0] != "text" || input[1] != "image" {
|
||||
t.Errorf("input = %v, want [text image]", cfg["input"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets text-only input when model lacks vision", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["completion"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "llama3.2")
|
||||
|
||||
input, ok := cfg["input"].([]string)
|
||||
if !ok || len(input) != 1 || input[0] != "text" {
|
||||
t.Errorf("input = %v, want [text]", cfg["input"])
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set for non-thinking model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets reasoning when model has thinking capability", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "qwq")
|
||||
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true for thinking model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("extracts context window from model info", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":131072}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "llama3.2")
|
||||
|
||||
if cfg["contextWindow"] != 131072 {
|
||||
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles all capabilities together", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision","thinking"],"model_info":{"qwen3.context_length":32768}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "qwen3-vision")
|
||||
|
||||
input := cfg["input"].([]string)
|
||||
if len(input) != 2 || input[0] != "text" || input[1] != "image" {
|
||||
t.Errorf("input = %v, want [text image]", input)
|
||||
}
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true")
|
||||
}
|
||||
if cfg["contextWindow"] != 32768 {
|
||||
t.Errorf("contextWindow = %v, want 32768", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns minimal config when show fails", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "missing-model")
|
||||
|
||||
if cfg["id"] != "missing-model" {
|
||||
t.Errorf("id = %v, want missing-model", cfg["id"])
|
||||
}
|
||||
if cfg["_launch"] != true {
|
||||
t.Error("expected _launch = true")
|
||||
}
|
||||
// Should not have capability fields
|
||||
if _, ok := cfg["input"]; ok {
|
||||
t.Error("input should not be set when show fails")
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set when show fails")
|
||||
}
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set when show fails")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips zero context length", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":0}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "test-model")
|
||||
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set for zero value")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Ensure Capability constants used in createConfig match expected values
|
||||
func TestPiCapabilityConstants(t *testing.T) {
|
||||
if model.CapabilityVision != "vision" {
|
||||
t.Errorf("CapabilityVision = %q, want %q", model.CapabilityVision, "vision")
|
||||
}
|
||||
if model.CapabilityThinking != "thinking" {
|
||||
t.Errorf("CapabilityThinking = %q, want %q", model.CapabilityThinking, "thinking")
|
||||
}
|
||||
}
|
||||
@@ -3,461 +3,34 @@ package config
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiHideCursor = "\033[?25l"
|
||||
ansiShowCursor = "\033[?25h"
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiClearDown = "\033[J"
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
)
|
||||
|
||||
const maxDisplayedItems = 10
|
||||
// ErrCancelled is returned when the user cancels a selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
var errCancelled = errors.New("cancelled")
|
||||
// errCancelled is kept as an alias for backward compatibility within the package.
|
||||
var errCancelled = ErrCancelled
|
||||
|
||||
type selectItem struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
type inputEvent int
|
||||
|
||||
const (
|
||||
eventNone inputEvent = iota
|
||||
eventEnter
|
||||
eventEscape
|
||||
eventUp
|
||||
eventDown
|
||||
eventTab
|
||||
eventBackspace
|
||||
eventChar
|
||||
)
|
||||
|
||||
type selectState struct {
|
||||
items []selectItem
|
||||
filter string
|
||||
selected int
|
||||
scrollOffset int
|
||||
}
|
||||
|
||||
func newSelectState(items []selectItem) *selectState {
|
||||
return &selectState{items: items}
|
||||
}
|
||||
|
||||
func (s *selectState) filtered() []selectItem {
|
||||
return filterItems(s.items, s.filter)
|
||||
}
|
||||
|
||||
func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) {
|
||||
filtered := s.filtered()
|
||||
|
||||
switch event {
|
||||
case eventEnter:
|
||||
if len(filtered) > 0 && s.selected < len(filtered) {
|
||||
return true, filtered[s.selected].Name, nil
|
||||
}
|
||||
case eventEscape:
|
||||
return true, "", errCancelled
|
||||
case eventBackspace:
|
||||
if len(s.filter) > 0 {
|
||||
s.filter = s.filter[:len(s.filter)-1]
|
||||
s.selected = 0
|
||||
s.scrollOffset = 0
|
||||
}
|
||||
case eventUp:
|
||||
if s.selected > 0 {
|
||||
s.selected--
|
||||
if s.selected < s.scrollOffset {
|
||||
s.scrollOffset = s.selected
|
||||
}
|
||||
}
|
||||
case eventDown:
|
||||
if s.selected < len(filtered)-1 {
|
||||
s.selected++
|
||||
if s.selected >= s.scrollOffset+maxDisplayedItems {
|
||||
s.scrollOffset = s.selected - maxDisplayedItems + 1
|
||||
}
|
||||
}
|
||||
case eventChar:
|
||||
s.filter += string(char)
|
||||
s.selected = 0
|
||||
s.scrollOffset = 0
|
||||
}
|
||||
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
type multiSelectState struct {
|
||||
items []selectItem
|
||||
itemIndex map[string]int
|
||||
filter string
|
||||
highlighted int
|
||||
scrollOffset int
|
||||
checked map[int]bool
|
||||
checkOrder []int
|
||||
focusOnButton bool
|
||||
}
|
||||
|
||||
func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState {
|
||||
s := &multiSelectState{
|
||||
items: items,
|
||||
itemIndex: make(map[string]int, len(items)),
|
||||
checked: make(map[int]bool),
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
s.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := s.itemIndex[name]; ok {
|
||||
s.checked[idx] = true
|
||||
s.checkOrder = append(s.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *multiSelectState) filtered() []selectItem {
|
||||
return filterItems(s.items, s.filter)
|
||||
}
|
||||
|
||||
func (s *multiSelectState) toggleItem() {
|
||||
filtered := s.filtered()
|
||||
if len(filtered) == 0 || s.highlighted >= len(filtered) {
|
||||
return
|
||||
}
|
||||
|
||||
item := filtered[s.highlighted]
|
||||
origIdx := s.itemIndex[item.Name]
|
||||
|
||||
if s.checked[origIdx] {
|
||||
delete(s.checked, origIdx)
|
||||
for i, idx := range s.checkOrder {
|
||||
if idx == origIdx {
|
||||
s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.checked[origIdx] = true
|
||||
s.checkOrder = append(s.checkOrder, origIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) {
|
||||
filtered := s.filtered()
|
||||
|
||||
switch event {
|
||||
case eventEnter:
|
||||
if s.focusOnButton && len(s.checkOrder) > 0 {
|
||||
var res []string
|
||||
for _, idx := range s.checkOrder {
|
||||
res = append(res, s.items[idx].Name)
|
||||
}
|
||||
return true, res, nil
|
||||
} else if !s.focusOnButton {
|
||||
s.toggleItem()
|
||||
}
|
||||
case eventTab:
|
||||
if len(s.checkOrder) > 0 {
|
||||
s.focusOnButton = !s.focusOnButton
|
||||
}
|
||||
case eventEscape:
|
||||
return true, nil, errCancelled
|
||||
case eventBackspace:
|
||||
if len(s.filter) > 0 {
|
||||
s.filter = s.filter[:len(s.filter)-1]
|
||||
s.highlighted = 0
|
||||
s.scrollOffset = 0
|
||||
s.focusOnButton = false
|
||||
}
|
||||
case eventUp:
|
||||
if s.focusOnButton {
|
||||
s.focusOnButton = false
|
||||
} else if s.highlighted > 0 {
|
||||
s.highlighted--
|
||||
if s.highlighted < s.scrollOffset {
|
||||
s.scrollOffset = s.highlighted
|
||||
}
|
||||
}
|
||||
case eventDown:
|
||||
if s.focusOnButton {
|
||||
s.focusOnButton = false
|
||||
} else if s.highlighted < len(filtered)-1 {
|
||||
s.highlighted++
|
||||
if s.highlighted >= s.scrollOffset+maxDisplayedItems {
|
||||
s.scrollOffset = s.highlighted - maxDisplayedItems + 1
|
||||
}
|
||||
}
|
||||
case eventChar:
|
||||
s.filter += string(char)
|
||||
s.highlighted = 0
|
||||
s.scrollOffset = 0
|
||||
s.focusOnButton = false
|
||||
}
|
||||
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
func (s *multiSelectState) selectedCount() int {
|
||||
return len(s.checkOrder)
|
||||
}
|
||||
|
||||
// Terminal I/O handling
|
||||
|
||||
type terminalState struct {
|
||||
fd int
|
||||
oldState *term.State
|
||||
}
|
||||
|
||||
func enterRawMode() (*terminalState, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Fprint(os.Stderr, ansiHideCursor)
|
||||
return &terminalState{fd: fd, oldState: oldState}, nil
|
||||
}
|
||||
|
||||
func (t *terminalState) restore() {
|
||||
fmt.Fprint(os.Stderr, ansiShowCursor)
|
||||
term.Restore(t.fd, t.oldState)
|
||||
}
|
||||
|
||||
func clearLines(n int) {
|
||||
if n > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA", n)
|
||||
fmt.Fprint(os.Stderr, ansiClearDown)
|
||||
}
|
||||
}
|
||||
|
||||
func parseInput(r io.Reader) (inputEvent, byte, error) {
|
||||
buf := make([]byte, 3)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case n == 1 && buf[0] == 13:
|
||||
return eventEnter, 0, nil
|
||||
case n == 1 && (buf[0] == 3 || buf[0] == 27):
|
||||
return eventEscape, 0, nil
|
||||
case n == 1 && buf[0] == 9:
|
||||
return eventTab, 0, nil
|
||||
case n == 1 && buf[0] == 127:
|
||||
return eventBackspace, 0, nil
|
||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65:
|
||||
return eventUp, 0, nil
|
||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66:
|
||||
return eventDown, 0, nil
|
||||
case n == 1 && buf[0] >= 32 && buf[0] < 127:
|
||||
return eventChar, buf[0], nil
|
||||
}
|
||||
|
||||
return eventNone, 0, nil
|
||||
}
|
||||
|
||||
// Rendering
|
||||
|
||||
func renderSelect(w io.Writer, prompt string, s *selectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
lineCount++
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxDisplayedItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := s.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
prefix := " "
|
||||
if idx == s.selected {
|
||||
prefix = " " + ansiBold + "> "
|
||||
}
|
||||
if item.Description != "" {
|
||||
fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||
lineCount++
|
||||
}
|
||||
}
|
||||
|
||||
return lineCount
|
||||
}
|
||||
|
||||
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
lineCount++
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxDisplayedItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := s.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
origIdx := s.itemIndex[item.Name]
|
||||
|
||||
checkbox := "[ ]"
|
||||
if s.checked[origIdx] {
|
||||
checkbox = "[x]"
|
||||
}
|
||||
|
||||
prefix := " "
|
||||
suffix := ""
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
prefix = "> "
|
||||
}
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx {
|
||||
suffix = " " + ansiGray + "(default)" + ansiReset
|
||||
}
|
||||
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
|
||||
} else {
|
||||
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||
lineCount++
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "\r\n")
|
||||
lineCount++
|
||||
count := s.selectedCount()
|
||||
switch {
|
||||
case count == 0:
|
||||
fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset)
|
||||
case s.focusOnButton:
|
||||
fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset)
|
||||
default:
|
||||
fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset)
|
||||
}
|
||||
lineCount++
|
||||
|
||||
return lineCount
|
||||
}
|
||||
|
||||
// selectPrompt prompts the user to select a single item from a list.
|
||||
func selectPrompt(prompt string, items []selectItem) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
ts, err := enterRawMode()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer ts.restore()
|
||||
|
||||
state := newSelectState(items)
|
||||
var lastLineCount int
|
||||
|
||||
render := func() {
|
||||
clearLines(lastLineCount)
|
||||
lastLineCount = renderSelect(os.Stderr, prompt, state)
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
for {
|
||||
event, char, err := parseInput(os.Stdin)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
done, result, err := state.handleInput(event, char)
|
||||
if done {
|
||||
clearLines(lastLineCount)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
render()
|
||||
}
|
||||
}
|
||||
|
||||
// multiSelectPrompt prompts the user to select multiple items from a list.
|
||||
func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
ts, err := enterRawMode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ts.restore()
|
||||
|
||||
state := newMultiSelectState(items, preChecked)
|
||||
var lastLineCount int
|
||||
|
||||
render := func() {
|
||||
clearLines(lastLineCount)
|
||||
lastLineCount = renderMultiSelect(os.Stderr, prompt, state)
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
for {
|
||||
event, char, err := parseInput(os.Stdin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
done, result, err := state.handleInput(event, char)
|
||||
if done {
|
||||
clearLines(lastLineCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
render()
|
||||
}
|
||||
}
|
||||
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||
// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
|
||||
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
||||
|
||||
func confirmPrompt(prompt string) (bool, error) {
|
||||
if DefaultConfirmPrompt != nil {
|
||||
return DefaultConfirmPrompt(prompt)
|
||||
}
|
||||
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
@@ -483,17 +56,3 @@ func confirmPrompt(prompt string) (bool, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterItems(items []selectItem, filter string) []selectItem {
|
||||
if filter == "" {
|
||||
return items
|
||||
}
|
||||
var result []selectItem
|
||||
filterLower := strings.ToLower(filter)
|
||||
for _, item := range items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -1,651 +1,9 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFilterItems(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "llama3.2:latest"},
|
||||
{Name: "qwen2.5:7b"},
|
||||
{Name: "deepseek-v3:cloud"},
|
||||
{Name: "GPT-OSS:20b"},
|
||||
}
|
||||
|
||||
t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) {
|
||||
result := filterItems(items, "")
|
||||
if len(result) != len(items) {
|
||||
t.Errorf("expected %d items, got %d", len(items), len(result))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) {
|
||||
result := filterItems(items, "LLAMA")
|
||||
if len(result) != 1 || result[0].Name != "llama3.2:latest" {
|
||||
t.Errorf("expected llama3.2:latest, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) {
|
||||
result := filterItems(items, "gpt")
|
||||
if len(result) != 1 || result[0].Name != "GPT-OSS:20b" {
|
||||
t.Errorf("expected GPT-OSS:20b, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PartialMatch", func(t *testing.T) {
|
||||
result := filterItems(items, "deep")
|
||||
if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" {
|
||||
t.Errorf("expected deepseek-v3:cloud, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) {
|
||||
result := filterItems(items, "nonexistent")
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected 0 items, got %d", len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSelectState(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
{Name: "item3"},
|
||||
}
|
||||
|
||||
t.Run("InitialState", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0, got %d", s.selected)
|
||||
}
|
||||
if s.filter != "" {
|
||||
t.Errorf("expected empty filter, got %q", s.filter)
|
||||
}
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_SelectsCurrentItem", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "item1" || err != nil {
|
||||
t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "item3"
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "item3" || err != nil {
|
||||
t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "nonexistent"
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != "" || err != nil {
|
||||
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
if !done || result != "" || err != errCancelled {
|
||||
t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_MovesSelection", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 1 {
|
||||
t.Errorf("expected selected=1, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 2 {
|
||||
t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_MovesSelection", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 1 {
|
||||
t.Errorf("expected selected=1, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 (stayed at top), got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventChar, 'i')
|
||||
s.handleInput(eventChar, 't')
|
||||
s.handleInput(eventChar, 'e')
|
||||
s.handleInput(eventChar, 'm')
|
||||
s.handleInput(eventChar, '2')
|
||||
if s.filter != "item2" {
|
||||
t.Errorf("expected filter='item2', got %q", s.filter)
|
||||
}
|
||||
filtered := s.filtered()
|
||||
if len(filtered) != 1 || filtered[0].Name != "item2" {
|
||||
t.Errorf("expected [item2], got %v", filtered)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_ResetsSelectionToZero", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventChar, 'x')
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after typing, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "test"
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "tes" {
|
||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "" {
|
||||
t.Errorf("expected filter='', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "test"
|
||||
s.selected = 2
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after backspace, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) {
|
||||
// maxDisplayedItems is 10, so with 15 items we need to scroll
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
|
||||
// move down 12 times (past the 10-item viewport)
|
||||
for range 12 {
|
||||
s.handleInput(eventDown, 0)
|
||||
}
|
||||
|
||||
if s.selected != 12 {
|
||||
t.Errorf("expected selected=12, got %d", s.selected)
|
||||
}
|
||||
if s.scrollOffset != 3 {
|
||||
t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
s.selected = 5
|
||||
s.scrollOffset = 5
|
||||
|
||||
s.handleInput(eventUp, 0)
|
||||
|
||||
if s.selected != 4 {
|
||||
t.Errorf("expected selected=4, got %d", s.selected)
|
||||
}
|
||||
if s.scrollOffset != 4 {
|
||||
t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiSelectState(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
{Name: "item3"},
|
||||
}
|
||||
|
||||
t.Run("InitialState_NoPrechecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
if s.selectedCount() != 0 {
|
||||
t.Errorf("expected 0 selected, got %d", s.selectedCount())
|
||||
}
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focusOnButton=false initially")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InitialState_WithPrechecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item3"})
|
||||
if s.selectedCount() != 2 {
|
||||
t.Errorf("expected 2 selected, got %d", s.selectedCount())
|
||||
}
|
||||
if !s.checked[1] || !s.checked[2] {
|
||||
t.Error("expected item2 and item3 to be checked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) {
|
||||
// order matters: first checked = default model
|
||||
s := newMultiSelectState(items, []string{"item3", "item1"})
|
||||
if len(s.checkOrder) != 2 {
|
||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||
}
|
||||
if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 {
|
||||
t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "nonexistent"})
|
||||
if s.selectedCount() != 1 {
|
||||
t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.toggleItem()
|
||||
if !s.checked[0] {
|
||||
t.Error("expected item1 to be checked after toggle")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.toggleItem()
|
||||
if s.checked[0] {
|
||||
t.Error("expected item1 to be unchecked after toggle")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "item2", "item3"})
|
||||
s.highlighted = 1 // toggle item2
|
||||
s.toggleItem()
|
||||
|
||||
if len(s.checkOrder) != 2 {
|
||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||
}
|
||||
// should be [0, 2] (item1, item3) with item2 removed
|
||||
if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 {
|
||||
t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventEnter, 0)
|
||||
if !s.checked[0] {
|
||||
t.Error("expected item1 to be checked after enter")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||
s.focusOnButton = true
|
||||
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
|
||||
if !done || err != nil {
|
||||
t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err)
|
||||
}
|
||||
// result should preserve selection order
|
||||
if len(result) != 2 || result[0] != "item2" || result[1] != "item1" {
|
||||
t.Errorf("expected [item2, item1], got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.focusOnButton = true
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != nil || err != nil {
|
||||
t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.handleInput(eventTab, 0)
|
||||
if !s.focusOnButton {
|
||||
t.Error("expected focus on button after tab")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventTab, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("tab should not focus button when nothing selected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_TogglesButtonFocus", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.handleInput(eventTab, 0)
|
||||
if !s.focusOnButton {
|
||||
t.Error("expected focus on button after first tab")
|
||||
}
|
||||
s.handleInput(eventTab, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focus back on list after second tab")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
if !done || result != nil || err != errCancelled {
|
||||
t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||
if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) {
|
||||
t.Error("expected item2 (idx 1) to be default (first checked)")
|
||||
}
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||
t.Error("expected item1 (idx 0) to NOT be default")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||
t.Error("expected isDefault=false when nothing checked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_MovesHighlight", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.highlighted != 1 {
|
||||
t.Errorf("expected highlighted=1, got %d", s.highlighted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_MovesHighlight", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.highlighted = 1
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.focusOnButton = true
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focus to return to list on arrow key")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventChar, 'x')
|
||||
if s.filter != "x" {
|
||||
t.Errorf("expected filter='x', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newMultiSelectState(manyItems, nil)
|
||||
s.highlighted = 10
|
||||
s.scrollOffset = 5
|
||||
|
||||
s.handleInput(eventChar, 'x')
|
||||
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.filter = "test"
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "tes" {
|
||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.filter = "x"
|
||||
s.focusOnButton = true
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focusOnButton=false after backspace")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseInput(t *testing.T) {
|
||||
t.Run("Enter", func(t *testing.T) {
|
||||
event, char, err := parseInput(bytes.NewReader([]byte{13}))
|
||||
if err != nil || event != eventEnter || char != 0 {
|
||||
t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27}))
|
||||
if err != nil || event != eventEscape {
|
||||
t.Errorf("expected eventEscape, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{3}))
|
||||
if err != nil || event != eventEscape {
|
||||
t.Errorf("expected eventEscape for Ctrl+C, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{9}))
|
||||
if err != nil || event != eventTab {
|
||||
t.Errorf("expected eventTab, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{127}))
|
||||
if err != nil || event != eventBackspace {
|
||||
t.Errorf("expected eventBackspace, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpArrow", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65}))
|
||||
if err != nil || event != eventUp {
|
||||
t.Errorf("expected eventUp, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DownArrow", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66}))
|
||||
if err != nil || event != eventDown {
|
||||
t.Errorf("expected eventDown, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PrintableChars", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
char byte
|
||||
}{
|
||||
{"lowercase", 'a'},
|
||||
{"uppercase", 'Z'},
|
||||
{"digit", '5'},
|
||||
{"space", ' '},
|
||||
{"tilde", '~'},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
event, char, err := parseInput(bytes.NewReader([]byte{tt.char}))
|
||||
if err != nil || event != eventChar || char != tt.char {
|
||||
t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRenderSelect(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1", Description: "first item"},
|
||||
{Name: "item2"},
|
||||
}
|
||||
|
||||
t.Run("ShowsPromptAndItems", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
var buf bytes.Buffer
|
||||
lineCount := renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Select:") {
|
||||
t.Error("expected prompt in output")
|
||||
}
|
||||
if !strings.Contains(output, "item1") {
|
||||
t.Error("expected item1 in output")
|
||||
}
|
||||
if !strings.Contains(output, "first item") {
|
||||
t.Error("expected description in output")
|
||||
}
|
||||
if !strings.Contains(output, "item2") {
|
||||
t.Error("expected item2 in output")
|
||||
}
|
||||
if lineCount != 3 { // 1 prompt + 2 items
|
||||
t.Errorf("expected 3 lines, got %d", lineCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "xyz"
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' message")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LongList_ShowsRemainingCount", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
// 15 items - 10 displayed = 5 more
|
||||
if !strings.Contains(buf.String(), "5 more") {
|
||||
t.Error("expected '5 more' indicator")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRenderMultiSelect(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
}
|
||||
|
||||
t.Run("ShowsCheckboxes", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "[x]") {
|
||||
t.Error("expected checked checkbox [x]")
|
||||
}
|
||||
if !strings.Contains(output, "[ ]") {
|
||||
t.Error("expected unchecked checkbox [ ]")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ShowsDefaultMarker", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "(default)") {
|
||||
t.Error("expected (default) marker for first checked item")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ShowsSelectedCount", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "item2"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "2 selected") {
|
||||
t.Error("expected '2 selected' in output")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoSelection_ShowsHelperText", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "Select at least one") {
|
||||
t.Error("expected 'Select at least one' helper text")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrCancelled(t *testing.T) {
|
||||
t.Run("NotNil", func(t *testing.T) {
|
||||
if errCancelled == nil {
|
||||
@@ -659,255 +17,3 @@ func TestErrCancelled(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Edge case tests for selector.go
|
||||
|
||||
// TestSelectState_SingleItem verifies that single item list works without crash.
|
||||
// List with only one item should still work.
|
||||
func TestSelectState_SingleItem(t *testing.T) {
|
||||
items := []selectItem{{Name: "only-one"}}
|
||||
|
||||
s := newSelectState(items)
|
||||
|
||||
// Down should do nothing (already at bottom)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("down on single item: expected selected=0, got %d", s.selected)
|
||||
}
|
||||
|
||||
// Up should do nothing (already at top)
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("up on single item: expected selected=0, got %d", s.selected)
|
||||
}
|
||||
|
||||
// Enter should select the only item
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "only-one" || err != nil {
|
||||
t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems.
|
||||
// List with exactly maxDisplayedItems items should not scroll.
|
||||
func TestSelectState_ExactlyMaxItems(t *testing.T) {
|
||||
items := make([]selectItem, maxDisplayedItems)
|
||||
for i := range items {
|
||||
items[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
|
||||
// Move to last item
|
||||
for range maxDisplayedItems - 1 {
|
||||
s.handleInput(eventDown, 0)
|
||||
}
|
||||
|
||||
if s.selected != maxDisplayedItems-1 {
|
||||
t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||
}
|
||||
|
||||
// Should not scroll when exactly at max
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset)
|
||||
}
|
||||
|
||||
// One more down should do nothing
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != maxDisplayedItems-1 {
|
||||
t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex.
|
||||
// User typing "model.v1" shouldn't match "modelsv1".
|
||||
func TestFilterItems_RegexSpecialChars(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "model.v1"},
|
||||
{Name: "modelsv1"},
|
||||
{Name: "model-v1"},
|
||||
}
|
||||
|
||||
// Filter with dot should only match literal dot
|
||||
result := filterItems(items, "model.v1")
|
||||
if len(result) != 1 {
|
||||
t.Errorf("expected 1 exact match, got %d", len(result))
|
||||
}
|
||||
if len(result) > 0 && result[0].Name != "model.v1" {
|
||||
t.Errorf("expected 'model.v1', got %s", result[0].Name)
|
||||
}
|
||||
|
||||
// Other regex special chars should be literal too
|
||||
items2 := []selectItem{
|
||||
{Name: "test[0]"},
|
||||
{Name: "test0"},
|
||||
{Name: "test(1)"},
|
||||
}
|
||||
|
||||
result2 := filterItems(items2, "test[0]")
|
||||
if len(result2) != 1 || result2[0].Name != "test[0]" {
|
||||
t.Errorf("expected only 'test[0]', got %v", result2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiSelectState_DuplicateNames documents handling of duplicate item names.
|
||||
// itemIndex uses name as key - duplicates cause collision. This documents
|
||||
// the current behavior: the last index for a duplicate name is stored
|
||||
func TestMultiSelectState_DuplicateNames(t *testing.T) {
|
||||
// Duplicate names - this is an edge case that shouldn't happen in practice
|
||||
items := []selectItem{
|
||||
{Name: "duplicate"},
|
||||
{Name: "duplicate"},
|
||||
{Name: "unique"},
|
||||
}
|
||||
|
||||
s := newMultiSelectState(items, nil)
|
||||
|
||||
// DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index
|
||||
// When there are duplicates, only the last occurrence's index is stored
|
||||
if s.itemIndex["duplicate"] != 1 {
|
||||
t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"])
|
||||
}
|
||||
|
||||
// Toggle item at highlighted=0 (first "duplicate")
|
||||
// Due to name collision, toggleItem uses itemIndex["duplicate"] = 1
|
||||
// So it actually toggles the SECOND duplicate item, not the first
|
||||
s.toggleItem()
|
||||
|
||||
// This documents the potentially surprising behavior:
|
||||
// We toggled at highlighted=0, but itemIndex lookup returned 1
|
||||
if !s.checked[1] {
|
||||
t.Error("toggle should check index 1 (due to name collision in itemIndex)")
|
||||
}
|
||||
if s.checked[0] {
|
||||
t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list.
|
||||
// Prevents index-out-of-bounds on next keystroke
|
||||
func TestSelectState_FilterReducesBelowSelection(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "apple"},
|
||||
{Name: "banana"},
|
||||
{Name: "cherry"},
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
s.selected = 2 // Select "cherry"
|
||||
|
||||
// Type a filter that removes cherry from results
|
||||
s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana"
|
||||
|
||||
// Selection should reset to 0
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after filter, got %d", s.selected)
|
||||
}
|
||||
|
||||
filtered := s.filtered()
|
||||
if len(filtered) != 2 {
|
||||
t.Errorf("expected 2 filtered items, got %d", len(filtered))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8.
|
||||
// Model names might contain unicode characters
|
||||
func TestFilterItems_UnicodeCharacters(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "llama-日本語"},
|
||||
{Name: "模型-chinese"},
|
||||
{Name: "émoji-🦙"},
|
||||
{Name: "regular-model"},
|
||||
}
|
||||
|
||||
t.Run("filter japanese", func(t *testing.T) {
|
||||
result := filterItems(items, "日本")
|
||||
if len(result) != 1 || result[0].Name != "llama-日本語" {
|
||||
t.Errorf("expected llama-日本語, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter chinese", func(t *testing.T) {
|
||||
result := filterItems(items, "模型")
|
||||
if len(result) != 1 || result[0].Name != "模型-chinese" {
|
||||
t.Errorf("expected 模型-chinese, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter emoji", func(t *testing.T) {
|
||||
result := filterItems(items, "🦙")
|
||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||
t.Errorf("expected émoji-🦙, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter accented char", func(t *testing.T) {
|
||||
result := filterItems(items, "émoji")
|
||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||
t.Errorf("expected émoji-🦙, got %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list.
|
||||
func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "apple"},
|
||||
{Name: "banana"},
|
||||
{Name: "cherry"},
|
||||
}
|
||||
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.highlighted = 2 // Highlight "cherry"
|
||||
|
||||
// Type a filter that removes cherry
|
||||
s.handleInput(eventChar, 'a')
|
||||
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiSelectState_EmptyItems verifies handling of empty item list.
|
||||
// Empty list should be handled gracefully.
|
||||
func TestMultiSelectState_EmptyItems(t *testing.T) {
|
||||
s := newMultiSelectState([]selectItem{}, nil)
|
||||
|
||||
// Toggle should not panic on empty list
|
||||
s.toggleItem()
|
||||
|
||||
if s.selectedCount() != 0 {
|
||||
t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount())
|
||||
}
|
||||
|
||||
// Render should handle empty list
|
||||
var buf bytes.Buffer
|
||||
lineCount := renderMultiSelect(&buf, "Select:", s)
|
||||
if lineCount == 0 {
|
||||
t.Error("renderMultiSelect should produce output even for empty list")
|
||||
}
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' for empty list")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions.
|
||||
func TestSelectState_RenderWithDescriptions(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1", Description: "First item description"},
|
||||
{Name: "item2", Description: ""},
|
||||
{Name: "item3", Description: "Third item"},
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "First item description") {
|
||||
t.Error("expected description to be rendered")
|
||||
}
|
||||
if !strings.Contains(output, "item2") {
|
||||
t.Error("expected item without description to be rendered")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,19 +10,21 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
|
||||
|
||||
func startApp(ctx context.Context, client *api.Client) error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
return errNotRunning
|
||||
}
|
||||
link, err := os.Readlink(exe)
|
||||
if err != nil {
|
||||
return err
|
||||
return errNotRunning
|
||||
}
|
||||
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
|
||||
m := r.FindStringSubmatch(link)
|
||||
if len(m) != 1 {
|
||||
return errors.New("could not find ollama app")
|
||||
return errNotRunning
|
||||
}
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
|
||||
return err
|
||||
|
||||
109
cmd/tui/confirm.go
Normal file
109
cmd/tui/confirm.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
var (
|
||||
confirmActiveStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
|
||||
|
||||
confirmInactiveStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
)
|
||||
|
||||
type confirmModel struct {
|
||||
prompt string
|
||||
yes bool
|
||||
confirmed bool
|
||||
cancelled bool
|
||||
width int
|
||||
}
|
||||
|
||||
func (m confirmModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m confirmModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "esc", "n":
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
case "y":
|
||||
m.yes = true
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
case "enter":
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
case "left", "h":
|
||||
m.yes = true
|
||||
case "right", "l":
|
||||
m.yes = false
|
||||
case "tab":
|
||||
m.yes = !m.yes
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m confirmModel) View() string {
|
||||
if m.confirmed || m.cancelled {
|
||||
return ""
|
||||
}
|
||||
|
||||
var yesBtn, noBtn string
|
||||
if m.yes {
|
||||
yesBtn = confirmActiveStyle.Render(" Yes ")
|
||||
noBtn = confirmInactiveStyle.Render(" No ")
|
||||
} else {
|
||||
yesBtn = confirmInactiveStyle.Render(" Yes ")
|
||||
noBtn = confirmActiveStyle.Render(" No ")
|
||||
}
|
||||
|
||||
s := selectorTitleStyle.Render(m.prompt) + "\n\n"
|
||||
s += " " + yesBtn + " " + noBtn + "\n\n"
|
||||
s += selectorHelpStyle.Render("←/→ navigate • enter confirm • esc cancel")
|
||||
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// RunConfirm shows a bubbletea yes/no confirmation prompt.
|
||||
// Returns true if the user confirmed, false if cancelled.
|
||||
func RunConfirm(prompt string) (bool, error) {
|
||||
m := confirmModel{
|
||||
prompt: prompt,
|
||||
yes: true, // default to yes
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error running confirm: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(confirmModel)
|
||||
if fm.cancelled {
|
||||
return false, ErrCancelled
|
||||
}
|
||||
|
||||
return fm.yes, nil
|
||||
}
|
||||
208
cmd/tui/confirm_test.go
Normal file
208
cmd/tui/confirm_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
func TestConfirmModel_DefaultsToYes(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download test?", yes: true}
|
||||
if !m.yes {
|
||||
t.Error("should default to yes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ContainsPrompt(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download qwen3:8b?", yes: true}
|
||||
got := m.View()
|
||||
if !strings.Contains(got, "Download qwen3:8b?") {
|
||||
t.Error("should contain the prompt text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ContainsButtons(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
got := m.View()
|
||||
if !strings.Contains(got, "Yes") {
|
||||
t.Error("should contain Yes button")
|
||||
}
|
||||
if !strings.Contains(got, "No") {
|
||||
t.Error("should contain No button")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ContainsHelp(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
got := m.View()
|
||||
if !strings.Contains(got, "enter confirm") {
|
||||
t.Error("should contain help text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ClearsAfterConfirm(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", confirmed: true}
|
||||
if m.View() != "" {
|
||||
t.Error("View should return empty string after confirmation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ClearsAfterCancel(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", cancelled: true}
|
||||
if m.View() != "" {
|
||||
t.Error("View should return empty string after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_EnterConfirmsYes(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.confirmed {
|
||||
t.Error("enter should set confirmed=true")
|
||||
}
|
||||
if !fm.yes {
|
||||
t.Error("enter with yes selected should keep yes=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("enter should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_EnterConfirmsNo(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: false}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.confirmed {
|
||||
t.Error("enter should set confirmed=true")
|
||||
}
|
||||
if fm.yes {
|
||||
t.Error("enter with no selected should keep yes=false")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("enter should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_EscCancels(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEsc})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("esc should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("esc should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_CtrlCCancels(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("ctrl+c should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("ctrl+c should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_NCancels(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'n'}})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("'n' should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("'n' should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_YConfirmsYes(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: false}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'y'}})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.confirmed {
|
||||
t.Error("'y' should set confirmed=true")
|
||||
}
|
||||
if !fm.yes {
|
||||
t.Error("'y' should set yes=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("'y' should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_ArrowKeysNavigate(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
|
||||
// Right moves to No
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'l'}})
|
||||
fm := updated.(confirmModel)
|
||||
if fm.yes {
|
||||
t.Error("right/l should move to No")
|
||||
}
|
||||
if fm.confirmed || fm.cancelled {
|
||||
t.Error("navigation should not confirm or cancel")
|
||||
}
|
||||
|
||||
// Left moves back to Yes
|
||||
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'h'}})
|
||||
fm = updated.(confirmModel)
|
||||
if !fm.yes {
|
||||
t.Error("left/h should move to Yes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_TabToggles(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
fm := updated.(confirmModel)
|
||||
if fm.yes {
|
||||
t.Error("tab should toggle from Yes to No")
|
||||
}
|
||||
|
||||
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
fm = updated.(confirmModel)
|
||||
if !fm.yes {
|
||||
t.Error("tab should toggle from No to Yes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_WindowSizeUpdatesWidth(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?"}
|
||||
updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 40})
|
||||
fm := updated.(confirmModel)
|
||||
if fm.width != 100 {
|
||||
t.Errorf("expected width 100, got %d", fm.width)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_ResizeEntersAltScreen(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", width: 80}
|
||||
_, cmd := m.Update(tea.WindowSizeMsg{Width: 100, Height: 40})
|
||||
if cmd == nil {
|
||||
t.Error("resize (width already set) should return a command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_InitialWindowSizeNoAltScreen(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?"}
|
||||
_, cmd := m.Update(tea.WindowSizeMsg{Width: 80, Height: 40})
|
||||
if cmd != nil {
|
||||
t.Error("initial WindowSizeMsg should not return a command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_ViewMaxWidth(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true, width: 40}
|
||||
got := m.View()
|
||||
// Just ensure it doesn't panic and returns content
|
||||
if got == "" {
|
||||
t.Error("View with width set should still return content")
|
||||
}
|
||||
}
|
||||
654
cmd/tui/selector.go
Normal file
654
cmd/tui/selector.go
Normal file
@@ -0,0 +1,654 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
var (
|
||||
selectorTitleStyle = lipgloss.NewStyle().
|
||||
Bold(true)
|
||||
|
||||
selectorItemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(4)
|
||||
|
||||
selectorSelectedItemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Bold(true).
|
||||
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
|
||||
|
||||
selectorDescStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
|
||||
selectorDescLineStyle = selectorDescStyle.
|
||||
PaddingLeft(6)
|
||||
|
||||
selectorFilterStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
|
||||
Italic(true)
|
||||
|
||||
selectorInputStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "235", Dark: "252"})
|
||||
|
||||
selectorCheckboxStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
|
||||
selectorCheckboxCheckedStyle = lipgloss.NewStyle().
|
||||
Bold(true)
|
||||
|
||||
selectorDefaultTagStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
|
||||
Italic(true)
|
||||
|
||||
selectorHelpStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "244", Dark: "244"})
|
||||
|
||||
selectorMoreStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(6).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
|
||||
Italic(true)
|
||||
|
||||
sectionHeaderStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Bold(true).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "240", Dark: "249"})
|
||||
)
|
||||
|
||||
const maxSelectorItems = 10
|
||||
|
||||
// ErrCancelled is returned when the user cancels the selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
type SelectItem struct {
|
||||
Name string
|
||||
Description string
|
||||
Recommended bool
|
||||
}
|
||||
|
||||
// selectorModel is the bubbletea model for single selection.
|
||||
type selectorModel struct {
|
||||
title string
|
||||
items []SelectItem
|
||||
filter string
|
||||
cursor int
|
||||
scrollOffset int
|
||||
selected string
|
||||
cancelled bool
|
||||
helpText string
|
||||
width int
|
||||
}
|
||||
|
||||
func (m selectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
}
|
||||
filterLower := strings.ToLower(m.filter)
|
||||
var result []SelectItem
|
||||
for _, item := range m.items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m selectorModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// otherStart returns the index of the first non-recommended item in the filtered list.
|
||||
// When filtering, all items scroll together so this returns 0.
|
||||
func (m selectorModel) otherStart() int {
|
||||
if m.filter != "" {
|
||||
return 0
|
||||
}
|
||||
filtered := m.filteredItems()
|
||||
for i, item := range filtered {
|
||||
if !item.Recommended {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return len(filtered)
|
||||
}
|
||||
|
||||
// updateNavigation handles navigation keys (up/down/pgup/pgdown/filter/backspace).
|
||||
// It does NOT handle Enter, Esc, or CtrlC. This is used by both the standalone
|
||||
// selector and the TUI modal (which intercepts Enter/Esc for its own logic).
|
||||
func (m *selectorModel) updateNavigation(msg tea.KeyMsg) {
|
||||
filtered := m.filteredItems()
|
||||
otherStart := m.otherStart()
|
||||
|
||||
switch msg.Type {
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
m.updateScroll(otherStart)
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
if m.cursor < len(filtered)-1 {
|
||||
m.cursor++
|
||||
m.updateScroll(otherStart)
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
m.cursor -= maxSelectorItems
|
||||
if m.cursor < 0 {
|
||||
m.cursor = 0
|
||||
}
|
||||
m.updateScroll(otherStart)
|
||||
|
||||
case tea.KeyPgDown:
|
||||
m.cursor += maxSelectorItems
|
||||
if m.cursor >= len(filtered) {
|
||||
m.cursor = len(filtered) - 1
|
||||
}
|
||||
m.updateScroll(otherStart)
|
||||
|
||||
case tea.KeyBackspace:
|
||||
if len(m.filter) > 0 {
|
||||
m.filter = m.filter[:len(m.filter)-1]
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
|
||||
// updateScroll adjusts scrollOffset based on cursor position.
|
||||
// When not filtering, scrollOffset is relative to the "More" (non-recommended) section.
|
||||
// When filtering, it's relative to the full filtered list.
|
||||
func (m *selectorModel) updateScroll(otherStart int) {
|
||||
if m.filter != "" {
|
||||
if m.cursor < m.scrollOffset {
|
||||
m.scrollOffset = m.cursor
|
||||
}
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Cursor is in recommended section — reset "More" scroll to top
|
||||
if m.cursor < otherStart {
|
||||
m.scrollOffset = 0
|
||||
return
|
||||
}
|
||||
|
||||
// Cursor is in "More" section — scroll relative to others
|
||||
posInOthers := m.cursor - otherStart
|
||||
maxOthers := maxSelectorItems - otherStart
|
||||
if maxOthers < 3 {
|
||||
maxOthers = 3
|
||||
}
|
||||
if posInOthers < m.scrollOffset {
|
||||
m.scrollOffset = posInOthers
|
||||
}
|
||||
if posInOthers >= m.scrollOffset+maxOthers {
|
||||
m.scrollOffset = posInOthers - maxOthers + 1
|
||||
}
|
||||
}
|
||||
|
||||
func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyEnter:
|
||||
filtered := m.filteredItems()
|
||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||
m.selected = filtered[m.cursor].Name
|
||||
}
|
||||
return m, tea.Quit
|
||||
|
||||
default:
|
||||
m.updateNavigation(msg)
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m selectorModel) renderItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
s.WriteString("\n")
|
||||
if item.Description != "" {
|
||||
s.WriteString(selectorDescLineStyle.Render(item.Description))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// renderContent renders the selector content (title, items, help text) without
|
||||
// checking the cancelled/selected state. This is used by both View() (standalone mode)
|
||||
// and by the TUI modal which embeds a selectorModel.
|
||||
func (m selectorModel) renderContent() string {
|
||||
var s strings.Builder
|
||||
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
s.WriteString(" ")
|
||||
if m.filter == "" {
|
||||
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
|
||||
} else {
|
||||
s.WriteString(selectorInputStyle.Render(m.filter))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
filtered := m.filteredItems()
|
||||
|
||||
if len(filtered) == 0 {
|
||||
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
|
||||
s.WriteString("\n")
|
||||
} else if m.filter != "" {
|
||||
s.WriteString(sectionHeaderStyle.Render("Top Results"))
|
||||
s.WriteString("\n")
|
||||
|
||||
displayCount := min(len(filtered), maxSelectorItems)
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
m.renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
// Split into pinned recommended and scrollable others
|
||||
var recItems, otherItems []int
|
||||
for i, item := range filtered {
|
||||
if item.Recommended {
|
||||
recItems = append(recItems, i)
|
||||
} else {
|
||||
otherItems = append(otherItems, i)
|
||||
}
|
||||
}
|
||||
|
||||
// Always render all recommended items (pinned)
|
||||
if len(recItems) > 0 {
|
||||
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
||||
s.WriteString("\n")
|
||||
for _, idx := range recItems {
|
||||
m.renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
}
|
||||
|
||||
if len(otherItems) > 0 {
|
||||
s.WriteString("\n")
|
||||
s.WriteString(sectionHeaderStyle.Render("More"))
|
||||
s.WriteString("\n")
|
||||
|
||||
maxOthers := maxSelectorItems - len(recItems)
|
||||
if maxOthers < 3 {
|
||||
maxOthers = 3
|
||||
}
|
||||
displayCount := min(len(otherItems), maxOthers)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(otherItems) {
|
||||
break
|
||||
}
|
||||
m.renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
}
|
||||
|
||||
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.WriteString("\n")
|
||||
help := "↑/↓ navigate • enter select • esc cancel"
|
||||
if m.helpText != "" {
|
||||
help = m.helpText
|
||||
}
|
||||
s.WriteString(selectorHelpStyle.Render(help))
|
||||
|
||||
return s.String()
|
||||
}
|
||||
|
||||
func (m selectorModel) View() string {
|
||||
if m.cancelled || m.selected != "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
s := m.renderContent()
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func SelectSingle(title string, items []SelectItem) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error running selector: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(selectorModel)
|
||||
if fm.cancelled {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
|
||||
return fm.selected, nil
|
||||
}
|
||||
|
||||
// multiSelectorModel is the bubbletea model for multi selection.
|
||||
type multiSelectorModel struct {
|
||||
title string
|
||||
items []SelectItem
|
||||
itemIndex map[string]int
|
||||
filter string
|
||||
cursor int
|
||||
scrollOffset int
|
||||
checked map[int]bool
|
||||
checkOrder []int
|
||||
cancelled bool
|
||||
confirmed bool
|
||||
width int
|
||||
}
|
||||
|
||||
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
|
||||
m := multiSelectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
itemIndex: make(map[string]int, len(items)),
|
||||
checked: make(map[int]bool),
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
m.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := m.itemIndex[name]; ok {
|
||||
m.checked[idx] = true
|
||||
m.checkOrder = append(m.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
}
|
||||
filterLower := strings.ToLower(m.filter)
|
||||
var result []SelectItem
|
||||
for _, item := range m.items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *multiSelectorModel) toggleItem() {
|
||||
filtered := m.filteredItems()
|
||||
if len(filtered) == 0 || m.cursor >= len(filtered) {
|
||||
return
|
||||
}
|
||||
|
||||
item := filtered[m.cursor]
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
if m.checked[origIdx] {
|
||||
delete(m.checked, origIdx)
|
||||
for i, idx := range m.checkOrder {
|
||||
if idx == origIdx {
|
||||
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
m.checked[origIdx] = true
|
||||
m.checkOrder = append(m.checkOrder, origIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) selectedCount() int {
|
||||
return len(m.checkOrder)
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
filtered := m.filteredItems()
|
||||
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyEnter:
|
||||
if len(m.checkOrder) > 0 {
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case tea.KeySpace:
|
||||
m.toggleItem()
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
if m.cursor < m.scrollOffset {
|
||||
m.scrollOffset = m.cursor
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
if m.cursor < len(filtered)-1 {
|
||||
m.cursor++
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
m.cursor -= maxSelectorItems
|
||||
if m.cursor < 0 {
|
||||
m.cursor = 0
|
||||
}
|
||||
m.scrollOffset -= maxSelectorItems
|
||||
if m.scrollOffset < 0 {
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyPgDown:
|
||||
m.cursor += maxSelectorItems
|
||||
if m.cursor >= len(filtered) {
|
||||
m.cursor = len(filtered) - 1
|
||||
}
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
|
||||
case tea.KeyBackspace:
|
||||
if len(m.filter) > 0 {
|
||||
m.filter = m.filter[:len(m.filter)-1]
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) View() string {
|
||||
if m.cancelled || m.confirmed {
|
||||
return ""
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
s.WriteString(" ")
|
||||
if m.filter == "" {
|
||||
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
|
||||
} else {
|
||||
s.WriteString(selectorInputStyle.Render(m.filter))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
filtered := m.filteredItems()
|
||||
|
||||
if len(filtered) == 0 {
|
||||
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
|
||||
s.WriteString("\n")
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxSelectorItems)
|
||||
shownRecHeader := false
|
||||
prevWasRec := false
|
||||
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
if m.filter == "" {
|
||||
if item.Recommended && !shownRecHeader {
|
||||
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
||||
s.WriteString("\n")
|
||||
shownRecHeader = true
|
||||
} else if !item.Recommended && prevWasRec {
|
||||
s.WriteString("\n")
|
||||
}
|
||||
prevWasRec = item.Recommended
|
||||
}
|
||||
|
||||
var checkbox string
|
||||
if m.checked[origIdx] {
|
||||
checkbox = selectorCheckboxCheckedStyle.Render("[x]")
|
||||
} else {
|
||||
checkbox = selectorCheckboxStyle.Render("[ ]")
|
||||
}
|
||||
|
||||
var line string
|
||||
if idx == m.cursor {
|
||||
line = selectorSelectedItemStyle.Render("▸ ") + checkbox + " " + selectorSelectedItemStyle.Render(item.Name)
|
||||
} else {
|
||||
line = " " + checkbox + " " + item.Name
|
||||
}
|
||||
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
|
||||
line += " " + selectorDefaultTagStyle.Render("(default)")
|
||||
}
|
||||
|
||||
s.WriteString(line)
|
||||
s.WriteString("\n")
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
s.WriteString("\n")
|
||||
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
|
||||
|
||||
result := s.String()
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(result)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := newMultiSelectorModel(title, items, preChecked)
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error running selector: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(multiSelectorModel)
|
||||
if fm.cancelled {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
if !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
var result []string
|
||||
for _, idx := range fm.checkOrder {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
410
cmd/tui/selector_test.go
Normal file
410
cmd/tui/selector_test.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
func items(names ...string) []SelectItem {
|
||||
var out []SelectItem
|
||||
for _, n := range names {
|
||||
out = append(out, SelectItem{Name: n})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func recItems(names ...string) []SelectItem {
|
||||
var out []SelectItem
|
||||
for _, n := range names {
|
||||
out = append(out, SelectItem{Name: n, Recommended: true})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mixedItems() []SelectItem {
|
||||
return []SelectItem{
|
||||
{Name: "rec-a", Recommended: true},
|
||||
{Name: "rec-b", Recommended: true},
|
||||
{Name: "other-1"},
|
||||
{Name: "other-2"},
|
||||
{Name: "other-3"},
|
||||
{Name: "other-4"},
|
||||
{Name: "other-5"},
|
||||
{Name: "other-6"},
|
||||
{Name: "other-7"},
|
||||
{Name: "other-8"},
|
||||
{Name: "other-9"},
|
||||
{Name: "other-10"},
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilteredItems(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
items []SelectItem
|
||||
filter string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "no filter returns all",
|
||||
items: items("alpha", "beta", "gamma"),
|
||||
filter: "",
|
||||
want: []string{"alpha", "beta", "gamma"},
|
||||
},
|
||||
{
|
||||
name: "filter matches substring",
|
||||
items: items("llama3.2", "qwen3:8b", "llama2"),
|
||||
filter: "llama",
|
||||
want: []string{"llama3.2", "llama2"},
|
||||
},
|
||||
{
|
||||
name: "filter is case insensitive",
|
||||
items: items("Qwen3:8b", "llama3.2"),
|
||||
filter: "QWEN",
|
||||
want: []string{"Qwen3:8b"},
|
||||
},
|
||||
{
|
||||
name: "no matches",
|
||||
items: items("alpha", "beta"),
|
||||
filter: "zzz",
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := selectorModel{items: tt.items, filter: tt.filter}
|
||||
got := m.filteredItems()
|
||||
var gotNames []string
|
||||
for _, item := range got {
|
||||
gotNames = append(gotNames, item.Name)
|
||||
}
|
||||
if len(gotNames) != len(tt.want) {
|
||||
t.Fatalf("got %v, want %v", gotNames, tt.want)
|
||||
}
|
||||
for i := range tt.want {
|
||||
if gotNames[i] != tt.want[i] {
|
||||
t.Errorf("index %d: got %q, want %q", i, gotNames[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOtherStart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
items []SelectItem
|
||||
filter string
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "all recommended",
|
||||
items: recItems("a", "b", "c"),
|
||||
want: 3,
|
||||
},
|
||||
{
|
||||
name: "none recommended",
|
||||
items: items("a", "b"),
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "mixed",
|
||||
items: []SelectItem{
|
||||
{Name: "rec-a", Recommended: true},
|
||||
{Name: "rec-b", Recommended: true},
|
||||
{Name: "other-1"},
|
||||
{Name: "other-2"},
|
||||
},
|
||||
want: 2,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
items: nil,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "filtering returns 0",
|
||||
items: []SelectItem{
|
||||
{Name: "rec-a", Recommended: true},
|
||||
{Name: "other-1"},
|
||||
},
|
||||
filter: "rec",
|
||||
want: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := selectorModel{items: tt.items, filter: tt.filter}
|
||||
if got := m.otherStart(); got != tt.want {
|
||||
t.Errorf("otherStart() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateScroll(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cursor int
|
||||
offset int
|
||||
otherStart int
|
||||
filter string
|
||||
wantOffset int
|
||||
}{
|
||||
{
|
||||
name: "cursor in recommended resets scroll",
|
||||
cursor: 1,
|
||||
offset: 5,
|
||||
otherStart: 3,
|
||||
wantOffset: 0,
|
||||
},
|
||||
{
|
||||
name: "cursor at start of others",
|
||||
cursor: 2,
|
||||
offset: 0,
|
||||
otherStart: 2,
|
||||
wantOffset: 0,
|
||||
},
|
||||
{
|
||||
name: "cursor scrolls down in others",
|
||||
cursor: 12,
|
||||
offset: 0,
|
||||
otherStart: 2,
|
||||
wantOffset: 3, // posInOthers=10, maxOthers=8, 10-8+1=3
|
||||
},
|
||||
{
|
||||
name: "cursor scrolls up in others",
|
||||
cursor: 4,
|
||||
offset: 5,
|
||||
otherStart: 2,
|
||||
wantOffset: 2, // posInOthers=2 < offset=5
|
||||
},
|
||||
{
|
||||
name: "filter mode standard scroll down",
|
||||
cursor: 12,
|
||||
offset: 0,
|
||||
filter: "x",
|
||||
otherStart: 0,
|
||||
wantOffset: 3, // 12 - 10 + 1 = 3
|
||||
},
|
||||
{
|
||||
name: "filter mode standard scroll up",
|
||||
cursor: 2,
|
||||
offset: 5,
|
||||
filter: "x",
|
||||
otherStart: 0,
|
||||
wantOffset: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := selectorModel{
|
||||
cursor: tt.cursor,
|
||||
scrollOffset: tt.offset,
|
||||
filter: tt.filter,
|
||||
}
|
||||
m.updateScroll(tt.otherStart)
|
||||
if m.scrollOffset != tt.wantOffset {
|
||||
t.Errorf("scrollOffset = %d, want %d", m.scrollOffset, tt.wantOffset)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_SectionHeaders(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: []SelectItem{
|
||||
{Name: "rec-a", Recommended: true},
|
||||
{Name: "other-1"},
|
||||
},
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
if !strings.Contains(content, "Recommended") {
|
||||
t.Error("should contain 'Recommended' header")
|
||||
}
|
||||
if !strings.Contains(content, "More") {
|
||||
t.Error("should contain 'More' header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_FilteredHeader(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: items("alpha", "beta", "alphabet"),
|
||||
filter: "alpha",
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
if !strings.Contains(content, "Top Results") {
|
||||
t.Error("filtered view should contain 'Top Results' header")
|
||||
}
|
||||
if strings.Contains(content, "Recommended") {
|
||||
t.Error("filtered view should not contain 'Recommended' header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_NoMatches(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: items("alpha"),
|
||||
filter: "zzz",
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
if !strings.Contains(content, "(no matches)") {
|
||||
t.Error("should show '(no matches)' when filter has no results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_SelectedItemIndicator(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: items("alpha", "beta"),
|
||||
cursor: 0,
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
if !strings.Contains(content, "▸") {
|
||||
t.Error("selected item should have ▸ indicator")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_Description(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: []SelectItem{
|
||||
{Name: "alpha", Description: "the first letter"},
|
||||
},
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
if !strings.Contains(content, "the first letter") {
|
||||
t.Error("should render item description")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_PinnedRecommended(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: mixedItems(),
|
||||
// cursor deep in "More" section
|
||||
cursor: 8,
|
||||
scrollOffset: 3,
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
// Recommended items should always be visible (pinned)
|
||||
if !strings.Contains(content, "rec-a") {
|
||||
t.Error("recommended items should always be rendered (pinned)")
|
||||
}
|
||||
if !strings.Contains(content, "rec-b") {
|
||||
t.Error("recommended items should always be rendered (pinned)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_MoreOverflowIndicator(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: mixedItems(), // 2 rec + 10 other = 12 total, maxSelectorItems=10
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
if !strings.Contains(content, "... and") {
|
||||
t.Error("should show overflow indicator when more items than visible")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateNavigation_CursorBounds(t *testing.T) {
|
||||
m := selectorModel{
|
||||
items: items("a", "b", "c"),
|
||||
cursor: 0,
|
||||
}
|
||||
|
||||
// Up at top stays at 0
|
||||
m.updateNavigation(keyMsg(KeyUp))
|
||||
if m.cursor != 0 {
|
||||
t.Errorf("cursor should stay at 0 when pressing up at top, got %d", m.cursor)
|
||||
}
|
||||
|
||||
// Down moves to 1
|
||||
m.updateNavigation(keyMsg(KeyDown))
|
||||
if m.cursor != 1 {
|
||||
t.Errorf("cursor should be 1 after down, got %d", m.cursor)
|
||||
}
|
||||
|
||||
// Down to end
|
||||
m.updateNavigation(keyMsg(KeyDown))
|
||||
m.updateNavigation(keyMsg(KeyDown))
|
||||
if m.cursor != 2 {
|
||||
t.Errorf("cursor should be 2 at bottom, got %d", m.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateNavigation_FilterResetsState(t *testing.T) {
|
||||
m := selectorModel{
|
||||
items: items("alpha", "beta"),
|
||||
cursor: 1,
|
||||
scrollOffset: 5,
|
||||
}
|
||||
|
||||
m.updateNavigation(runeMsg('x'))
|
||||
if m.filter != "x" {
|
||||
t.Errorf("filter should be 'x', got %q", m.filter)
|
||||
}
|
||||
if m.cursor != 0 {
|
||||
t.Errorf("cursor should reset to 0 on filter, got %d", m.cursor)
|
||||
}
|
||||
if m.scrollOffset != 0 {
|
||||
t.Errorf("scrollOffset should reset to 0 on filter, got %d", m.scrollOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateNavigation_Backspace(t *testing.T) {
|
||||
m := selectorModel{
|
||||
items: items("alpha"),
|
||||
filter: "abc",
|
||||
cursor: 1,
|
||||
}
|
||||
|
||||
m.updateNavigation(keyMsg(KeyBackspace))
|
||||
if m.filter != "ab" {
|
||||
t.Errorf("filter should be 'ab' after backspace, got %q", m.filter)
|
||||
}
|
||||
if m.cursor != 0 {
|
||||
t.Errorf("cursor should reset to 0 on backspace, got %d", m.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
// Key message helpers for testing
|
||||
|
||||
type keyType = int
|
||||
|
||||
const (
|
||||
KeyUp keyType = iota
|
||||
KeyDown keyType = iota
|
||||
KeyBackspace keyType = iota
|
||||
)
|
||||
|
||||
func keyMsg(k keyType) tea.KeyMsg {
|
||||
switch k {
|
||||
case KeyUp:
|
||||
return tea.KeyMsg{Type: tea.KeyUp}
|
||||
case KeyDown:
|
||||
return tea.KeyMsg{Type: tea.KeyDown}
|
||||
case KeyBackspace:
|
||||
return tea.KeyMsg{Type: tea.KeyBackspace}
|
||||
default:
|
||||
return tea.KeyMsg{}
|
||||
}
|
||||
}
|
||||
|
||||
func runeMsg(r rune) tea.KeyMsg {
|
||||
return tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{r}}
|
||||
}
|
||||
128
cmd/tui/signin.go
Normal file
128
cmd/tui/signin.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
)
|
||||
|
||||
type signInModel struct {
|
||||
modelName string
|
||||
signInURL string
|
||||
spinner int
|
||||
width int
|
||||
userName string
|
||||
cancelled bool
|
||||
}
|
||||
|
||||
func (m signInModel) Init() tea.Cmd {
|
||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func (m signInModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case signInTickMsg:
|
||||
m.spinner++
|
||||
if m.spinner%5 == 0 {
|
||||
return m, tea.Batch(
|
||||
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
}),
|
||||
checkSignIn,
|
||||
)
|
||||
}
|
||||
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
|
||||
case signInCheckMsg:
|
||||
if msg.signedIn {
|
||||
m.userName = msg.userName
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m signInModel) View() string {
|
||||
if m.userName != "" {
|
||||
return ""
|
||||
}
|
||||
return renderSignIn(m.modelName, m.signInURL, m.spinner, m.width)
|
||||
}
|
||||
|
||||
func renderSignIn(modelName, signInURL string, spinner, width int) string {
|
||||
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
|
||||
frame := spinnerFrames[spinner%len(spinnerFrames)]
|
||||
|
||||
urlColor := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("117"))
|
||||
urlWrap := lipgloss.NewStyle().PaddingLeft(2)
|
||||
if width > 4 {
|
||||
urlWrap = urlWrap.Width(width - 4)
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
fmt.Fprintf(&s, "To use %s, please sign in.\n\n", selectorSelectedItemStyle.Render(modelName))
|
||||
|
||||
// Wrap in OSC 8 hyperlink so the entire URL is clickable even when wrapped.
|
||||
// Padding is outside the hyperlink so spaces don't get underlined.
|
||||
link := fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", signInURL, urlColor.Render(signInURL))
|
||||
s.WriteString("Navigate to:\n")
|
||||
s.WriteString(urlWrap.Render(link))
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
|
||||
frame + " Waiting for sign in to complete..."))
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(selectorHelpStyle.Render("esc cancel"))
|
||||
|
||||
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
|
||||
}
|
||||
|
||||
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
|
||||
func RunSignIn(modelName, signInURL string) (string, error) {
|
||||
config.OpenBrowser(signInURL)
|
||||
|
||||
m := signInModel{
|
||||
modelName: modelName,
|
||||
signInURL: signInURL,
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error running sign-in: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(signInModel)
|
||||
if fm.cancelled {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
|
||||
return fm.userName, nil
|
||||
}
|
||||
175
cmd/tui/signin_test.go
Normal file
175
cmd/tui/signin_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
func TestRenderSignIn_ContainsModelName(t *testing.T) {
|
||||
got := renderSignIn("glm-4.7:cloud", "https://example.com/signin", 0, 80)
|
||||
if !strings.Contains(got, "glm-4.7:cloud") {
|
||||
t.Error("should contain model name")
|
||||
}
|
||||
if !strings.Contains(got, "please sign in") {
|
||||
t.Error("should contain sign-in prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_ContainsURL(t *testing.T) {
|
||||
url := "https://ollama.com/connect?key=abc123"
|
||||
got := renderSignIn("test:cloud", url, 0, 120)
|
||||
if !strings.Contains(got, url) {
|
||||
t.Errorf("should contain URL %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_OSC8Hyperlink(t *testing.T) {
|
||||
url := "https://ollama.com/connect?key=abc123"
|
||||
got := renderSignIn("test:cloud", url, 0, 120)
|
||||
|
||||
// Should contain OSC 8 open sequence with the URL
|
||||
osc8Open := "\033]8;;" + url + "\033\\"
|
||||
if !strings.Contains(got, osc8Open) {
|
||||
t.Error("should contain OSC 8 open sequence with URL")
|
||||
}
|
||||
|
||||
// Should contain OSC 8 close sequence
|
||||
osc8Close := "\033]8;;\033\\"
|
||||
if !strings.Contains(got, osc8Close) {
|
||||
t.Error("should contain OSC 8 close sequence")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_ContainsSpinner(t *testing.T) {
|
||||
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
|
||||
if !strings.Contains(got, "Waiting for sign in to complete") {
|
||||
t.Error("should contain waiting message")
|
||||
}
|
||||
if !strings.Contains(got, "⠋") {
|
||||
t.Error("should contain first spinner frame at spinner=0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_SpinnerAdvances(t *testing.T) {
|
||||
got0 := renderSignIn("test:cloud", "https://example.com", 0, 80)
|
||||
got1 := renderSignIn("test:cloud", "https://example.com", 1, 80)
|
||||
if got0 == got1 {
|
||||
t.Error("different spinner values should produce different output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_ContainsEscHelp(t *testing.T) {
|
||||
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
|
||||
if !strings.Contains(got, "esc cancel") {
|
||||
t.Error("should contain esc cancel help text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_EscCancels(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEsc})
|
||||
fm := updated.(signInModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("esc should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("esc should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_CtrlCCancels(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
|
||||
fm := updated.(signInModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("ctrl+c should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("ctrl+c should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_SignedInQuitsClean(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, cmd := m.Update(signInCheckMsg{signedIn: true, userName: "alice"})
|
||||
fm := updated.(signInModel)
|
||||
if fm.userName != "alice" {
|
||||
t.Errorf("expected userName 'alice', got %q", fm.userName)
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("successful sign-in should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_SignedInViewClears(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
userName: "alice",
|
||||
}
|
||||
|
||||
got := m.View()
|
||||
if got != "" {
|
||||
t.Errorf("View should return empty string after sign-in, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_NotSignedInContinues(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, _ := m.Update(signInCheckMsg{signedIn: false})
|
||||
fm := updated.(signInModel)
|
||||
if fm.userName != "" {
|
||||
t.Error("should not set userName when not signed in")
|
||||
}
|
||||
if fm.cancelled {
|
||||
t.Error("should not cancel when check returns not signed in")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_WindowSizeUpdatesWidth(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
|
||||
fm := updated.(signInModel)
|
||||
if fm.width != 120 {
|
||||
t.Errorf("expected width 120, got %d", fm.width)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_TickAdvancesSpinner(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
spinner: 0,
|
||||
}
|
||||
|
||||
updated, cmd := m.Update(signInTickMsg{})
|
||||
fm := updated.(signInModel)
|
||||
if fm.spinner != 1 {
|
||||
t.Errorf("expected spinner=1, got %d", fm.spinner)
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("tick should return a command")
|
||||
}
|
||||
}
|
||||
603
cmd/tui/tui.go
Normal file
603
cmd/tui/tui.go
Normal file
@@ -0,0 +1,603 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
var (
|
||||
versionStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "243", Dark: "250"})
|
||||
|
||||
menuItemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2)
|
||||
|
||||
menuSelectedItemStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
|
||||
|
||||
menuDescStyle = selectorDescStyle.
|
||||
PaddingLeft(4)
|
||||
|
||||
greyedStyle = menuItemStyle.
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
|
||||
greyedSelectedStyle = menuSelectedItemStyle.
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
|
||||
modelStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "243", Dark: "250"})
|
||||
|
||||
notInstalledStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
|
||||
Italic(true)
|
||||
)
|
||||
|
||||
type menuItem struct {
|
||||
title string
|
||||
description string
|
||||
integration string // integration name for loading model config, empty if not an integration
|
||||
isRunModel bool
|
||||
isOthers bool
|
||||
}
|
||||
|
||||
var mainMenuItems = []menuItem{
|
||||
{
|
||||
title: "Run a model",
|
||||
description: "Start an interactive chat with a model",
|
||||
isRunModel: true,
|
||||
},
|
||||
{
|
||||
title: "Launch Claude Code",
|
||||
description: "Agentic coding across large codebases",
|
||||
integration: "claude",
|
||||
},
|
||||
{
|
||||
title: "Launch Codex",
|
||||
description: "OpenAI's open-source coding agent",
|
||||
integration: "codex",
|
||||
},
|
||||
{
|
||||
title: "Launch OpenClaw",
|
||||
description: "Personal AI with 100+ skills",
|
||||
integration: "openclaw",
|
||||
},
|
||||
}
|
||||
|
||||
var othersMenuItem = menuItem{
|
||||
title: "More...",
|
||||
description: "Show additional integrations",
|
||||
isOthers: true,
|
||||
}
|
||||
|
||||
// getOtherIntegrations dynamically builds the "Others" list from the integration
|
||||
// registry, excluding any integrations already present in the pinned mainMenuItems.
|
||||
func getOtherIntegrations() []menuItem {
|
||||
pinned := map[string]bool{
|
||||
"run": true, // not an integration but in the pinned list
|
||||
}
|
||||
for _, item := range mainMenuItems {
|
||||
if item.integration != "" {
|
||||
pinned[item.integration] = true
|
||||
}
|
||||
}
|
||||
|
||||
var others []menuItem
|
||||
for _, info := range config.ListIntegrationInfos() {
|
||||
if pinned[info.Name] {
|
||||
continue
|
||||
}
|
||||
desc := info.Description
|
||||
if desc == "" {
|
||||
desc = "Open " + info.DisplayName + " integration"
|
||||
}
|
||||
others = append(others, menuItem{
|
||||
title: "Launch " + info.DisplayName,
|
||||
description: desc,
|
||||
integration: info.Name,
|
||||
})
|
||||
}
|
||||
return others
|
||||
}
|
||||
|
||||
type model struct {
|
||||
items []menuItem
|
||||
cursor int
|
||||
quitting bool
|
||||
selected bool
|
||||
changeModel bool
|
||||
showOthers bool
|
||||
availableModels map[string]bool
|
||||
err error
|
||||
|
||||
showingModal bool
|
||||
modalSelector selectorModel
|
||||
modalItems []SelectItem
|
||||
|
||||
showingSignIn bool
|
||||
signInURL string
|
||||
signInModel string
|
||||
signInSpinner int
|
||||
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
|
||||
|
||||
width int // terminal width from WindowSizeMsg
|
||||
statusMsg string // temporary status message shown near help text
|
||||
}
|
||||
|
||||
type signInTickMsg struct{}
|
||||
|
||||
type signInCheckMsg struct {
|
||||
signedIn bool
|
||||
userName string
|
||||
}
|
||||
|
||||
type clearStatusMsg struct{}
|
||||
|
||||
func (m *model) modelExists(name string) bool {
|
||||
if m.availableModels == nil || name == "" {
|
||||
return false
|
||||
}
|
||||
if m.availableModels[name] {
|
||||
return true
|
||||
}
|
||||
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
|
||||
for modelName := range m.availableModels {
|
||||
if strings.HasPrefix(modelName, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *model) buildModalItems() []SelectItem {
|
||||
modelItems, _ := config.GetModelItems(context.Background())
|
||||
var items []SelectItem
|
||||
for _, item := range modelItems {
|
||||
items = append(items, SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended})
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func (m *model) openModelModal() {
|
||||
m.modalItems = m.buildModalItems()
|
||||
m.modalSelector = selectorModel{
|
||||
title: "Select model:",
|
||||
items: m.modalItems,
|
||||
helpText: "↑/↓ navigate • enter select • ← back",
|
||||
}
|
||||
m.showingModal = true
|
||||
}
|
||||
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud")
|
||||
}
|
||||
|
||||
// checkCloudSignIn checks if a cloud model needs sign-in.
|
||||
// Returns a command to start sign-in if needed, or nil if already signed in.
|
||||
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
|
||||
if modelName == "" || !isCloudModel(modelName) {
|
||||
return nil
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
||||
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startSignIn initiates the sign-in flow for a cloud model.
|
||||
// fromModal indicates if this was triggered from the model picker modal.
|
||||
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
|
||||
m.showingModal = false
|
||||
m.showingSignIn = true
|
||||
m.signInURL = signInURL
|
||||
m.signInModel = modelName
|
||||
m.signInSpinner = 0
|
||||
m.signInFromModal = fromModal
|
||||
|
||||
config.OpenBrowser(signInURL)
|
||||
|
||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func checkSignIn() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||
}
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
|
||||
func (m *model) loadAvailableModels() {
|
||||
m.availableModels = make(map[string]bool)
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
models, err := client.List(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, mdl := range models.Models {
|
||||
m.availableModels[mdl.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) buildItems() {
|
||||
others := getOtherIntegrations()
|
||||
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
|
||||
m.items = append(m.items, mainMenuItems...)
|
||||
|
||||
if m.showOthers {
|
||||
m.items = append(m.items, others...)
|
||||
} else {
|
||||
m.items = append(m.items, othersMenuItem)
|
||||
}
|
||||
}
|
||||
|
||||
func isOthersIntegration(name string) bool {
|
||||
for _, item := range getOtherIntegrations() {
|
||||
if item.integration == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func initialModel() model {
|
||||
m := model{
|
||||
cursor: 0,
|
||||
}
|
||||
m.loadAvailableModels()
|
||||
|
||||
lastSelection := config.LastSelection()
|
||||
if isOthersIntegration(lastSelection) {
|
||||
m.showOthers = true
|
||||
}
|
||||
|
||||
m.buildItems()
|
||||
|
||||
if lastSelection != "" {
|
||||
for i, item := range m.items {
|
||||
if lastSelection == "run" && item.isRunModel {
|
||||
m.cursor = i
|
||||
break
|
||||
} else if item.integration == lastSelection {
|
||||
m.cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m model) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if wmsg, ok := msg.(tea.WindowSizeMsg); ok {
|
||||
wasSet := m.width > 0
|
||||
m.width = wmsg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if _, ok := msg.(clearStatusMsg); ok {
|
||||
m.statusMsg = ""
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.showingSignIn {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.showingSignIn = false
|
||||
if m.signInFromModal {
|
||||
m.showingModal = true
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case signInTickMsg:
|
||||
m.signInSpinner++
|
||||
// Check sign-in status every 5th tick (~1 second)
|
||||
if m.signInSpinner%5 == 0 {
|
||||
return m, tea.Batch(
|
||||
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
}),
|
||||
checkSignIn,
|
||||
)
|
||||
}
|
||||
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
|
||||
case signInCheckMsg:
|
||||
if msg.signedIn {
|
||||
if m.signInFromModal {
|
||||
m.modalSelector.selected = m.signInModel
|
||||
m.changeModel = true
|
||||
} else {
|
||||
m.selected = true
|
||||
}
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.showingModal {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc, tea.KeyLeft:
|
||||
m.showingModal = false
|
||||
return m, nil
|
||||
|
||||
case tea.KeyEnter:
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
|
||||
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
|
||||
}
|
||||
if m.modalSelector.selected != "" {
|
||||
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
m.changeModel = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
|
||||
default:
|
||||
// Delegate navigation (up/down/pgup/pgdown/filter/backspace) to selectorModel
|
||||
m.modalSelector.updateNavigation(msg)
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "q", "esc":
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
|
||||
case "up", "k":
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
}
|
||||
// Auto-collapse "Others" when cursor moves back into pinned items
|
||||
if m.showOthers && m.cursor < len(mainMenuItems) {
|
||||
m.showOthers = false
|
||||
m.buildItems()
|
||||
}
|
||||
|
||||
case "down", "j":
|
||||
if m.cursor < len(m.items)-1 {
|
||||
m.cursor++
|
||||
}
|
||||
// Auto-expand "Others..." when cursor lands on it
|
||||
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
|
||||
m.showOthers = true
|
||||
m.buildItems()
|
||||
// cursor now points at the first "other" integration
|
||||
}
|
||||
|
||||
case "enter", " ":
|
||||
item := m.items[m.cursor]
|
||||
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var configuredModel string
|
||||
if item.isRunModel {
|
||||
configuredModel = config.LastModel()
|
||||
} else if item.integration != "" {
|
||||
configuredModel = config.IntegrationModel(item.integration)
|
||||
}
|
||||
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
|
||||
case "right", "l":
|
||||
item := m.items[m.cursor]
|
||||
if item.integration != "" || item.isRunModel {
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
return m, nil
|
||||
}
|
||||
m.openModelModal()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) View() string {
|
||||
if m.quitting {
|
||||
return ""
|
||||
}
|
||||
|
||||
if m.showingSignIn {
|
||||
return m.renderSignInDialog()
|
||||
}
|
||||
|
||||
if m.showingModal {
|
||||
return m.renderModal()
|
||||
}
|
||||
|
||||
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
|
||||
|
||||
for i, item := range m.items {
|
||||
cursor := ""
|
||||
style := menuItemStyle
|
||||
isInstalled := true
|
||||
|
||||
if item.integration != "" {
|
||||
isInstalled = config.IsIntegrationInstalled(item.integration)
|
||||
}
|
||||
|
||||
if m.cursor == i {
|
||||
cursor = "▸ "
|
||||
if isInstalled {
|
||||
style = menuSelectedItemStyle
|
||||
} else {
|
||||
style = greyedSelectedStyle
|
||||
}
|
||||
} else if !isInstalled && item.integration != "" {
|
||||
style = greyedStyle
|
||||
}
|
||||
|
||||
title := item.title
|
||||
var modelSuffix string
|
||||
if item.integration != "" {
|
||||
if !isInstalled {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
} else if m.cursor == i {
|
||||
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
|
||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
}
|
||||
} else if item.isRunModel && m.cursor == i {
|
||||
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
|
||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
}
|
||||
|
||||
s += style.Render(cursor+title) + modelSuffix + "\n"
|
||||
|
||||
desc := item.description
|
||||
if !isInstalled && item.integration != "" && m.cursor == i {
|
||||
if hint := config.IntegrationInstallHint(item.integration); hint != "" {
|
||||
desc = hint
|
||||
} else {
|
||||
desc = "not installed"
|
||||
}
|
||||
}
|
||||
s += menuDescStyle.Render(desc) + "\n\n"
|
||||
}
|
||||
|
||||
if m.statusMsg != "" {
|
||||
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "124", Dark: "210"}).Render(m.statusMsg) + "\n"
|
||||
}
|
||||
|
||||
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → change model • esc quit")
|
||||
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m model) renderModal() string {
|
||||
modalStyle := lipgloss.NewStyle().
|
||||
PaddingBottom(1).
|
||||
PaddingRight(2)
|
||||
|
||||
s := modalStyle.Render(m.modalSelector.renderContent())
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m model) renderSignInDialog() string {
|
||||
return renderSignIn(m.signInModel, m.signInURL, m.signInSpinner, m.width)
|
||||
}
|
||||
|
||||
type Selection int
|
||||
|
||||
const (
|
||||
SelectionNone Selection = iota
|
||||
SelectionRunModel
|
||||
SelectionChangeRunModel
|
||||
SelectionIntegration // Generic integration selection
|
||||
SelectionChangeIntegration // Generic change model for integration
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
Selection Selection
|
||||
Integration string // integration name if applicable
|
||||
Model string // model name if selected from modal
|
||||
}
|
||||
|
||||
func Run() (Result, error) {
|
||||
m := initialModel()
|
||||
p := tea.NewProgram(m)
|
||||
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(model)
|
||||
if fm.err != nil {
|
||||
return Result{Selection: SelectionNone}, fm.err
|
||||
}
|
||||
|
||||
if !fm.selected && !fm.changeModel {
|
||||
return Result{Selection: SelectionNone}, nil
|
||||
}
|
||||
|
||||
item := fm.items[fm.cursor]
|
||||
|
||||
if fm.changeModel {
|
||||
if item.isRunModel {
|
||||
return Result{
|
||||
Selection: SelectionChangeRunModel,
|
||||
Model: fm.modalSelector.selected,
|
||||
}, nil
|
||||
}
|
||||
return Result{
|
||||
Selection: SelectionChangeIntegration,
|
||||
Integration: item.integration,
|
||||
Model: fm.modalSelector.selected,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if item.isRunModel {
|
||||
return Result{Selection: SelectionRunModel}, nil
|
||||
}
|
||||
|
||||
return Result{
|
||||
Selection: SelectionIntegration,
|
||||
Integration: item.integration,
|
||||
}, nil
|
||||
}
|
||||
@@ -313,8 +313,12 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &deepseek2Model{}
|
||||
case "Glm4MoeLiteForCausalLM":
|
||||
conv = &glm4MoeLiteModel{}
|
||||
case "GlmOcrForConditionalGeneration":
|
||||
conv = &glmOcrModel{}
|
||||
case "Lfm2ForCausalLM":
|
||||
conv = &lfm2Model{}
|
||||
case "Qwen3NextForCausalLM":
|
||||
conv = &qwen3NextModel{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
455
convert/convert_glmocr.go
Normal file
455
convert/convert_glmocr.go
Normal file
@@ -0,0 +1,455 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
// normalToNeoXRepacker creates a repacker that permutes Q/K weights from interleaved (LLaMA)
|
||||
// to NeoX ordering for compatibility with GGML's M-RoPE kernel.
|
||||
//
|
||||
// For weights: reshape [out, in] -> [n_heads, head_dim, in], permute rotary dims, reshape back
|
||||
// For biases: reshape [out] -> [n_heads, head_dim], permute rotary dims, reshape back
|
||||
func normalToNeoXRepacker(nHeads, headDim int, partialRotaryFactor float32) func(string, []float32, []uint64) ([]float32, error) {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
rotaryDim := int(float32(headDim) * partialRotaryFactor)
|
||||
if rotaryDim%2 != 0 {
|
||||
rotaryDim = (rotaryDim / 2) * 2 // Round down to even
|
||||
}
|
||||
|
||||
// Handle 1D (bias) or 2D (weight) tensors
|
||||
is1D := len(shape) == 1
|
||||
var inFeatures int
|
||||
if is1D {
|
||||
inFeatures = 1
|
||||
} else {
|
||||
inFeatures = int(shape[1])
|
||||
}
|
||||
outFeatures := int(shape[0])
|
||||
nEffectiveHeads := outFeatures / headDim
|
||||
|
||||
if nEffectiveHeads != nHeads {
|
||||
slog.Warn("normalToNeoX: unexpected head count", "effective", nEffectiveHeads, "expected", nHeads)
|
||||
}
|
||||
|
||||
// Reshape to [n_heads, head_dim, in_features]
|
||||
reshaped := make([]float32, len(data))
|
||||
copy(reshaped, data)
|
||||
|
||||
// Permute the rotary dimensions: even indices first, then odd
|
||||
// For each head, reorder [0,1,2,3,4,5...] to [0,2,4...,1,3,5...]
|
||||
result := make([]float32, len(data))
|
||||
halfRotary := rotaryDim / 2
|
||||
|
||||
for h := range nEffectiveHeads {
|
||||
for f := range inFeatures {
|
||||
for i := range halfRotary {
|
||||
// Even dim (0, 2, 4, ...) -> position i
|
||||
srcIdx := h*headDim*inFeatures + (2*i)*inFeatures + f
|
||||
dstIdx := h*headDim*inFeatures + i*inFeatures + f
|
||||
result[dstIdx] = reshaped[srcIdx]
|
||||
|
||||
// Odd dim (1, 3, 5, ...) -> position halfRotary + i
|
||||
srcIdx = h*headDim*inFeatures + (2*i+1)*inFeatures + f
|
||||
dstIdx = h*headDim*inFeatures + (halfRotary+i)*inFeatures + f
|
||||
result[dstIdx] = reshaped[srcIdx]
|
||||
}
|
||||
|
||||
// Non-rotary part: copy as-is
|
||||
for i := rotaryDim; i < headDim; i++ {
|
||||
srcIdx := h*headDim*inFeatures + i*inFeatures + f
|
||||
result[srcIdx] = reshaped[srcIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
type glmOcrModel struct {
|
||||
ModelParameters
|
||||
|
||||
TextConfig struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
MaxPositionEmbed uint32 `json:"max_position_embeddings"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeParameters struct {
|
||||
RopeType string `json:"rope_type"`
|
||||
MRopeSection []int32 `json:"mrope_section"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
|
||||
VisionConfig struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
Depth uint32 `json:"depth"`
|
||||
NumHeads uint32 `json:"num_heads"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
OutHiddenSize uint32 `json:"out_hidden_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
} `json:"vision_config"`
|
||||
|
||||
ImageStartTokenID uint32 `json:"image_start_token_id"`
|
||||
ImageEndTokenID uint32 `json:"image_end_token_id"`
|
||||
VideoStartTokenID uint32 `json:"video_start_token_id"`
|
||||
VideoEndTokenID uint32 `json:"video_end_token_id"`
|
||||
ImageTokenID uint32 `json:"image_token_id"`
|
||||
VideoTokenID uint32 `json:"video_token_id"`
|
||||
|
||||
// Preprocessor config (preprocessor_config.json)
|
||||
Preprocessor struct {
|
||||
Size struct {
|
||||
ShortestEdge uint32 `json:"shortest_edge"`
|
||||
LongestEdge uint32 `json:"longest_edge"`
|
||||
} `json:"size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
MergeSize uint32 `json:"merge_size"`
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
} `json:"-"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*glmOcrModel)(nil)
|
||||
|
||||
func (m *glmOcrModel) parseMore(fsys fs.FS) error {
|
||||
bts, err := fs.ReadFile(fsys, "preprocessor_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(bts, &m.Preprocessor)
|
||||
}
|
||||
|
||||
func (m *glmOcrModel) KV(t *Tokenizer) KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "glmocr"
|
||||
|
||||
// Text model parameters
|
||||
kv["glmocr.block_count"] = cmp.Or(m.TextConfig.NumHiddenLayers, 16)
|
||||
kv["glmocr.embedding_length"] = cmp.Or(m.TextConfig.HiddenSize, 1536)
|
||||
kv["glmocr.attention.head_count"] = cmp.Or(m.TextConfig.NumAttentionHeads, 16)
|
||||
kv["glmocr.attention.head_count_kv"] = cmp.Or(m.TextConfig.NumKeyValueHeads, 8)
|
||||
headDim := cmp.Or(m.TextConfig.HeadDim, m.TextConfig.HiddenSize/m.TextConfig.NumAttentionHeads)
|
||||
kv["glmocr.attention.key_length"] = headDim
|
||||
kv["glmocr.attention.value_length"] = headDim
|
||||
kv["glmocr.feed_forward_length"] = cmp.Or(m.TextConfig.IntermediateSize, 4608)
|
||||
kv["glmocr.attention.layer_norm_rms_epsilon"] = cmp.Or(m.TextConfig.RMSNormEps, 1e-5)
|
||||
kv["glmocr.context_length"] = cmp.Or(m.TextConfig.MaxPositionEmbed, 131072)
|
||||
kv["glmocr.rope.freq_base"] = cmp.Or(m.TextConfig.RopeParameters.RopeTheta, float32(10000))
|
||||
kv["glmocr.rope.partial_rotary_factor"] = cmp.Or(m.TextConfig.RopeParameters.PartialRotaryFactor, m.TextConfig.PartialRotaryFactor, float32(1.0))
|
||||
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 {
|
||||
kv["glmocr.rope.mrope_section"] = m.TextConfig.RopeParameters.MRopeSection
|
||||
}
|
||||
|
||||
// Vision model parameters
|
||||
kv["glmocr.vision.block_count"] = cmp.Or(m.VisionConfig.Depth, 24)
|
||||
kv["glmocr.vision.embedding_length"] = cmp.Or(m.VisionConfig.HiddenSize, 1024)
|
||||
kv["glmocr.vision.attention.head_count"] = cmp.Or(m.VisionConfig.NumHeads, 16)
|
||||
kv["glmocr.vision.image_size"] = cmp.Or(m.VisionConfig.ImageSize, 336)
|
||||
kv["glmocr.vision.patch_size"] = cmp.Or(m.VisionConfig.PatchSize, m.Preprocessor.PatchSize, 14)
|
||||
kv["glmocr.vision.spatial_merge_size"] = cmp.Or(m.VisionConfig.SpatialMergeSize, m.Preprocessor.MergeSize, 2)
|
||||
kv["glmocr.vision.temporal_patch_size"] = cmp.Or(m.VisionConfig.TemporalPatchSize, m.Preprocessor.TemporalPatchSize, 2)
|
||||
kv["glmocr.vision.out_hidden_size"] = cmp.Or(m.VisionConfig.OutHiddenSize, 1536)
|
||||
kv["glmocr.vision.intermediate_size"] = cmp.Or(m.VisionConfig.IntermediateSize, 4096)
|
||||
kv["glmocr.vision.attention.layer_norm_rms_epsilon"] = cmp.Or(m.VisionConfig.RMSNormEps, 1e-5)
|
||||
|
||||
// Preprocessor-derived image settings (min/max pixels and normalization)
|
||||
// Note: fs.Config.keyValue() auto-prepends architecture prefix, so use full key
|
||||
if m.Preprocessor.Size.ShortestEdge > 0 {
|
||||
kv["glmocr.vision.min_pixels"] = m.Preprocessor.Size.ShortestEdge
|
||||
}
|
||||
if m.Preprocessor.Size.LongestEdge > 0 {
|
||||
kv["glmocr.vision.max_pixels"] = m.Preprocessor.Size.LongestEdge
|
||||
}
|
||||
if len(m.Preprocessor.ImageMean) == 3 {
|
||||
kv["glmocr.vision.image_mean"] = m.Preprocessor.ImageMean
|
||||
}
|
||||
if len(m.Preprocessor.ImageStd) == 3 {
|
||||
kv["glmocr.vision.image_std"] = m.Preprocessor.ImageStd
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
kv["glmocr.image_token_id"] = m.ImageTokenID
|
||||
kv["glmocr.image_start_token_id"] = m.ImageStartTokenID
|
||||
kv["glmocr.image_end_token_id"] = m.ImageEndTokenID
|
||||
kv["glmocr.video_token_id"] = m.VideoTokenID
|
||||
kv["glmocr.video_start_token_id"] = m.VideoStartTokenID
|
||||
kv["glmocr.video_end_token_id"] = m.VideoEndTokenID
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *glmOcrModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
// Skip layers >= num_hidden_layers (Multi-Token Prediction layers not needed for basic inference)
|
||||
numLayers := int(cmp.Or(m.TextConfig.NumHiddenLayers, 16))
|
||||
skipLayer := func(name string) bool {
|
||||
// Tensor names are already replaced to "blk.N.xxx" format
|
||||
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(name)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
blkNum, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return blkNum >= numLayers
|
||||
}
|
||||
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
|
||||
// Skip next-n prediction layers (layers >= num_hidden_layers)
|
||||
if skipLayer(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Split ffn_gate_up into separate gate and up projections
|
||||
if strings.Contains(name, "ffn_gate_up") {
|
||||
for t := range splitDim(t, 0,
|
||||
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_gate")},
|
||||
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_up")},
|
||||
) {
|
||||
out = append(out, t)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasSuffix(name, "patch_embd.weight") {
|
||||
shape := t.Shape()
|
||||
if len(shape) == 5 && shape[2] == 2 {
|
||||
newShape := []uint64{shape[0], shape[1], shape[3], shape[4]}
|
||||
|
||||
t0 := t.Clone()
|
||||
t0.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
tt, err := tt.Slice(nil, nil, tensor.S(0, 1), nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
|
||||
if err := tt.Reshape(newDims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t0,
|
||||
})
|
||||
|
||||
t1 := t.Clone()
|
||||
t1.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
tt, err := tt.Slice(nil, nil, tensor.S(1, 2), nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
|
||||
if err := tt.Reshape(newDims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_1.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t1,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if len(shape) == 4 {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Warn("glmocr: patch_embed weight has unexpected shape - not splitting", "shape", shape)
|
||||
// Fall through to default handling
|
||||
}
|
||||
|
||||
// Handle pre-split patch embedding weights
|
||||
// Pattern 1: v.patch_embd.0.weight, v.patch_embd.1.weight -> patch_embd_0.weight, patch_embd_1.weight
|
||||
// Pattern 2: v.patch_embd.weight.0, v.patch_embd.weight.1 -> patch_embd_0.weight, patch_embd_1.weight
|
||||
if strings.Contains(name, "patch_embd.0.") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.0.", "patch_embd_0.", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if strings.Contains(name, "patch_embd.1.") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.1.", "patch_embd_1.", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
// Handle .weight.0 and .weight.1 suffix patterns
|
||||
if strings.HasSuffix(name, "patch_embd.weight.0") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight.0", "patch_embd_0.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, "patch_embd.weight.1") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight.1", "patch_embd_1.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Permute Q/K weights for M-RoPE compatibility (interleaved -> NeoX ordering)
|
||||
// GGML's M-RoPE kernel uses NeoX-style rotation, but GLM-OCR uses interleaved (LLaMA-style)
|
||||
// We permute at conversion time so the weights work correctly with GGML's kernel
|
||||
// This aligns Q/K rotary dimensions with GGML's NeoX-style rotation
|
||||
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 &&
|
||||
strings.Contains(name, "blk.") && (strings.Contains(name, "attn_q.") || strings.Contains(name, "attn_k.")) {
|
||||
// Get config values for permutation
|
||||
nHeads := int(cmp.Or(m.TextConfig.NumAttentionHeads, 16))
|
||||
nKVHeads := int(cmp.Or(m.TextConfig.NumKeyValueHeads, 8))
|
||||
hiddenSize := int(cmp.Or(m.TextConfig.HiddenSize, 1536))
|
||||
headDim := int(cmp.Or(m.TextConfig.HeadDim, uint32(hiddenSize/nHeads)))
|
||||
partialRotaryFactor := cmp.Or(m.TextConfig.PartialRotaryFactor, m.TextConfig.RopeParameters.PartialRotaryFactor, float32(1.0))
|
||||
|
||||
// Use appropriate head count: nHeads for Q, nKVHeads for K
|
||||
effectiveHeads := nHeads
|
||||
if strings.Contains(name, "attn_k.") {
|
||||
effectiveHeads = nKVHeads
|
||||
}
|
||||
|
||||
permutedT := t.Clone()
|
||||
permutedT.SetRepacker(normalToNeoXRepacker(effectiveHeads, headDim, partialRotaryFactor))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: permutedT,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *glmOcrModel) Replacements() []string {
|
||||
return []string{
|
||||
// Vision encoder
|
||||
"model.visual.patch_embed.proj_1", "v.patch_embd_1", // Second temporal split
|
||||
"model.visual.patch_embed.proj", "v.patch_embd",
|
||||
"model.visual.blocks", "v.blk",
|
||||
"model.visual.post_layernorm", "v.post_ln",
|
||||
"model.visual.downsample", "mm.patch_merger",
|
||||
|
||||
// Vision attention
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.proj", "attn_out",
|
||||
"attn.q_norm", "attn_q_norm",
|
||||
"attn.k_norm", "attn_k_norm",
|
||||
|
||||
// Vision norms
|
||||
"norm1", "ln1",
|
||||
"norm2", "ln2",
|
||||
|
||||
// Vision MLP
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
|
||||
// Merger (multimodal projector)
|
||||
"model.visual.merger.proj", "mm.model.fc",
|
||||
"model.visual.merger.post_projection_norm", "mm.post_norm",
|
||||
"model.visual.merger.gate_proj", "mm.gate",
|
||||
"model.visual.merger.up_proj", "mm.up",
|
||||
"model.visual.merger.down_proj", "mm.down",
|
||||
|
||||
// Language model
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.layers", "blk",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"lm_head", "output",
|
||||
|
||||
// Language model attention
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_out",
|
||||
|
||||
// Language model norms
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"post_self_attn_layernorm", "post_attn_norm",
|
||||
"post_mlp_layernorm", "post_ffn_norm",
|
||||
|
||||
// Language model MLP (remove mlp. prefix so ffn_* names work)
|
||||
"mlp.gate_up_proj", "ffn_gate_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
}
|
||||
}
|
||||
512
convert/convert_qwen3next.go
Normal file
512
convert/convert_qwen3next.go
Normal file
@@ -0,0 +1,512 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type qwen3NextModel struct {
|
||||
ModelParameters
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
|
||||
// MoE config
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
NormTopkProb bool `json:"norm_topk_prob"`
|
||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
|
||||
|
||||
// Hybrid attention config
|
||||
FullAttentionInterval uint32 `json:"full_attention_interval"`
|
||||
|
||||
// Linear attention (Gated Delta Net) config
|
||||
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
|
||||
LinearKeyHeadDim uint32 `json:"linear_key_head_dim"`
|
||||
LinearNumKeyHeads uint32 `json:"linear_num_key_heads"`
|
||||
LinearNumValueHeads uint32 `json:"linear_num_value_heads"`
|
||||
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
|
||||
|
||||
// RoPE config
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeScaling struct {
|
||||
Type string `json:"type"`
|
||||
Factor ropeFactor `json:"factor"`
|
||||
} `json:"rope_scaling"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*qwen3NextModel)(nil)
|
||||
|
||||
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
|
||||
if q.NumHiddenLayers == 0 {
|
||||
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
|
||||
}
|
||||
if q.NumAttentionHeads == 0 {
|
||||
return fmt.Errorf("qwen3next: num_attention_heads must be set")
|
||||
}
|
||||
if q.NumKeyValueHeads == 0 {
|
||||
return fmt.Errorf("qwen3next: num_key_value_heads must be set")
|
||||
}
|
||||
if q.HeadDim == 0 {
|
||||
return fmt.Errorf("qwen3next: head_dim must be set")
|
||||
}
|
||||
if q.RopeTheta == 0 {
|
||||
return fmt.Errorf("qwen3next: rope_theta must be set")
|
||||
}
|
||||
if q.PartialRotaryFactor <= 0 || q.PartialRotaryFactor > 1 {
|
||||
return fmt.Errorf("qwen3next: partial_rotary_factor must be in (0,1], got %v", q.PartialRotaryFactor)
|
||||
}
|
||||
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
|
||||
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
|
||||
}
|
||||
if q.FullAttentionInterval == 0 {
|
||||
return fmt.Errorf("qwen3next: full_attention_interval must be set")
|
||||
}
|
||||
if q.FullAttentionInterval > q.NumHiddenLayers {
|
||||
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
}
|
||||
|
||||
hasFull := false
|
||||
for i := range q.NumHiddenLayers {
|
||||
if (i+1)%q.FullAttentionInterval == 0 {
|
||||
hasFull = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasFull {
|
||||
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen3next"
|
||||
kv["tokenizer.ggml.pre"] = "qwen2"
|
||||
kv["block_count"] = q.NumHiddenLayers
|
||||
kv["context_length"] = q.MaxPositionEmbeddings
|
||||
kv["embedding_length"] = q.HiddenSize
|
||||
kv["feed_forward_length"] = q.IntermediateSize
|
||||
kv["attention.head_count"] = q.NumAttentionHeads
|
||||
headDim := q.HeadDim
|
||||
if headDim == 0 && q.NumAttentionHeads > 0 {
|
||||
headDim = q.HiddenSize / q.NumAttentionHeads
|
||||
}
|
||||
kv["attention.key_length"] = headDim
|
||||
kv["attention.value_length"] = headDim
|
||||
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
||||
kv["rope.freq_base"] = q.RopeTheta
|
||||
|
||||
// RoPE dimension count (partial rotary)
|
||||
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
|
||||
partialRotary := q.PartialRotaryFactor
|
||||
if partialRotary > 0 && partialRotary <= 1 {
|
||||
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
|
||||
}
|
||||
|
||||
// MoE config
|
||||
if q.NumExperts > 0 {
|
||||
kv["expert_count"] = q.NumExperts
|
||||
kv["expert_used_count"] = q.NumExpertsPerToken
|
||||
kv["norm_top_k_prob"] = q.NormTopkProb
|
||||
if q.MoEIntermediateSize > 0 {
|
||||
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
|
||||
}
|
||||
if q.SharedExpertIntermSize > 0 {
|
||||
kv["expert_shared_feed_forward_length"] = q.SharedExpertIntermSize
|
||||
}
|
||||
}
|
||||
|
||||
// SSM/Linear attention config
|
||||
// d_inner = linear_value_head_dim * linear_num_value_heads
|
||||
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
|
||||
kv["ssm.inner_size"] = dInner
|
||||
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
|
||||
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
|
||||
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
|
||||
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
|
||||
interval := q.FullAttentionInterval
|
||||
kv["full_attention_interval"] = interval
|
||||
|
||||
// Build per-layer KV head count array to identify layer types
|
||||
// 0 = recurrent (linear attention), non-zero = full attention
|
||||
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
|
||||
for i := range q.NumHiddenLayers {
|
||||
// Full attention every full_attention_interval layers (starting at interval-1)
|
||||
if interval > 0 && (i+1)%interval == 0 {
|
||||
kvHeadCounts[i] = q.NumKeyValueHeads
|
||||
}
|
||||
// else stays 0 (recurrent layer)
|
||||
}
|
||||
kv["attention.head_count_kv"] = kvHeadCounts
|
||||
|
||||
// RoPE scaling
|
||||
if q.RopeScaling.Type != "" {
|
||||
kv["rope.scaling.type"] = q.RopeScaling.Type
|
||||
kv["rope.scaling.factor"] = q.RopeScaling.Factor
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
// Create merges for expert tensors - stack individual experts into batched tensors
|
||||
merges := make([]merge, q.NumHiddenLayers*3)
|
||||
for i := range q.NumHiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
|
||||
// Merge expert tensors
|
||||
merged, remaining := mergeTensors(ts, merges...)
|
||||
out = append(out, merged...)
|
||||
|
||||
// Process remaining tensors
|
||||
for _, t := range remaining {
|
||||
name := t.Name()
|
||||
shape := t.Shape()
|
||||
|
||||
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
|
||||
if strings.HasSuffix(name, ".ssm_in.weight") {
|
||||
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
|
||||
out = append(out, qkv, gate)
|
||||
continue
|
||||
}
|
||||
panic(fmt.Sprintf("qwen3next: failed to split %s into attn_qkv/attn_gate (shape=%v)", name, shape))
|
||||
}
|
||||
|
||||
switch {
|
||||
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
|
||||
// This matches the Python converter behavior for qwen3next
|
||||
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
|
||||
t.SetRepacker(q.addOne)
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
// Handle linear attention A_log -> ssm_a (negate and exp)
|
||||
// Note: name has already been transformed by Replacements at this point
|
||||
case strings.HasSuffix(name, ".ssm_a"):
|
||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
// Compute -exp(A_log)
|
||||
result := make([]float32, len(data))
|
||||
for i, v := range data {
|
||||
// -exp(v)
|
||||
result[i] = -float32(math.Exp(float64(v)))
|
||||
}
|
||||
return result, nil
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
|
||||
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
||||
newShape := slices.Clone(shape)
|
||||
if len(shape) == 3 {
|
||||
if shape[0] == 1 {
|
||||
// [1, D, K] -> [D, K]
|
||||
newShape = []uint64{shape[1], shape[2]}
|
||||
} else if shape[1] == 1 {
|
||||
// [D, 1, K] -> [D, K]
|
||||
newShape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
|
||||
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
|
||||
newShape := slices.Clone(shape)
|
||||
if len(shape) == 2 {
|
||||
if shape[0] == 1 && shape[1] > 1 {
|
||||
newShape = []uint64{shape[1]}
|
||||
} else if shape[1] == 1 && shape[0] > 1 {
|
||||
newShape = []uint64{shape[0]}
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
|
||||
default:
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
type qkvzSplitSpec struct {
|
||||
hidden int
|
||||
headKDim int
|
||||
headVDim int
|
||||
numKHeads int
|
||||
numVHeads int
|
||||
qkvzDim int
|
||||
qkvOut int
|
||||
gateOut int
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) qkvzSpec(shape []uint64) (qkvzSplitSpec, bool) {
|
||||
if len(shape) != 2 {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
|
||||
numKHeads := int(q.LinearNumKeyHeads)
|
||||
numVHeads := int(q.LinearNumValueHeads)
|
||||
headKDim := int(q.LinearKeyHeadDim)
|
||||
headVDim := int(q.LinearValueHeadDim)
|
||||
if numKHeads == 0 || numVHeads == 0 || headKDim == 0 || headVDim == 0 {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
if numVHeads%numKHeads != 0 {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
|
||||
hidden := int(shape[1])
|
||||
vPerHead := headVDim * (numVHeads / numKHeads)
|
||||
qkvzDim := 2*headKDim + 2*vPerHead
|
||||
expectedOut := qkvzDim * numKHeads
|
||||
if int(shape[0]) != expectedOut {
|
||||
return qkvzSplitSpec{}, false
|
||||
}
|
||||
|
||||
return qkvzSplitSpec{
|
||||
hidden: hidden,
|
||||
headKDim: headKDim,
|
||||
headVDim: headVDim,
|
||||
numKHeads: numKHeads,
|
||||
numVHeads: numVHeads,
|
||||
qkvzDim: qkvzDim,
|
||||
qkvOut: 2*headKDim*numKHeads + headVDim*numVHeads,
|
||||
gateOut: headVDim * numVHeads,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) splitQKVZTensor(t Tensor) (*ggml.Tensor, *ggml.Tensor, bool) {
|
||||
spec, ok := q.qkvzSpec(t.Shape())
|
||||
if !ok {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
qkvTensor := t.Clone()
|
||||
qkvTensor.SetRepacker(q.repackQKVZ(spec, false))
|
||||
|
||||
gateTensor := t.Clone()
|
||||
gateTensor.SetRepacker(q.repackQKVZ(spec, true))
|
||||
|
||||
qkvName := strings.Replace(t.Name(), "ssm_in", "attn_qkv", 1)
|
||||
gateName := strings.Replace(t.Name(), "ssm_in", "attn_gate", 1)
|
||||
|
||||
return &ggml.Tensor{
|
||||
Name: qkvName,
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(spec.qkvOut), uint64(spec.hidden)},
|
||||
WriterTo: qkvTensor,
|
||||
}, &ggml.Tensor{
|
||||
Name: gateName,
|
||||
Kind: t.Kind(),
|
||||
Shape: []uint64{uint64(spec.gateOut), uint64(spec.hidden)},
|
||||
WriterTo: gateTensor,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repacker {
|
||||
vPerHead := spec.headVDim * (spec.numVHeads / spec.numKHeads)
|
||||
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
var err error
|
||||
|
||||
// Convert to [hidden, out_features] layout for slicing
|
||||
tt, err = tensor.Transpose(tt, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
|
||||
if err := tt.Reshape(spec.hidden, spec.numKHeads, spec.qkvzDim); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
offset := 0
|
||||
qSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset += spec.headKDim
|
||||
kSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset += spec.headKDim
|
||||
vSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset += vPerHead
|
||||
zSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qMat := tensor.Materialize(qSlice).(*tensor.Dense)
|
||||
kMat := tensor.Materialize(kSlice).(*tensor.Dense)
|
||||
vMat := tensor.Materialize(vSlice).(*tensor.Dense)
|
||||
zMat := tensor.Materialize(zSlice).(*tensor.Dense)
|
||||
|
||||
if err := qMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := kMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := vMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := zMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out tensor.Tensor
|
||||
if extractGate {
|
||||
out = zMat
|
||||
} else {
|
||||
out, err = tensor.Concat(1, qMat, kMat, vMat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
out = tensor.Materialize(out)
|
||||
out, err = tensor.Transpose(out, 1, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = tensor.Materialize(out)
|
||||
|
||||
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(out.(*tensor.Dense))
|
||||
}
|
||||
}
|
||||
|
||||
// addOne adds 1.0 to all elements in the tensor (for norm weights)
|
||||
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
|
||||
ones := tensor.Ones(tensor.Float32, int(shape[0]))
|
||||
|
||||
n, err := n.Add(ones)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ts, err := native.SelectF32(n, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
for _, t := range ts {
|
||||
f32s = append(f32s, t...)
|
||||
}
|
||||
|
||||
return f32s, nil
|
||||
}
|
||||
|
||||
func (q *qwen3NextModel) Replacements() []string {
|
||||
return []string{
|
||||
// Embeddings and output
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
|
||||
// Layer norms
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
|
||||
// Full attention (self_attn)
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
|
||||
// Linear attention (Gated Delta Net)
|
||||
"linear_attn.in_proj_qkvz", "ssm_in",
|
||||
"linear_attn.in_proj_ba", "ssm_ba",
|
||||
"linear_attn.conv1d", "ssm_conv1d",
|
||||
"linear_attn.dt_bias", "ssm_dt",
|
||||
"linear_attn.dt_proj", "ssm_dt",
|
||||
"linear_attn.A_log", "ssm_a",
|
||||
"linear_attn.norm", "ssm_norm",
|
||||
"linear_attn.out_proj", "ssm_out",
|
||||
|
||||
// MoE (experts are stacked via mergeTensors, not replaced here)
|
||||
"mlp.gate.weight", "ffn_gate_inp.weight",
|
||||
"mlp.shared_expert.down_proj", "ffn_down_shexp",
|
||||
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_expert.up_proj", "ffn_up_shexp",
|
||||
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
|
||||
|
||||
// Dense FFN (if any layers use it)
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
}
|
||||
}
|
||||
@@ -41,6 +41,7 @@ func (t tensorBase) Kind() uint32 {
|
||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||
strings.HasSuffix(t.name, ".bias") ||
|
||||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
|
||||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
|
||||
@@ -99,6 +99,8 @@ func (st safetensor) Kind() uint32 {
|
||||
if st.dtype == "BF16" &&
|
||||
!strings.HasPrefix(st.name, "v.") &&
|
||||
!strings.HasPrefix(st.name, "s.") &&
|
||||
!strings.HasPrefix(st.name, "mm.") &&
|
||||
!strings.Contains(st.name, "ffn_gate_inp_shexp.weight") &&
|
||||
kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
@@ -5,7 +5,10 @@ title: Context length
|
||||
Context length is the maximum number of tokens that the model has access to in memory.
|
||||
|
||||
<Note>
|
||||
The default context length in Ollama is 4096 tokens.
|
||||
Ollama defaults to the following context lengths based on VRAM:
|
||||
- < 24 GiB VRAM: 4k context
|
||||
- 24-48 GiB VRAM: 32k context
|
||||
- >= 48 GiB VRAM: 256k context
|
||||
</Note>
|
||||
|
||||
Tasks which require large context like web search, agents, and coding tools should be set to at least 64000 tokens.
|
||||
|
||||
@@ -71,6 +71,10 @@
|
||||
{
|
||||
"source": "/api",
|
||||
"destination": "/api/introduction"
|
||||
},
|
||||
{
|
||||
"source": "/integrations/clawdbot",
|
||||
"destination": "/integrations/openclaw"
|
||||
}
|
||||
],
|
||||
"navigation": {
|
||||
@@ -101,20 +105,52 @@
|
||||
{
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/cline",
|
||||
"/integrations/codex",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/marimo",
|
||||
"/integrations/n8n",
|
||||
"/integrations/onyx",
|
||||
"/integrations/opencode",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/xcode",
|
||||
"/integrations/zed"
|
||||
"/integrations/index",
|
||||
{
|
||||
"group": "Coding",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/codex",
|
||||
"/integrations/opencode",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Assistants",
|
||||
"pages": [
|
||||
"/integrations/openclaw"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "IDEs & Editors",
|
||||
"pages": [
|
||||
"/integrations/cline",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/xcode",
|
||||
"/integrations/zed"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Chat & RAG",
|
||||
"pages": [
|
||||
"/integrations/onyx"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Automation",
|
||||
"pages": [
|
||||
"/integrations/n8n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Notebooks",
|
||||
"pages": [
|
||||
"/integrations/marimo"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
22
docs/faq.mdx
22
docs/faq.mdx
@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
## How can I view the logs?
|
||||
|
||||
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
|
||||
Review the [Troubleshooting](./troubleshooting.mdx) docs for more about using logs.
|
||||
|
||||
## Is my GPU compatible with Ollama?
|
||||
|
||||
Please refer to the [GPU docs](./gpu).
|
||||
Please refer to the [GPU docs](./gpu.mdx).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
@@ -66,7 +66,7 @@ llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
||||
```
|
||||
</Info>
|
||||
|
||||
The `Processor` column will show which memory the model was loaded in to:
|
||||
The `Processor` column will show which memory the model was loaded into:
|
||||
|
||||
- `100% GPU` means the model was loaded entirely into the GPU
|
||||
- `100% CPU` means the model was loaded entirely in system memory
|
||||
@@ -158,7 +158,7 @@ docker run -d -e HTTPS_PROXY=https://my.proxy.example.com -p 11434:11434 ollama-
|
||||
|
||||
## Does Ollama send my prompts and answers back to ollama.com?
|
||||
|
||||
No. Ollama runs locally, and conversation data does not leave your machine.
|
||||
Ollama runs locally. We don't see your prompts or data when you run locally. When using cloud-hosted models, we process your prompts and responses to provide the service but do not store or log that content and never train on it. We collect basic account info and limited usage metadata to provide the service that does not include prompt or response content. We don't sell your data. You can delete your account anytime.
|
||||
|
||||
## How can I expose Ollama on my network?
|
||||
|
||||
@@ -183,7 +183,7 @@ server {
|
||||
|
||||
## How can I use Ollama with ngrok?
|
||||
|
||||
Ollama can be accessed using a range of tools for tunneling tools. For example with Ngrok:
|
||||
Ollama can be accessed using a range of tunneling apps. For example with Ngrok:
|
||||
|
||||
```shell
|
||||
ngrok http 11434 --host-header="localhost:11434"
|
||||
@@ -240,7 +240,7 @@ GPU acceleration is not available for Docker Desktop in macOS due to the lack of
|
||||
|
||||
This can impact both installing Ollama, as well as downloading models.
|
||||
|
||||
Open `Control Panel > Networking and Internet > View network status and tasks` and click on `Change adapter settings` on the left panel. Find the `vEthernel (WSL)` adapter, right click and select `Properties`.
|
||||
Open `Control Panel > Networking and Internet > View network status and tasks` and click on `Change adapter settings` on the left panel. Find the `vEthernet (WSL)` adapter, right click and select `Properties`.
|
||||
Click on `Configure` and open the `Advanced` tab. Search through each of the properties until you find `Large Send Offload Version 2 (IPv4)` and `Large Send Offload Version 2 (IPv6)`. _Disable_ both of these
|
||||
properties.
|
||||
|
||||
@@ -299,7 +299,7 @@ The `keep_alive` API parameter with the `/api/generate` and `/api/chat` API endp
|
||||
|
||||
## How do I manage the maximum number of requests the Ollama server can queue?
|
||||
|
||||
If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queue by setting `OLLAMA_MAX_QUEUE`.
|
||||
If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queued by setting `OLLAMA_MAX_QUEUE`.
|
||||
|
||||
## How does Ollama handle concurrent requests?
|
||||
|
||||
@@ -312,10 +312,10 @@ Parallel request processing for a given model results in increasing the context
|
||||
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
|
||||
|
||||
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 \* the number of GPUs or 3 for CPU inference.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time, default 1. Required RAM will scale by `OLLAMA_NUM_PARALLEL` * `OLLAMA_CONTEXT_LENGTH`.
|
||||
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPU's VRAM.
|
||||
|
||||
## How does Ollama load models on multiple GPUs?
|
||||
|
||||
@@ -382,7 +382,7 @@ ollama signin
|
||||
Replace <username> with your actual Windows user name.
|
||||
</Note>
|
||||
|
||||
## How can I stop Ollama from starting when I login to my computer
|
||||
## How can I stop Ollama from starting when I login to my computer?
|
||||
|
||||
Ollama for Windows and macOS register as a login item during installation. You can disable this if you prefer not to have Ollama automatically start. Ollama will respect this setting across upgrades, unless you uninstall the application.
|
||||
|
||||
@@ -390,4 +390,4 @@ Ollama for Windows and macOS register as a login item during installation. You
|
||||
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
|
||||
|
||||
**MacOS**
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under `Allow in the Background`, then click the slider to disable.
|
||||
|
||||
@@ -10,6 +10,7 @@ Check your compute compatibility to see if your card is supported:
|
||||
|
||||
| Compute Capability | Family | Cards |
|
||||
| ------------------ | ------------------- | ------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| 12.1 | NVIDIA | `GB10 (DGX Spark)` |
|
||||
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
|
||||
| | NVIDIA Professional | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
|
||||
| 9.0 | NVIDIA | `H200` `H100` |
|
||||
@@ -163,4 +164,4 @@ To select specific Vulkan GPU(s), you can set the environment variable
|
||||
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
|
||||
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
|
||||
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
|
||||
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||
|
||||
@@ -134,22 +134,12 @@ success
|
||||
|
||||
### Supported Quantizations
|
||||
|
||||
- `q4_0`
|
||||
- `q4_1`
|
||||
- `q5_0`
|
||||
- `q5_1`
|
||||
- `q8_0`
|
||||
|
||||
#### K-means Quantizations
|
||||
|
||||
- `q3_K_S`
|
||||
- `q3_K_M`
|
||||
- `q3_K_L`
|
||||
- `q4_K_S`
|
||||
- `q4_K_M`
|
||||
- `q5_K_S`
|
||||
- `q5_K_M`
|
||||
- `q6_K`
|
||||
|
||||
## Sharing your model on ollama.com
|
||||
|
||||
|
||||
50
docs/integrations/index.mdx
Normal file
50
docs/integrations/index.mdx
Normal file
@@ -0,0 +1,50 @@
|
||||
---
|
||||
title: Overview
|
||||
---
|
||||
|
||||
Ollama integrates with a wide range of tools.
|
||||
|
||||
## Coding Agents
|
||||
|
||||
Coding assistants that can read, modify, and execute code in your projects.
|
||||
|
||||
- [Claude Code](/integrations/claude-code)
|
||||
- [Codex](/integrations/codex)
|
||||
- [OpenCode](/integrations/opencode)
|
||||
- [Droid](/integrations/droid)
|
||||
- [Goose](/integrations/goose)
|
||||
|
||||
## Assistants
|
||||
|
||||
AI assistants that help with everyday tasks.
|
||||
|
||||
- [OpenClaw](/integrations/openclaw)
|
||||
|
||||
## IDEs & Editors
|
||||
|
||||
Native integrations for popular development environments.
|
||||
|
||||
- [VS Code](/integrations/vscode)
|
||||
- [Cline](/integrations/cline)
|
||||
- [Roo Code](/integrations/roo-code)
|
||||
- [JetBrains](/integrations/jetbrains)
|
||||
- [Xcode](/integrations/xcode)
|
||||
- [Zed](/integrations/zed)
|
||||
|
||||
## Chat & RAG
|
||||
|
||||
Chat interfaces and retrieval-augmented generation platforms.
|
||||
|
||||
- [Onyx](/integrations/onyx)
|
||||
|
||||
## Automation
|
||||
|
||||
Workflow automation platforms with AI integration.
|
||||
|
||||
- [n8n](/integrations/n8n)
|
||||
|
||||
## Notebooks
|
||||
|
||||
Interactive computing environments with AI capabilities.
|
||||
|
||||
- [marimo](/integrations/marimo)
|
||||
50
docs/integrations/openclaw.mdx
Normal file
50
docs/integrations/openclaw.mdx
Normal file
@@ -0,0 +1,50 @@
|
||||
---
|
||||
title: OpenClaw
|
||||
---
|
||||
|
||||
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
|
||||
|
||||
## Install
|
||||
|
||||
Install [OpenClaw](https://openclaw.ai/)
|
||||
|
||||
```bash
|
||||
npm install -g openclaw@latest
|
||||
```
|
||||
|
||||
Then run the onboarding wizard:
|
||||
|
||||
```bash
|
||||
openclaw onboard --install-daemon
|
||||
```
|
||||
|
||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch openclaw
|
||||
```
|
||||
|
||||
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
||||
|
||||
This configures OpenClaw to use Ollama and starts the gateway.
|
||||
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
|
||||
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch openclaw --config
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
@@ -9,7 +9,7 @@ OpenCode is an open-source AI coding assistant that runs in your terminal.
|
||||
Install the [OpenCode CLI](https://opencode.ai):
|
||||
|
||||
```bash
|
||||
curl -fsSL https://opencode.ai/install.sh | bash
|
||||
curl -fsSL https://opencode.ai/install | bash
|
||||
```
|
||||
|
||||
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
title: Quickstart
|
||||
---
|
||||
|
||||
This quickstart will walk your through running your first model with Ollama. To get started, download Ollama on macOS, Windows or Linux.
|
||||
Ollama is available on macOS, Windows, and Linux.
|
||||
|
||||
<a
|
||||
href="https://ollama.com/download"
|
||||
@@ -12,131 +12,48 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
Download Ollama
|
||||
</a>
|
||||
|
||||
## Run a model
|
||||
## Get Started
|
||||
|
||||
<Tabs>
|
||||
<Tab title="CLI">
|
||||
Open a terminal and run the command:
|
||||
|
||||
```sh
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="cURL">
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
Lastly, chat with the model:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "gemma3",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "Hello there!"
|
||||
}],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="Python">
|
||||
Start by downloading a model:
|
||||
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
Then install Ollama's Python library:
|
||||
|
||||
```sh
|
||||
pip install ollama
|
||||
```
|
||||
|
||||
Lastly, chat with the model:
|
||||
|
||||
```python
|
||||
from ollama import chat
|
||||
from ollama import ChatResponse
|
||||
|
||||
response: ChatResponse = chat(model='gemma3', messages=[
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
},
|
||||
])
|
||||
print(response['message']['content'])
|
||||
# or access fields directly from the response object
|
||||
print(response.message.content)
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="JavaScript">
|
||||
Start by downloading a model:
|
||||
|
||||
```
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
Then install the Ollama JavaScript library:
|
||||
```
|
||||
npm i ollama
|
||||
```
|
||||
|
||||
Lastly, chat with the model:
|
||||
|
||||
```shell
|
||||
import ollama from 'ollama'
|
||||
|
||||
const response = await ollama.chat({
|
||||
model: 'gemma3',
|
||||
messages: [{ role: 'user', content: 'Why is the sky blue?' }],
|
||||
})
|
||||
console.log(response.message.content)
|
||||
```
|
||||
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
See a full list of available models [here](https://ollama.com/models).
|
||||
|
||||
## Coding
|
||||
|
||||
For coding use cases, we recommend using the `glm-4.7-flash` model.
|
||||
|
||||
Note: this model requires 23 GB of VRAM with 64000 tokens context length.
|
||||
```sh
|
||||
ollama pull glm-4.7-flash
|
||||
```
|
||||
|
||||
Alternatively, you can use a more powerful cloud model (with full context length):
|
||||
```sh
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
Use `ollama launch` to quickly set up a coding tool with Ollama models:
|
||||
Run `ollama` in your terminal to open the interactive menu:
|
||||
|
||||
```sh
|
||||
ollama launch
|
||||
ollama
|
||||
```
|
||||
|
||||
### Supported integrations
|
||||
Navigate with `↑/↓`, press `enter` to launch, `→` to change model, and `esc` to quit.
|
||||
|
||||
- [OpenCode](/integrations/opencode) - Open-source coding assistant
|
||||
- [Claude Code](/integrations/claude-code) - Anthropic's agentic coding tool
|
||||
- [Codex](/integrations/codex) - OpenAI's coding assistant
|
||||
- [Droid](/integrations/droid) - Factory's AI coding agent
|
||||
The menu provides quick access to:
|
||||
- **Run a model** - Start an interactive chat
|
||||
- **Launch tools** - Claude Code, Codex, OpenClaw, and more
|
||||
- **Additional integrations** - Available under "More..."
|
||||
|
||||
### Launch with a specific model
|
||||
## Coding
|
||||
|
||||
Launch coding tools with Ollama models:
|
||||
|
||||
```sh
|
||||
ollama launch claude --model glm-4.7-flash
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
### Configure without launching
|
||||
|
||||
```sh
|
||||
ollama launch claude --config
|
||||
ollama launch codex
|
||||
```
|
||||
|
||||
```sh
|
||||
ollama launch opencode
|
||||
```
|
||||
|
||||
See [integrations](/integrations) for all supported tools.
|
||||
|
||||
## API
|
||||
|
||||
Use the [API](/api) to integrate Ollama into your applications:
|
||||
|
||||
```sh
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "gemma3",
|
||||
"messages": [{ "role": "user", "content": "Hello!" }]
|
||||
}'
|
||||
```
|
||||
|
||||
See the [API documentation](/api) for Python, JavaScript, and other integrations.
|
||||
|
||||
@@ -201,7 +201,7 @@ var (
|
||||
// Enable the new Ollama engine
|
||||
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
||||
// ContextLength sets the default context length
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 0)
|
||||
// Auth enables authentication between the Ollama client and server
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable Vulkan backend
|
||||
@@ -290,7 +290,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
|
||||
|
||||
@@ -282,7 +282,7 @@ func TestVar(t *testing.T) {
|
||||
|
||||
func TestContextLength(t *testing.T) {
|
||||
cases := map[string]uint{
|
||||
"": 4096,
|
||||
"": 0,
|
||||
"2048": 2048,
|
||||
}
|
||||
|
||||
|
||||
@@ -268,8 +268,10 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"olmo3",
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3next",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"glm4moelite",
|
||||
"glmocr",
|
||||
"lfm2",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
@@ -859,11 +861,13 @@ func (f GGML) FlashAttention() bool {
|
||||
"bert",
|
||||
"gemma3",
|
||||
"glm4moelite",
|
||||
"glmocr",
|
||||
"gptoss", "gpt-oss",
|
||||
"lfm2",
|
||||
"mistral3",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3next",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
}, f.KV().String("general.architecture"))
|
||||
}
|
||||
|
||||
24
go.mod
24
go.mod
@@ -13,7 +13,7 @@ require (
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
@@ -21,13 +21,18 @@ require (
|
||||
|
||||
require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/mattn/go-runewidth v0.0.16
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.30.0
|
||||
@@ -37,22 +42,35 @@ require (
|
||||
|
||||
require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/x/ansi v0.10.1 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/chewxy/hm v1.0.0 // indirect
|
||||
github.com/chewxy/math32 v1.11.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-pointer v0.0.1 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/tkrajina/go-reflector v0.5.5 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
|
||||
70
go.sum
70
go.sum
@@ -14,6 +14,8 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
@@ -24,6 +26,18 @@ github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
|
||||
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
|
||||
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
|
||||
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
|
||||
@@ -59,6 +73,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
@@ -148,13 +164,19 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
|
||||
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
|
||||
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
||||
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -162,6 +184,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
|
||||
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
@@ -174,14 +202,17 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
|
||||
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
||||
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
|
||||
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
@@ -204,12 +235,39 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
|
||||
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
|
||||
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
|
||||
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
|
||||
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
|
||||
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
|
||||
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
|
||||
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
|
||||
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
|
||||
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
|
||||
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
|
||||
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
|
||||
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
|
||||
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
|
||||
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
|
||||
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
|
||||
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
|
||||
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
|
||||
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
|
||||
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
|
||||
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
|
||||
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
|
||||
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
|
||||
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
@@ -218,6 +276,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
||||
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -304,6 +364,8 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
|
||||
@@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
}
|
||||
ChatTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
// TestNumPredict verifies that when num_predict is set, the model generates
|
||||
// exactly that many tokens. It uses logprobs to count the actual tokens output.
|
||||
func TestNumPredict(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
|
||||
t.Fatalf("failed to pull model: %v", err)
|
||||
}
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "qwen3:0.6b",
|
||||
Prompt: "Write a long story.",
|
||||
Stream: &stream,
|
||||
Logprobs: true,
|
||||
Options: map[string]any{
|
||||
"num_predict": 10,
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
|
||||
logprobCount := 0
|
||||
var finalResponse api.GenerateResponse
|
||||
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
|
||||
logprobCount += len(resp.Logprobs)
|
||||
if resp.Done {
|
||||
finalResponse = resp
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("generate failed: %v", err)
|
||||
}
|
||||
|
||||
if logprobCount != 10 {
|
||||
t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
|
||||
logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,3 +75,10 @@ type Cache interface {
|
||||
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||
Remove(seq int, beginIndex, endIndex int32) error
|
||||
}
|
||||
|
||||
// CheckpointCache optionally supports restoring recurrent state to a prior
|
||||
// position to avoid full prompt reprocessing when a prefix mismatch occurs.
|
||||
// The returned position is the number of tokens that can be kept (prefix length).
|
||||
type CheckpointCache interface {
|
||||
PrepareRestore(seq int, targetPos int32) (int32, bool)
|
||||
}
|
||||
|
||||
276
llama/patches/0033-ggml-metal-solve_tri.patch
Normal file
276
llama/patches/0033-ggml-metal-solve_tri.patch
Normal file
@@ -0,0 +1,276 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jeffrey Morgan <jmorganca@gmail.com>
|
||||
Date: Tue, 3 Feb 2026 12:00:00 -0800
|
||||
Subject: [PATCH] ggml: metal solve_tri
|
||||
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal-device.cpp | 20 +++++++
|
||||
ggml/src/ggml-metal/ggml-metal-device.h | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal-device.m | 11 ++++
|
||||
ggml/src/ggml-metal/ggml-metal-impl.h | 21 ++++++++
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 63 +++++++++++++++++++++++
|
||||
ggml/src/ggml-metal/ggml-metal-ops.h | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 60 +++++++++++++++++++++
|
||||
7 files changed, 177 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
index 680904d13..83385c9ef 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
|
||||
return res;
|
||||
}
|
||||
|
||||
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
+ assert(op->op == GGML_OP_SOLVE_TRI);
|
||||
+
|
||||
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||
+
|
||||
+ char base[256];
|
||||
+ char name[256];
|
||||
+
|
||||
+ snprintf(base, 256, "kernel_solve_tri_f32");
|
||||
+ snprintf(name, 256, "%s", base);
|
||||
+
|
||||
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
+ if (!res.pipeline) {
|
||||
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
+ }
|
||||
+
|
||||
+ return res;
|
||||
+}
|
||||
+
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_GROUP_NORM);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
index 0a8b9211a..8a9d17460 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
index 7b5ee968c..4e5acfbe5 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
+ case GGML_OP_SOLVE_TRI:
|
||||
+ return ggml_is_contiguous(op->src[0]) &&
|
||||
+ ggml_is_contiguous(op->src[1]) &&
|
||||
+ op->src[0]->type == GGML_TYPE_F32 &&
|
||||
+ op->src[1]->type == GGML_TYPE_F32 &&
|
||||
+ op->type == GGML_TYPE_F32;
|
||||
+ case GGML_OP_COUNT_EQUAL:
|
||||
+ return has_simdgroup_reduction &&
|
||||
+ op->src[0]->type == GGML_TYPE_I32 &&
|
||||
+ op->src[1]->type == GGML_TYPE_I32 &&
|
||||
+ op->type == GGML_TYPE_I64;
|
||||
case GGML_OP_ARGMAX:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_NORM:
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
index 8944b07e9..cfdea9c07 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
@@ -500,6 +500,27 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
+typedef struct {
|
||||
+ int32_t ne00;
|
||||
+ int32_t ne01;
|
||||
+ int32_t ne02;
|
||||
+ int32_t ne03;
|
||||
+ uint64_t nb00;
|
||||
+ uint64_t nb01;
|
||||
+ uint64_t nb02;
|
||||
+ uint64_t nb03;
|
||||
+ int32_t ne10;
|
||||
+ int32_t ne11;
|
||||
+ uint64_t nb10;
|
||||
+ uint64_t nb11;
|
||||
+ uint64_t nb12;
|
||||
+ uint64_t nb13;
|
||||
+ uint64_t nb0;
|
||||
+ uint64_t nb1;
|
||||
+ uint64_t nb2;
|
||||
+ uint64_t nb3;
|
||||
+} ggml_metal_kargs_solve_tri;
|
||||
+
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index 80864f303..4ac135603 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
||||
} break;
|
||||
+ case GGML_OP_SOLVE_TRI:
|
||||
+ {
|
||||
+ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
|
||||
+ } break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_group_norm(ctx, idx);
|
||||
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
||||
+ ggml_tensor * op = ctx->node(idx);
|
||||
+
|
||||
+ ggml_metal_library_t lib = ctx->lib;
|
||||
+ ggml_metal_encoder_t enc = ctx->enc;
|
||||
+
|
||||
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
+
|
||||
+ ggml_metal_kargs_solve_tri args = {
|
||||
+ /*.ne00 =*/ ne00,
|
||||
+ /*.ne01 =*/ ne01,
|
||||
+ /*.ne02 =*/ ne02,
|
||||
+ /*.ne03 =*/ ne03,
|
||||
+ /*.nb00 =*/ nb00,
|
||||
+ /*.nb01 =*/ nb01,
|
||||
+ /*.nb02 =*/ nb02,
|
||||
+ /*.nb03 =*/ nb03,
|
||||
+ /*.ne10 =*/ ne10,
|
||||
+ /*.ne11 =*/ ne11,
|
||||
+ /*.nb10 =*/ nb10,
|
||||
+ /*.nb11 =*/ nb11,
|
||||
+ /*.nb12 =*/ nb12,
|
||||
+ /*.nb13 =*/ nb13,
|
||||
+ /*.nb0 =*/ nb0,
|
||||
+ /*.nb1 =*/ nb1,
|
||||
+ /*.nb2 =*/ nb2,
|
||||
+ /*.nb3 =*/ nb3,
|
||||
+ };
|
||||
+
|
||||
+ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
|
||||
+
|
||||
+ const int64_t ncols = ne10;
|
||||
+ const int64_t n_batches = (int64_t)ne02 * ne03;
|
||||
+ const int64_t nr = n_batches * ncols;
|
||||
+
|
||||
+ int nth = 64;
|
||||
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
+ if (nth < 1) {
|
||||
+ nth = 1;
|
||||
+ }
|
||||
+
|
||||
+ const int64_t n_tg = (nr + nth - 1) / nth;
|
||||
+
|
||||
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
+
|
||||
+ ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
|
||||
+
|
||||
+ return 1;
|
||||
+}
|
||||
+
|
||||
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h
|
||||
index 902b54452..a475183d3 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.h
|
||||
@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
|
||||
+int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index d33c16079..c37447a10 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
+kernel void kernel_solve_tri_f32(
|
||||
+ constant ggml_metal_kargs_solve_tri & args,
|
||||
+ device const char * src0,
|
||||
+ device const char * src1,
|
||||
+ device char * dst,
|
||||
+ uint tgpig[[threadgroup_position_in_grid]],
|
||||
+ ushort tpitg[[thread_position_in_threadgroup]],
|
||||
+ ushort ntg[[threads_per_threadgroup]]) {
|
||||
+ const uint64_t ncols = (uint64_t) args.ne10;
|
||||
+ const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
|
||||
+ const uint64_t nr = n_batches * ncols;
|
||||
+
|
||||
+ const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
|
||||
+ if (gid >= nr) {
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
|
||||
+ const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
|
||||
+ const uint64_t i02 = rem / ncols;
|
||||
+ const uint64_t i01 = rem - i02 * ncols;
|
||||
+
|
||||
+ const uint64_t sa0 = args.nb00 / sizeof(float);
|
||||
+ const uint64_t sa1 = args.nb01 / sizeof(float);
|
||||
+ const uint64_t sa2 = args.nb02 / sizeof(float);
|
||||
+ const uint64_t sa3 = args.nb03 / sizeof(float);
|
||||
+
|
||||
+ const uint64_t sb0 = args.nb10 / sizeof(float);
|
||||
+ const uint64_t sb1 = args.nb11 / sizeof(float);
|
||||
+ const uint64_t sb2 = args.nb12 / sizeof(float);
|
||||
+ const uint64_t sb3 = args.nb13 / sizeof(float);
|
||||
+
|
||||
+ const uint64_t sx0 = args.nb0 / sizeof(float);
|
||||
+ const uint64_t sx1 = args.nb1 / sizeof(float);
|
||||
+ const uint64_t sx2 = args.nb2 / sizeof(float);
|
||||
+ const uint64_t sx3 = args.nb3 / sizeof(float);
|
||||
+
|
||||
+ device const float * A = (device const float *) src0;
|
||||
+ device const float * B = (device const float *) src1;
|
||||
+ device float * X = (device float *) dst;
|
||||
+
|
||||
+ const uint64_t A_base = i02 * sa2 + i03 * sa3;
|
||||
+ const uint64_t B_base = i02 * sb2 + i03 * sb3;
|
||||
+ const uint64_t X_base = i02 * sx2 + i03 * sx3;
|
||||
+
|
||||
+ const uint64_t n = (uint64_t) args.ne11;
|
||||
+
|
||||
+ for (uint64_t i00 = 0; i00 < n; ++i00) {
|
||||
+ float sum = 0.0f;
|
||||
+ for (uint64_t t = 0; t < i00; ++t) {
|
||||
+ sum += A[A_base + i00 * sa1 + t * sa0] *
|
||||
+ X[X_base + t * sx1 + i01 * sx0];
|
||||
+ }
|
||||
+
|
||||
+ const float diag = A[A_base + i00 * sa1 + i00 * sa0];
|
||||
+ X[X_base + i00 * sx1 + i01 * sx0] =
|
||||
+ (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type filteredEnv []string
|
||||
@@ -80,6 +81,7 @@ type LlamaServer interface {
|
||||
GetPort() int
|
||||
GetDeviceInfos(ctx context.Context) []ml.DeviceInfo
|
||||
HasExited() bool
|
||||
ContextLength() int
|
||||
}
|
||||
|
||||
// llmServer is an instance of a runner hosting a single model
|
||||
@@ -115,7 +117,7 @@ type llamaServer struct {
|
||||
type ollamaServer struct {
|
||||
llmServer
|
||||
|
||||
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
|
||||
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
|
||||
}
|
||||
|
||||
// LoadModel will load a model from disk. The model must be in the GGML format.
|
||||
@@ -141,11 +143,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
var llamaModel *llama.Model
|
||||
var textProcessor model.TextProcessor
|
||||
var tok tokenizer.Tokenizer
|
||||
var err error
|
||||
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
||||
if len(projectors) == 0 {
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
tok, err = model.NewTextProcessor(modelPath)
|
||||
} else {
|
||||
err = errors.New("split vision models aren't supported")
|
||||
}
|
||||
@@ -154,7 +156,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||
}
|
||||
}
|
||||
if textProcessor == nil {
|
||||
if tok == nil {
|
||||
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -210,7 +212,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
|
||||
if textProcessor == nil {
|
||||
if tok == nil {
|
||||
flashAttention := ml.FlashAttentionAuto
|
||||
if faUserSet {
|
||||
if fa {
|
||||
@@ -260,7 +262,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
gpuLibs := ml.LibraryPaths(gpus)
|
||||
status := NewStatusWriter(os.Stderr)
|
||||
cmd, port, err := StartRunner(
|
||||
textProcessor != nil,
|
||||
tok != nil,
|
||||
modelPath,
|
||||
gpuLibs,
|
||||
status,
|
||||
@@ -309,8 +311,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
}
|
||||
}()
|
||||
|
||||
if textProcessor != nil {
|
||||
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
|
||||
if tok != nil {
|
||||
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
|
||||
} else {
|
||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||
}
|
||||
@@ -1200,7 +1202,8 @@ func (s *llmServer) initModel(ctx context.Context, req LoadRequest, operation Lo
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do load request: %w", err)
|
||||
slog.Error("do load request", "error", err)
|
||||
return nil, errors.New("model failed to load, this may be due to resource limitations or an internal error, check ollama server logs for details")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
@@ -1772,7 +1775,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
|
||||
}
|
||||
|
||||
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
tokens, err := s.textProcessor.Encode(content, false)
|
||||
tokens, err := s.tokenizer.Encode(content, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1807,7 +1810,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
|
||||
toks[i] = int32(t)
|
||||
}
|
||||
|
||||
content, err := s.textProcessor.Decode(toks)
|
||||
content, err := s.tokenizer.Decode(toks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1901,6 +1904,10 @@ func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *llmServer) ContextLength() int {
|
||||
return s.options.NumCtx
|
||||
}
|
||||
|
||||
func (s *ollamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
devices, err := ml.GetDevicesFromRunner(ctx, s)
|
||||
if err != nil {
|
||||
|
||||
@@ -131,12 +131,15 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
|
||||
messageID := anthropic.GenerateMessageID()
|
||||
|
||||
// Estimate input tokens for streaming (actual count not available until generation completes)
|
||||
estimatedTokens := anthropic.EstimateInputTokens(req)
|
||||
|
||||
w := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model),
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
|
||||
@@ -170,10 +170,12 @@ type Tensor interface {
|
||||
Cos(ctx Context) Tensor
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context, up ...Tensor) Tensor
|
||||
GELU_ERF(ctx Context) Tensor
|
||||
QuickGELU(ctx Context, up ...Tensor) Tensor
|
||||
SILU(ctx Context, up ...Tensor) Tensor
|
||||
RELU(ctx Context, up ...Tensor) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
SigmoidOut(ctx Context) Tensor
|
||||
|
||||
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
@@ -206,6 +208,32 @@ type Tensor interface {
|
||||
Stddev(ctx Context) Tensor
|
||||
Sqr(ctx Context) Tensor
|
||||
Sqrt(ctx Context) Tensor
|
||||
Exp(ctx Context) Tensor
|
||||
Neg(ctx Context) Tensor
|
||||
|
||||
// Clamp clamps values to [min, max] range
|
||||
Clamp(ctx Context, min, max float32) Tensor
|
||||
|
||||
// Softplus computes ln(1 + exp(x))
|
||||
Softplus(ctx Context) Tensor
|
||||
|
||||
// CumSum computes cumulative sum along dimension 0
|
||||
CumSum(ctx Context) Tensor
|
||||
|
||||
// Diag creates a diagonal matrix from a 1D tensor
|
||||
Diag(ctx Context) Tensor
|
||||
|
||||
// Tri converts a matrix to triangular form (0=upper+diag, 1=upper, 2=lower+diag, 3=lower)
|
||||
Tri(ctx Context, triType int) Tensor
|
||||
|
||||
// Fill fills a tensor with a constant value (in-place)
|
||||
Fill(ctx Context, value float32) Tensor
|
||||
|
||||
// Repeat4D repeats tensor to match target shape
|
||||
Repeat4D(ctx Context, dim0, dim1, dim2, dim3 int) Tensor
|
||||
|
||||
// SolveTri solves a triangular system Ax = B
|
||||
SolveTri(ctx Context, b Tensor, lower, left, unitDiag bool) Tensor
|
||||
|
||||
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||
}
|
||||
|
||||
@@ -378,7 +378,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
}
|
||||
|
||||
maxGraphNodes := max(1024, len(meta.Tensors().Items())*8)
|
||||
maxGraphNodes := max(1024, len(meta.Tensors().Items())*32)
|
||||
|
||||
sched := C.ggml_backend_sched_new_ext(
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
|
||||
@@ -1468,6 +1468,13 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SigmoidOut(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sigmoid(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
@@ -1581,6 +1588,13 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) GELU_ERF(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_erf_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
var tt *C.struct_ggml_tensor
|
||||
if len(t2) > 0 {
|
||||
@@ -1772,6 +1786,76 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Exp(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_exp(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Softplus(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_softplus(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) CumSum(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cumsum(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Diag(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_diag(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Tri(ctx ml.Context, triType int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tri(ctx.(*Context).ctx, t.t, C.enum_ggml_tri_type(triType)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Fill(ctx ml.Context, value float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_fill_inplace(ctx.(*Context).ctx, t.t, C.float(value)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Repeat4D(ctx ml.Context, dim0, dim1, dim2, dim3 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_repeat_4d(ctx.(*Context).ctx, t.t, C.int64_t(dim0), C.int64_t(dim1), C.int64_t(dim2), C.int64_t(dim3)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SolveTri(ctx ml.Context, b ml.Tensor, lower, left, unitDiag bool) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_solve_tri(ctx.(*Context).ctx, t.t, b.(*Tensor).t, C._Bool(lower), C._Bool(left), C._Bool(unitDiag)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
|
||||
var mode C.uint32_t
|
||||
switch samplingMode {
|
||||
|
||||
@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_SOLVE_TRI);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_solve_tri_f32");
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_GROUP_NORM);
|
||||
|
||||
|
||||
@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
return ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op->src[1]) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
return has_simdgroup_reduction &&
|
||||
op->src[0]->type == GGML_TYPE_I32 &&
|
||||
op->src[1]->type == GGML_TYPE_I32 &&
|
||||
op->type == GGML_TYPE_I64;
|
||||
case GGML_OP_ARGMAX:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_NORM:
|
||||
|
||||
@@ -2385,6 +2385,27 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_solve_tri;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
@@ -5813,6 +5834,66 @@ kernel void kernel_l2_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_solve_tri_f32(
|
||||
constant ggml_metal_kargs_solve_tri & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
const uint64_t ncols = (uint64_t) args.ne10;
|
||||
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
|
||||
const uint64_t nr = n_batches * ncols;
|
||||
|
||||
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
|
||||
if (gid >= nr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
|
||||
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
|
||||
const uint64_t i02 = rem / ncols;
|
||||
const uint64_t i01 = rem - i02 * ncols;
|
||||
|
||||
const uint64_t sa0 = args.nb00 / sizeof(float);
|
||||
const uint64_t sa1 = args.nb01 / sizeof(float);
|
||||
const uint64_t sa2 = args.nb02 / sizeof(float);
|
||||
const uint64_t sa3 = args.nb03 / sizeof(float);
|
||||
|
||||
const uint64_t sb0 = args.nb10 / sizeof(float);
|
||||
const uint64_t sb1 = args.nb11 / sizeof(float);
|
||||
const uint64_t sb2 = args.nb12 / sizeof(float);
|
||||
const uint64_t sb3 = args.nb13 / sizeof(float);
|
||||
|
||||
const uint64_t sx0 = args.nb0 / sizeof(float);
|
||||
const uint64_t sx1 = args.nb1 / sizeof(float);
|
||||
const uint64_t sx2 = args.nb2 / sizeof(float);
|
||||
const uint64_t sx3 = args.nb3 / sizeof(float);
|
||||
|
||||
device const float * A = (device const float *) src0;
|
||||
device const float * B = (device const float *) src1;
|
||||
device float * X = (device float *) dst;
|
||||
|
||||
const uint64_t A_base = i02 * sa2 + i03 * sa3;
|
||||
const uint64_t B_base = i02 * sb2 + i03 * sb3;
|
||||
const uint64_t X_base = i02 * sx2 + i03 * sx3;
|
||||
|
||||
const uint64_t n = (uint64_t) args.ne11;
|
||||
|
||||
for (uint64_t i00 = 0; i00 < n; ++i00) {
|
||||
float sum = 0.0f;
|
||||
for (uint64_t t = 0; t < i00; ++t) {
|
||||
sum += A[A_base + i00 * sa1 + t * sa0] *
|
||||
X[X_base + t * sx1 + i01 * sx0];
|
||||
}
|
||||
|
||||
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
|
||||
X[X_base + i00 * sx1 + i01 * sx0] =
|
||||
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
|
||||
@@ -500,6 +500,27 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_solve_tri;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
|
||||
@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_group_norm(ctx, idx);
|
||||
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_solve_tri args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
|
||||
|
||||
const int64_t ncols = ne10;
|
||||
const int64_t n_batches = (int64_t)ne02 * ne03;
|
||||
const int64_t nr = n_batches * ncols;
|
||||
|
||||
int nth = 64;
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
if (nth < 1) {
|
||||
nth = 1;
|
||||
}
|
||||
|
||||
const int64_t n_tg = (nr + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
||||
@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_solve_tri_f32(
|
||||
constant ggml_metal_kargs_solve_tri & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
const uint64_t ncols = (uint64_t) args.ne10;
|
||||
const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
|
||||
const uint64_t nr = n_batches * ncols;
|
||||
|
||||
const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
|
||||
if (gid >= nr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
|
||||
const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
|
||||
const uint64_t i02 = rem / ncols;
|
||||
const uint64_t i01 = rem - i02 * ncols;
|
||||
|
||||
const uint64_t sa0 = args.nb00 / sizeof(float);
|
||||
const uint64_t sa1 = args.nb01 / sizeof(float);
|
||||
const uint64_t sa2 = args.nb02 / sizeof(float);
|
||||
const uint64_t sa3 = args.nb03 / sizeof(float);
|
||||
|
||||
const uint64_t sb0 = args.nb10 / sizeof(float);
|
||||
const uint64_t sb1 = args.nb11 / sizeof(float);
|
||||
const uint64_t sb2 = args.nb12 / sizeof(float);
|
||||
const uint64_t sb3 = args.nb13 / sizeof(float);
|
||||
|
||||
const uint64_t sx0 = args.nb0 / sizeof(float);
|
||||
const uint64_t sx1 = args.nb1 / sizeof(float);
|
||||
const uint64_t sx2 = args.nb2 / sizeof(float);
|
||||
const uint64_t sx3 = args.nb3 / sizeof(float);
|
||||
|
||||
device const float * A = (device const float *) src0;
|
||||
device const float * B = (device const float *) src1;
|
||||
device float * X = (device float *) dst;
|
||||
|
||||
const uint64_t A_base = i02 * sa2 + i03 * sa3;
|
||||
const uint64_t B_base = i02 * sb2 + i03 * sb3;
|
||||
const uint64_t X_base = i02 * sx2 + i03 * sx3;
|
||||
|
||||
const uint64_t n = (uint64_t) args.ne11;
|
||||
|
||||
for (uint64_t i00 = 0; i00 < n; ++i00) {
|
||||
float sum = 0.0f;
|
||||
for (uint64_t t = 0; t < i00; ++t) {
|
||||
sum += A[A_base + i00 * sa1 + t * sa0] *
|
||||
X[X_base + t * sx1 + i01 * sx0];
|
||||
}
|
||||
|
||||
const float diag = A[A_base + i00 * sa1 + i00 * sa0];
|
||||
X[X_base + i00 * sx1 + i01 * sx0] =
|
||||
(B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type BytePairEncoding struct {
|
||||
vocab *Vocabulary
|
||||
regexps []*regexp2.Regexp
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||
if len(pretokenizers) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizers {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||
return bpe.vocab
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||
return bpe.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||
parts := []string{s}
|
||||
for _, re := range bpe.regexps {
|
||||
parts = slices.Collect(func(yield func(string) bool) {
|
||||
for _, part := range parts {
|
||||
r := []rune(part)
|
||||
var offset int
|
||||
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
|
||||
if offset-m.Index != 0 {
|
||||
if !yield(string(r[:m.Index])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !yield(m.String()) {
|
||||
return
|
||||
}
|
||||
|
||||
offset = m.Index + m.Length
|
||||
}
|
||||
|
||||
if offset < len(r) {
|
||||
if !yield(string(r[offset:])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return slices.Values(parts)
|
||||
}
|
||||
|
||||
// fragment is a string fragment and their corresponding token IDs
|
||||
type fragment struct {
|
||||
value string
|
||||
ids []int32
|
||||
}
|
||||
|
||||
// pair is a pair of runes and its rank
|
||||
type pair struct {
|
||||
a, b int
|
||||
rank int
|
||||
value string
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
p, n int
|
||||
runes []rune
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
id := bpe.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var middle []fragment
|
||||
switch i := strings.Index(frag.value, special); {
|
||||
case i < 0:
|
||||
middle = append(middle, frag)
|
||||
case i > 0:
|
||||
middle = append(middle, fragment{value: frag.value[:i]})
|
||||
fallthrough
|
||||
default:
|
||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||
if rest := frag.value[i+len(special):]; rest != "" {
|
||||
middle = append(middle, fragment{value: rest})
|
||||
}
|
||||
}
|
||||
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
continue
|
||||
}
|
||||
|
||||
for split := range bpe.split(frag.value) {
|
||||
// TODO: process splits concurrently
|
||||
var sb strings.Builder
|
||||
for _, b := range []byte(split) {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
|
||||
// short circuit if the fragment is in the vocabulary
|
||||
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
runes := []rune(sb.String())
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *pair {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
rank := bpe.vocab.Merge(left, right)
|
||||
if rank < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &pair{
|
||||
a: a,
|
||||
b: b,
|
||||
rank: rank,
|
||||
value: left + right,
|
||||
}
|
||||
}
|
||||
|
||||
pairs := heap.NewWith(func(i, j *pair) int {
|
||||
return cmp.Compare(i.rank, j.rank)
|
||||
})
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for !pairs.Empty() {
|
||||
pair, _ := pairs.Pop()
|
||||
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
||||
string(left.runes)+string(right.runes) != pair.value {
|
||||
continue
|
||||
}
|
||||
|
||||
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial {
|
||||
ids = bpe.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
for _, r := range bpe.vocab.Decode(id) {
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// this produces 0x00 aka NULL
|
||||
continue
|
||||
case r == 0x0143:
|
||||
r = 0x00ad
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
r = r - 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
}
|
||||
|
||||
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||
// encoding of the rune which is _not_ what we want
|
||||
if err := sb.WriteByte(byte(r)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", ids)
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -133,7 +134,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, ok := m.(TextProcessor)
|
||||
tp, ok := m.(tokenizer.Tokenizer)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
@@ -56,6 +56,18 @@ type fakeTensor struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// Stub methods to satisfy ml.Tensor interface
|
||||
func (f *fakeTensor) Exp(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Neg(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Clamp(ctx ml.Context, _, _ float32) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Softplus(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) CumSum(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Diag(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
|
||||
|
||||
func (m *fakeBackend) Get(name string) ml.Tensor {
|
||||
if slices.Contains(m.names, name) {
|
||||
return &fakeTensor{Name: name}
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocab := &model.Vocabulary{
|
||||
vocab := &tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
|
||||
},
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
var t tokenizer.Tokenizer
|
||||
switch c.String("tokenizer.ggml.model", "bert") {
|
||||
case "bert":
|
||||
processor = model.NewWordPiece(vocab, true)
|
||||
t = tokenizer.NewWordPiece(vocab, true)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Tokenizer: t,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
Sam *samModel `gguf:"s"`
|
||||
Vision *visionModel `gguf:"v"`
|
||||
@@ -134,8 +135,8 @@ func init() {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -43,8 +44,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
poolingType pooling.Type
|
||||
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
|
||||
),
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
var t tokenizer.Tokenizer
|
||||
switch c.String("tokenizer.ggml.model") {
|
||||
case "gpt2":
|
||||
processor = model.NewBytePairEncoding(&vocabulary)
|
||||
t = tokenizer.NewBytePairEncoding(&vocabulary)
|
||||
default:
|
||||
// Previous uploads of Gemma 3 on Ollama did not have token 106
|
||||
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
t = tokenizer.NewSentencePiece(&vocabulary)
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Tokenizer: t,
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
|
||||
@@ -6,11 +6,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
}
|
||||
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
|
||||
@@ -198,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -236,8 +237,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
174
model/models/glmocr/imageprocessor.go
Normal file
174
model/models/glmocr/imageprocessor.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package glmocr
|
||||
|
||||
import (
|
||||
"image"
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize int
|
||||
patchSize int
|
||||
temporalPatchSize int
|
||||
spatialMergeSize int
|
||||
minPixels int
|
||||
maxPixels int
|
||||
factor int
|
||||
imageMean [3]float32
|
||||
imageStd [3]float32
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
|
||||
|
||||
// Read normalization values from config if available, otherwise use CLIP defaults
|
||||
imageMean := c.Floats("vision.image_mean", imageproc.ClipDefaultMean[:])
|
||||
imageStd := c.Floats("vision.image_std", imageproc.ClipDefaultSTD[:])
|
||||
|
||||
// Default max_pixels: 2048 * patchSize^2 * mergeSize^2 * temporal = ~3.2M pixels
|
||||
// This limits to ~16k patches (4k output tokens) to keep memory stable without flash attention
|
||||
defaultMaxPixels := 2048 * patchSize * patchSize * spatialMergeSize * spatialMergeSize * temporalPatchSize
|
||||
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size", 336)),
|
||||
patchSize: patchSize,
|
||||
temporalPatchSize: temporalPatchSize,
|
||||
spatialMergeSize: spatialMergeSize,
|
||||
minPixels: int(c.Uint("vision.min_pixels", uint32(8*patchSize*patchSize*spatialMergeSize*spatialMergeSize*temporalPatchSize))),
|
||||
maxPixels: int(c.Uint("vision.max_pixels", uint32(defaultMaxPixels))),
|
||||
factor: patchSize * spatialMergeSize,
|
||||
imageMean: [3]float32{imageMean[0], imageMean[1], imageMean[2]},
|
||||
imageStd: [3]float32{imageStd[0], imageStd[1], imageStd[2]},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
|
||||
factor := p.factor
|
||||
temporalFactor := p.temporalPatchSize
|
||||
numFrames := temporalFactor // single image
|
||||
|
||||
if height < factor || width < factor {
|
||||
// Scale up small images
|
||||
scale := float64(factor) / float64(min(height, width))
|
||||
height = int(math.Ceil(float64(height) * scale))
|
||||
width = int(math.Ceil(float64(width) * scale))
|
||||
}
|
||||
|
||||
if temporalFactor <= 0 {
|
||||
slog.Warn("temporal_patch_size must be > 0, defaulting to 1")
|
||||
temporalFactor = 1
|
||||
}
|
||||
if numFrames < temporalFactor {
|
||||
slog.Warn("num_frames must be >= temporal_patch_size, adjusting num_frames", "num_frames", numFrames, "temporal_patch_size", temporalFactor)
|
||||
numFrames = temporalFactor
|
||||
}
|
||||
if aspectRatio := float64(max(height, width)) / float64(min(height, width)); aspectRatio > 200 {
|
||||
slog.Warn("aspect ratio exceeds 200, image quality may be affected", "aspect_ratio", aspectRatio)
|
||||
}
|
||||
|
||||
round := func(x float64) int { return int(math.RoundToEven(x)) }
|
||||
|
||||
hBar := round(float64(height)/float64(factor)) * factor
|
||||
wBar := round(float64(width)/float64(factor)) * factor
|
||||
tBar := round(float64(numFrames)/float64(temporalFactor)) * temporalFactor
|
||||
|
||||
if tBar*hBar*wBar > p.maxPixels {
|
||||
beta := math.Sqrt(float64(numFrames*height*width) / float64(p.maxPixels))
|
||||
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
|
||||
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
|
||||
} else if tBar*hBar*wBar < p.minPixels {
|
||||
beta := math.Sqrt(float64(p.minPixels) / float64(numFrames*height*width))
|
||||
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
}
|
||||
|
||||
return hBar, wBar
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
|
||||
img = imageproc.Composite(img)
|
||||
|
||||
origWidth := img.Bounds().Dx()
|
||||
origHeight := img.Bounds().Dy()
|
||||
|
||||
// Calculate smart resize dimensions
|
||||
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
|
||||
|
||||
// Resize image
|
||||
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeCatmullrom)
|
||||
|
||||
// Normalize pixels - output format is [C, H, W] with rescale and channelFirst
|
||||
// We keep [C, H, W] for patch extraction
|
||||
normalizedPixels := imageproc.Normalize(resizedImg, p.imageMean, p.imageStd, true, true)
|
||||
|
||||
// Calculate grid dimensions (after Conv2D patching)
|
||||
grid := &Grid{
|
||||
Height: resizedHeight / p.patchSize,
|
||||
Width: resizedWidth / p.patchSize,
|
||||
Temporal: 1, // Single image
|
||||
ImageHeight: resizedHeight,
|
||||
ImageWidth: resizedWidth,
|
||||
}
|
||||
|
||||
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return patches, grid, nil
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
|
||||
channels := 3
|
||||
patchSize := p.patchSize
|
||||
mergeSize := p.spatialMergeSize
|
||||
temporalPatchSize := p.temporalPatchSize
|
||||
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
||||
result := make([]float32, numPatches*patchDim)
|
||||
patchIndex := 0
|
||||
|
||||
// Single temporal frame handling (copies to all frames)
|
||||
for range grid.Temporal {
|
||||
for h := 0; h < grid.Height; h += mergeSize {
|
||||
for w := 0; w < grid.Width; w += mergeSize {
|
||||
for mh := range mergeSize {
|
||||
for mw := range mergeSize {
|
||||
baseOffset := patchIndex * patchDim
|
||||
for c := range channels {
|
||||
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
|
||||
for py := range patchSize {
|
||||
for px := range patchSize {
|
||||
y := (h+mh)*patchSize + py
|
||||
x := (w+mw)*patchSize + px
|
||||
srcIdx := c*height*width + y*width + x
|
||||
dstIdx := channelOffset + (py * patchSize) + px
|
||||
result[dstIdx] = pixels[srcIdx]
|
||||
}
|
||||
}
|
||||
|
||||
if temporalPatchSize > 1 {
|
||||
frameSize := patchSize * patchSize
|
||||
for tp := 1; tp < temporalPatchSize; tp++ {
|
||||
currentFrameOffset := channelOffset + (tp * frameSize)
|
||||
copy(result[currentFrameOffset:currentFrameOffset+frameSize],
|
||||
result[channelOffset:channelOffset+frameSize])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
patchIndex++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
236
model/models/glmocr/model.go
Normal file
236
model/models/glmocr/model.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package glmocr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"image"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
VisionDownsample *VisionDownsample `gguf:"mm.patch_merger"`
|
||||
PatchMerger *PatchMerger `gguf:"mm"`
|
||||
|
||||
ImageProcessor
|
||||
|
||||
imageTokenID int32
|
||||
imageStartTokenID int32
|
||||
imageEndTokenID int32
|
||||
}
|
||||
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
eosTokenID := int32(c.Uint("tokenizer.ggml.eos_token_id"))
|
||||
eosTokenIDs := c.Ints("tokenizer.ggml.eos_token_ids")
|
||||
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
|
||||
|
||||
m := &Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: allEOS,
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
imageTokenID: int32(c.Uint("image_token_id", 59280)),
|
||||
imageStartTokenID: int32(c.Uint("image_start_token_id", 59256)),
|
||||
imageEndTokenID: int32(c.Uint("image_end_token_id", 59257)),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if len(m.VisionModel.Blocks) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, grid, err := m.ImageProcessor.ProcessImage(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create pixel values tensor from flattened patches
|
||||
// Shape: [patchDim, numPatches]
|
||||
patchDim := m.VisionModel.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
|
||||
|
||||
// Forward through vision encoder
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid)
|
||||
|
||||
// Forward through downsample (patch merger)
|
||||
if m.VisionDownsample == nil || m.VisionDownsample.Weight == nil {
|
||||
return nil, errors.New("glmocr: missing vision downsample weights")
|
||||
}
|
||||
visionOutputs = m.VisionDownsample.Forward(ctx, visionOutputs, grid, m.VisionModel.VisionModelOptions)
|
||||
|
||||
// Forward through patch merger (FC + LayerNorm + GELU + SwiGLU FFN)
|
||||
if m.PatchMerger == nil {
|
||||
return nil, errors.New("glmocr: missing patch merger weights")
|
||||
}
|
||||
visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.VisionModelOptions)
|
||||
|
||||
return []input.Multimodal{{Tensor: visionOutputs, Data: grid}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
// Reset position cache
|
||||
m.TextModel.positionCache = m.TextModel.positionCache[:0]
|
||||
m.TextModel.ropeDelta = 0
|
||||
|
||||
pos := int32(0)
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
result = append(result, inp)
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
// Get grid info for position calculation
|
||||
grid := inp.Multimodal[0].Data.(*Grid)
|
||||
mergedH := grid.Height / m.VisionModel.spatialMergeSize
|
||||
mergedW := grid.Width / m.VisionModel.spatialMergeSize
|
||||
|
||||
// Add image start token
|
||||
result = append(result, &input.Input{Token: m.imageStartTokenID})
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
|
||||
pos++
|
||||
|
||||
// Add image tokens with multimodal data
|
||||
// All image tokens share the same base position for temporal dimension
|
||||
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
|
||||
basePos := pos
|
||||
sameBatch := tokensPerGrid - 1
|
||||
if sameBatch < 0 {
|
||||
sameBatch = 0
|
||||
}
|
||||
result = append(result, &input.Input{
|
||||
Token: m.imageTokenID,
|
||||
Multimodal: inp.Multimodal,
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
SameBatch: sameBatch,
|
||||
})
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
|
||||
|
||||
// Add placeholder tokens for remaining positions
|
||||
// All image tokens use the same base position (temporal stays constant)
|
||||
for range tokensPerGrid - 1 {
|
||||
result = append(result, &input.Input{Token: m.imageTokenID})
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
|
||||
}
|
||||
|
||||
// Advance position by max(mergedH, mergedW) after image tokens
|
||||
pos = basePos + int32(max(mergedH, mergedW))
|
||||
|
||||
// Add image end token
|
||||
result = append(result, &input.Input{Token: m.imageEndTokenID})
|
||||
m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
|
||||
pos++
|
||||
}
|
||||
|
||||
// Compute rope delta for continuation after the prefill segment:
|
||||
// delta = (max_position_id + 1) - sequence_length
|
||||
if len(m.TextModel.positionCache) > 0 {
|
||||
last := m.TextModel.positionCache[len(m.TextModel.positionCache)-1]
|
||||
m.TextModel.ropeDelta = last + 1 - int32(len(m.TextModel.positionCache))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
// Initial token embedding
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
|
||||
ctx.Forward(hiddenStates)
|
||||
|
||||
// Build position slices for M-RoPE
|
||||
positionSlice := func() [][]int32 {
|
||||
s := [][]int32{
|
||||
make([]int32, len(batch.Positions)), // temporal
|
||||
make([]int32, len(batch.Positions)), // height
|
||||
make([]int32, len(batch.Positions)), // width
|
||||
make([]int32, len(batch.Positions)), // unused (zeros)
|
||||
}
|
||||
for i, position := range batch.Positions {
|
||||
// Translate through position cache or continue sequence
|
||||
if position < int32(len(m.TextModel.positionCache)) {
|
||||
position = m.TextModel.positionCache[position]
|
||||
} else if len(m.TextModel.positionCache) > 0 {
|
||||
// Continue sequence after cached positions using ropeDelta
|
||||
position = position + m.TextModel.ropeDelta
|
||||
}
|
||||
|
||||
s[0][i] = position
|
||||
s[1][i] = position
|
||||
s[2][i] = position
|
||||
}
|
||||
return s
|
||||
}()
|
||||
|
||||
// Inject vision embeddings and adjust positions for image tokens
|
||||
for _, mi := range batch.Multimodal {
|
||||
img := mi.Multimodal[0].Tensor
|
||||
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
|
||||
|
||||
if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
|
||||
w := grid.Width / m.VisionModel.spatialMergeSize
|
||||
for i := range img.Dim(1) {
|
||||
positionSlice[1][mi.Index+i] += int32(i / w)
|
||||
positionSlice[2][mi.Index+i] += int32(i % w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
|
||||
|
||||
// Process through transformer layers
|
||||
for i, layer := range m.TextModel.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.TextModel.Layers)-1 {
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, m.Cache, m.TextModel.TextModelOptions)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.TextModel.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("glmocr", New)
|
||||
}
|
||||
190
model/models/glmocr/model_text.go
Normal file
190
model/models/glmocr/model_text.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package glmocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
type TextModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
numKVHeads int
|
||||
headDim int
|
||||
rotaryDim int
|
||||
intermediateSize int
|
||||
eps float32
|
||||
ropeBase float32
|
||||
mropeSections []int
|
||||
}
|
||||
|
||||
func (o *TextModelOptions) applyMRoPE(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
// With 4 sections for [temporal, height, width, unused]
|
||||
return nn.RoPE(ctx, states, positions, o.rotaryDim, o.ropeBase, 1.0, rope.WithMRoPE(o.mropeSections))
|
||||
}
|
||||
|
||||
type TextSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
// Separate Q, K, V projections
|
||||
q := sa.Query.Forward(ctx, hiddenStates)
|
||||
k := sa.Key.Forward(ctx, hiddenStates)
|
||||
v := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
// Reshape for GQA
|
||||
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
|
||||
k = k.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
|
||||
v = v.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
// Apply M-RoPE (multi-resolution rotary position embeddings)
|
||||
q = opts.applyMRoPE(ctx, q, positions)
|
||||
k = opts.applyMRoPE(ctx, k, positions)
|
||||
|
||||
// Scaled dot-product attention with KV cache
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
// Reshape attention output: [headDim, numHeads, batchSize] -> [numHeads*headDim, batchSize]
|
||||
// Note: numHeads * headDim = 16 * 128 = 2048, which is the attention hidden size
|
||||
kqv = kqv.Reshape(ctx, opts.numHeads*opts.headDim, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextModelOptions) ml.Tensor {
|
||||
// SwiGLU: down(silu(gate(x)) * up(x))
|
||||
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, gate)
|
||||
}
|
||||
|
||||
type TextDecoderLayer struct {
|
||||
// Input layernorm (before attention)
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *TextSelfAttention
|
||||
// Post self-attention layernorm (after attention, before residual add)
|
||||
PostAttnNorm *nn.RMSNorm `gguf:"post_attn_norm"`
|
||||
|
||||
// FFN input layernorm (after first residual, before MLP)
|
||||
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *TextMLP
|
||||
// Post MLP layernorm (after MLP, before residual add)
|
||||
PostFFNNorm *nn.RMSNorm `gguf:"post_ffn_norm"`
|
||||
}
|
||||
|
||||
func (l *TextDecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
// Attention block
|
||||
residual := hiddenStates
|
||||
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = l.SelfAttention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
hiddenStates = l.PostAttnNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// Prune to output positions in final layer
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
// MLP block
|
||||
residual = hiddenStates
|
||||
hiddenStates = l.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = l.PostFFNNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []TextDecoderLayer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*TextModelOptions
|
||||
|
||||
// positionCache stores the M-RoPE position for each token in the sequence.
|
||||
// This is needed because image tokens share the same base position but have
|
||||
// different height/width offsets, and the end token position depends on the
|
||||
// image grid dimensions.
|
||||
positionCache []int32
|
||||
ropeDelta int32
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// Clear position cache when KV cache shifts
|
||||
m.positionCache = nil
|
||||
m.ropeDelta = 0
|
||||
return m.applyMRoPE(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
hiddenSize := int(c.Uint("embedding_length", 1536))
|
||||
numHeads := int(c.Uint("attention.head_count", 16))
|
||||
numKVHeads := int(c.Uint("attention.head_count_kv", 8))
|
||||
intermediateSize := int(c.Uint("feed_forward_length", 4608))
|
||||
eps := c.Float("attention.layer_norm_rms_epsilon", 1e-5)
|
||||
ropeBase := c.Float("rope.freq_base", 10000)
|
||||
|
||||
headDim := int(c.Uint("attention.key_length", uint32(hiddenSize/numHeads)))
|
||||
ropeDim := int(c.Uint("rope.dimension_count", uint32(headDim)))
|
||||
if ropeDim <= 0 {
|
||||
ropeDim = headDim
|
||||
}
|
||||
|
||||
mropeSections := c.Ints("rope.mrope_section")
|
||||
var sectionInts []int
|
||||
|
||||
if len(mropeSections) > 0 {
|
||||
sectionInts = make([]int, len(mropeSections))
|
||||
for i, section := range mropeSections {
|
||||
sectionInts[i] = int(section)
|
||||
}
|
||||
} else {
|
||||
// Default to GLM-OCR's HF ratio (2:3:3) scaled to rotaryDim/2.
|
||||
// For rotaryDim=64 this yields [8, 12, 12].
|
||||
total := ropeDim / 2
|
||||
if total <= 0 {
|
||||
total = 32
|
||||
}
|
||||
s0 := total * 2 / 8
|
||||
s1 := total * 3 / 8
|
||||
s2 := total - s0 - s1
|
||||
sectionInts = []int{s0, s1, s2}
|
||||
}
|
||||
|
||||
// GGML rope_multi: sector = (dim_pair) % sum(sections), mapping each pair to its position dim
|
||||
rotaryDim := ropeDim
|
||||
|
||||
return &TextModel{
|
||||
Layers: make([]TextDecoderLayer, c.Uint("block_count", 16)),
|
||||
TextModelOptions: &TextModelOptions{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
numKVHeads: numKVHeads,
|
||||
headDim: headDim,
|
||||
rotaryDim: rotaryDim,
|
||||
intermediateSize: intermediateSize,
|
||||
eps: eps,
|
||||
ropeBase: ropeBase,
|
||||
mropeSections: sectionInts,
|
||||
},
|
||||
}
|
||||
}
|
||||
355
model/models/glmocr/model_vision.go
Normal file
355
model/models/glmocr/model_vision.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package glmocr
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
type Grid struct {
|
||||
Height int // Number of patches in height direction
|
||||
Width int // Number of patches in width direction
|
||||
Temporal int
|
||||
ImageHeight int // Full image height in pixels
|
||||
ImageWidth int // Full image width in pixels
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
headDim int
|
||||
numChannels int
|
||||
patchSize int
|
||||
temporalPatchSize int
|
||||
imageSize int
|
||||
spatialMergeSize int
|
||||
outHiddenSize int
|
||||
intermediateSize int
|
||||
eps float32
|
||||
}
|
||||
|
||||
type VisionPatchEmbed struct {
|
||||
Proj *nn.Conv2D `gguf:"patch_embd_0"`
|
||||
Proj1 *nn.Conv2D `gguf:"patch_embd_1"`
|
||||
Bias ml.Tensor `gguf:"patch_embd.bias"`
|
||||
}
|
||||
|
||||
func (pe *VisionPatchEmbed) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
|
||||
_ = grid // patches are already in merge-block order
|
||||
|
||||
// pixelValues shape: [patchDim, numPatches]
|
||||
numPatches := pixelValues.Shape()[1]
|
||||
|
||||
// Reshape to [patchSize*patchSize, temporalPatchSize, numChannels, numPatches]
|
||||
pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
|
||||
// Permute to [temporalPatchSize, patchSize*patchSize, numChannels, numPatches]
|
||||
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Slice temporal frames for Conv2D (simulate Conv3D)
|
||||
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||
in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
|
||||
|
||||
s0, s1 := opts.patchSize, opts.patchSize
|
||||
p0, p1 := 0, 0
|
||||
d0, d1 := 1, 1
|
||||
hiddenStates := pe.Proj.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
|
||||
|
||||
if pe.Proj1 != nil && opts.temporalPatchSize > 1 {
|
||||
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||
in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
|
||||
out1 := pe.Proj1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
|
||||
hiddenStates = hiddenStates.Add(ctx, out1)
|
||||
}
|
||||
|
||||
// Flatten to [hidden_size, num_patches]
|
||||
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, numPatches)
|
||||
|
||||
// Add patch bias - reshape from [hidden_size] to [hidden_size, 1] for broadcasting
|
||||
if pe.Bias != nil {
|
||||
hiddenStates = hiddenStates.Add(ctx, pe.Bias.Reshape(ctx, opts.hiddenSize, 1))
|
||||
}
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
QKV *nn.Linear `gguf:"attn_qkv"`
|
||||
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
// Combined QKV projection: [3*hidden_size, batch_size]
|
||||
qkv := sa.QKV.Forward(ctx, hiddenStates)
|
||||
|
||||
// Split using ChunkSections along dim 0 (handles byte offsets correctly)
|
||||
// ChunkSections returns views - must make contiguous before further operations
|
||||
chunks := qkv.ChunkSections(ctx, 0, opts.hiddenSize, opts.hiddenSize, opts.hiddenSize)
|
||||
q := chunks[0].Contiguous(ctx)
|
||||
k := chunks[1].Contiguous(ctx)
|
||||
v := chunks[2].Contiguous(ctx)
|
||||
|
||||
// Reshape for multi-head attention: [hiddenSize, N] -> [headDim, numHeads, N]
|
||||
q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
|
||||
k = k.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
|
||||
v = v.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
|
||||
|
||||
// Apply Q-norm and K-norm after head reshape
|
||||
// Weights are [headDim]=64, tensor is [headDim, numHeads, N]
|
||||
q = sa.QNorm.Forward(ctx, q, opts.eps)
|
||||
k = sa.KNorm.Forward(ctx, k, opts.eps)
|
||||
|
||||
// Apply rotary position embeddings with vision-style 2D positions.
|
||||
// ggml's vision RoPE uses two position dimensions (H/W) with half-rotation pairs.
|
||||
// We provide H/W sections and leave the remaining sections empty.
|
||||
ropeFreqBase := float32(10000.0)
|
||||
section := opts.headDim / 4
|
||||
if section <= 0 {
|
||||
section = 1
|
||||
}
|
||||
sections := []int{section, section, 0, 0}
|
||||
q = nn.RoPE(ctx, q, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
|
||||
k = nn.RoPE(ctx, k, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
|
||||
|
||||
// Scale factor for scaled dot-product attention
|
||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||
|
||||
// Try flash attention first (ScaledDotProductAttention), fall back to manual
|
||||
if sdpa, ok := q.(ml.ScaledDotProductAttention); ok {
|
||||
attention := sdpa.ScaledDotProductAttention(ctx, k, v, nil, nil, nil, scale, false)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
slog.Warn("glmocr: vision attention falling back to manual attention",
|
||||
"batchSize", batchSize, "numHeads", opts.numHeads,
|
||||
"hint", "set OLLAMA_FLASH_ATTENTION=1 to enable flash attention")
|
||||
|
||||
// Manual attention fallback
|
||||
// q, k, v are [headDim, numHeads, batchSize] - GGML treats as 4D with implicit dim 3 = 1
|
||||
q = q.Permute(ctx, 0, 2, 1, 3)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
// Attention scores
|
||||
kq := k.MulmatFullPrec(ctx, q)
|
||||
kq = kq.Scale(ctx, scale)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
// Attention output: v @ kq (note: v first)
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
// SwiGLU: down(silu(gate(x)) * up(x))
|
||||
gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, gate)
|
||||
}
|
||||
|
||||
type VisionBlock struct {
|
||||
Norm1 *nn.RMSNorm `gguf:"ln1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
Norm2 *nn.RMSNorm `gguf:"ln2"`
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (b *VisionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
// Pre-norm architecture
|
||||
residual := hiddenStates
|
||||
hiddenStates = b.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = b.SelfAttention.Forward(ctx, hiddenStates, positions, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = b.Norm2.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = b.MLP.Forward(ctx, hiddenStates)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type VisionDownsample struct {
|
||||
*nn.Conv2D
|
||||
}
|
||||
|
||||
func (d *VisionDownsample) Forward(ctx ml.Context, hiddenStates ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
|
||||
// Apply spatial downsampling via Conv2D
|
||||
// Input: [hidden_size, num_patches] where patches are in merge-block order
|
||||
|
||||
if d.Conv2D == nil || d.Weight == nil {
|
||||
slog.Error("VisionDownsample weights not loaded - model may be corrupted or incompatible")
|
||||
return hiddenStates // Return input unchanged as fallback
|
||||
}
|
||||
|
||||
merge := opts.spatialMergeSize
|
||||
numOutputTokens := (grid.Height / merge) * (grid.Width / merge)
|
||||
|
||||
// Step 1: Reshape to [hidden_size, merge, merge, num_output_tokens]
|
||||
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, merge, merge, numOutputTokens)
|
||||
|
||||
// Step 2: Permute to [merge, merge, hidden_size, num_output_tokens]
|
||||
// ggml semantics: result.ne[perm[i]] = input.ne[i]
|
||||
// So permute(2,0,1,3) on [1024,2,2,N] gives: ne[2]=1024, ne[0]=2, ne[1]=2, ne[3]=N -> [2,2,1024,N]
|
||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
|
||||
// Step 3: Apply Conv2D without bias (bias added after reshape)
|
||||
// Note: ggml_conv_2d takes (kernel, input) - kernel must be receiver in ollama
|
||||
s0, s1 := merge, merge
|
||||
p0, p1 := 0, 0
|
||||
d0, d1 := 1, 1
|
||||
hiddenStates = d.Weight.Conv2D(ctx, hiddenStates, s0, s1, p0, p1, d0, d1)
|
||||
|
||||
// Step 4: Reshape to [out_hidden_size, num_output_tokens]
|
||||
hiddenStates = hiddenStates.Reshape(ctx, opts.outHiddenSize, numOutputTokens)
|
||||
|
||||
// Step 5: Add bias after reshape
|
||||
// Reshape bias from [out_hidden_size] to [out_hidden_size, 1] for proper broadcasting
|
||||
if d.Bias != nil {
|
||||
hiddenStates = hiddenStates.Add(ctx, d.Bias.Reshape(ctx, opts.outHiddenSize, 1))
|
||||
}
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type PatchMerger struct {
|
||||
// GGUF tags align with mm.* keys used by the model
|
||||
Proj *nn.Linear `gguf:"model.fc"` // mm.model.fc.weight
|
||||
PostLN *nn.LayerNorm `gguf:"post_norm"` // mm.post_norm.weight/bias
|
||||
GateProj *nn.Linear `gguf:"gate"` // mm.gate.weight
|
||||
UpProj *nn.Linear `gguf:"up"` // mm.up.weight
|
||||
DownProj *nn.Linear `gguf:"down"` // mm.down.weight
|
||||
}
|
||||
|
||||
func (m *PatchMerger) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
// Linear projection
|
||||
hiddenStates = m.Proj.Forward(ctx, hiddenStates)
|
||||
|
||||
// Post-projection layer norm + GELU ERF
|
||||
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = hiddenStates.GELU_ERF(ctx)
|
||||
// Force a copy to avoid in-place mutation issues with GELU_ERF
|
||||
hiddenStates = hiddenStates.Contiguous(ctx)
|
||||
|
||||
// SwiGLU MLP: down(silu(gate(x)) * up(x))
|
||||
gateOut := m.GateProj.Forward(ctx, hiddenStates)
|
||||
upOut := m.UpProj.Forward(ctx, hiddenStates)
|
||||
gate := gateOut.SILU(ctx, upOut)
|
||||
return m.DownProj.Forward(ctx, gate)
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbed *VisionPatchEmbed
|
||||
Blocks []VisionBlock `gguf:"blk"`
|
||||
PostLN *nn.RMSNorm `gguf:"post_ln"`
|
||||
// Note: Downsample is applied at the model level so mm.patch_merger stays separate
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
|
||||
// Extract patch embeddings from flattened patches
|
||||
hiddenStates := m.PatchEmbed.Forward(ctx, pixelValues, grid, m.VisionModelOptions)
|
||||
|
||||
// Create position IDs for RoPE (spatial grid)
|
||||
// Patches are already in merge-block order from preprocessing
|
||||
positions := m.createPositions(ctx, grid)
|
||||
|
||||
// Process through vision blocks
|
||||
for _, block := range m.Blocks {
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
// Post-layernorm
|
||||
hiddenStates = m.PostLN.Forward(ctx, hiddenStates, m.eps)
|
||||
|
||||
// Note: Downsample is now applied separately in Model.EncodeMultimodal
|
||||
// so mm.patch_merger remains a distinct module
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
func (m *VisionModel) createPositions(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||
// Create spatial position IDs for vision RoPE
|
||||
// Position layout: [height, width, height, width] - 4 sections for mrope
|
||||
// Patches are in MERGE-BLOCK order after VisionPatchEmbed interleaving
|
||||
// This follows the GLM-OCR rot_pos_emb layout
|
||||
numPatches := grid.Height * grid.Width
|
||||
mergeRatio := m.spatialMergeSize
|
||||
|
||||
// Build position arrays in merge-block order
|
||||
// Each merge_ratio x merge_ratio block of patches is grouped together
|
||||
hpos := make([]int32, numPatches)
|
||||
wpos := make([]int32, numPatches)
|
||||
ptr := 0
|
||||
for y := 0; y < grid.Height; y += mergeRatio {
|
||||
for x := 0; x < grid.Width; x += mergeRatio {
|
||||
for dy := range mergeRatio {
|
||||
for dx := range mergeRatio {
|
||||
hpos[ptr] = int32(y + dy)
|
||||
wpos[ptr] = int32(x + dx)
|
||||
ptr++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build position arrays for 4 sections (mrope). ggml vision RoPE uses only H/W;
|
||||
// keep remaining sections zeroed to match its conventions.
|
||||
zeros := make([]int32, numPatches)
|
||||
s := [][]int32{
|
||||
hpos, // Section 0: height
|
||||
wpos, // Section 1: width
|
||||
zeros, // Section 2: unused
|
||||
zeros, // Section 3: unused
|
||||
}
|
||||
|
||||
return ctx.Input().FromInts(slices.Concat(s...), numPatches*4)
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
hiddenSize := int(c.Uint("vision.embedding_length", 1024))
|
||||
numHeads := int(c.Uint("vision.attention.head_count", 16))
|
||||
numChannels := int(c.Uint("vision.num_channels", 3))
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
|
||||
imageSize := int(c.Uint("vision.image_size", 336))
|
||||
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||
outHiddenSize := int(c.Uint("vision.out_hidden_size", 1536))
|
||||
intermediateSize := int(c.Uint("vision.intermediate_size", 4096))
|
||||
eps := c.Float("vision.attention.layer_norm_rms_epsilon", 1e-5)
|
||||
|
||||
return &VisionModel{
|
||||
Blocks: make([]VisionBlock, c.Uint("vision.block_count", 24)),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
headDim: hiddenSize / numHeads,
|
||||
numChannels: numChannels,
|
||||
patchSize: patchSize,
|
||||
temporalPatchSize: temporalPatchSize,
|
||||
imageSize: imageSize,
|
||||
spatialMergeSize: spatialMergeSize,
|
||||
outHiddenSize: outHiddenSize,
|
||||
intermediateSize: intermediateSize,
|
||||
eps: eps,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Transformer struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TransformerBlocks []TransformerBlock `gguf:"blk"`
|
||||
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Transformer{
|
||||
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -59,7 +60,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -78,7 +79,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
vocabulary := model.Vocabulary{
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -104,8 +105,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
vocabulary := model.Vocabulary{
|
||||
var processor tokenizer.Tokenizer
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
}
|
||||
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
case "llama":
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
processor = tokenizer.NewSentencePiece(&vocabulary)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
ImageProcessor
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -28,12 +29,12 @@ type Model struct {
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
// Implement TextProcessor interface
|
||||
var _ model.TextProcessor = (*Model)(nil)
|
||||
var _ tokenizer.Tokenizer = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -32,8 +33,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/glm4moelite"
|
||||
_ "github.com/ollama/ollama/model/models/glmocr"
|
||||
_ "github.com/ollama/ollama/model/models/gptoss"
|
||||
_ "github.com/ollama/ollama/model/models/lfm2"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
@@ -19,5 +20,6 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3next"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3vl"
|
||||
)
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -178,29 +179,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
headDim := hiddenSize / numHeads
|
||||
|
||||
processor := model.NewWordPiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
)
|
||||
|
||||
blockCount := int(c.Uint("block_count"))
|
||||
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
|
||||
layers := make([]EncoderLayer, blockCount)
|
||||
@@ -219,8 +197,29 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: layers,
|
||||
Tokenizer: tokenizer.NewWordPiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: Options{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,7 +34,7 @@ type Options struct {
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -44,28 +45,24 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
processor := model.NewBytePairEncoding(
|
||||
&vocabulary,
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []DecoderLayer `gguf:"blk"`
|
||||
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
m := Model{
|
||||
Layers: make([]DecoderLayer, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*Model
|
||||
poolingType pooling.Type
|
||||
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
m := embedModel{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user