mirror of
https://github.com/ollama/ollama.git
synced 2026-02-12 08:33:46 -05:00
Compare commits
28 Commits
drifkin/de
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ecb89007e | ||
|
|
da0190f8e4 | ||
|
|
1f2594f50f | ||
|
|
3fb0958e01 | ||
|
|
f08427c138 | ||
|
|
2dbb000908 | ||
|
|
c980e19995 | ||
|
|
6162374ca9 | ||
|
|
44bdd9a2ef | ||
|
|
db493d6e5e | ||
|
|
75695f16a5 | ||
|
|
a0407d07fa | ||
|
|
9ec733e527 | ||
|
|
5ef04dab52 | ||
|
|
aea316f1e9 | ||
|
|
235ba3df5c | ||
|
|
099a0f18ef | ||
|
|
fff696ee31 | ||
|
|
2e3ce6eab3 | ||
|
|
9e2003f88a | ||
|
|
42e1d49fbe | ||
|
|
814630ca60 | ||
|
|
87cf187774 | ||
|
|
6ddd8862cd | ||
|
|
f1373193dc | ||
|
|
8a4b77f9da | ||
|
|
5f53fe7884 | ||
|
|
7ab4ca0e7f |
6
.github/workflows/release.yaml
vendored
6
.github/workflows/release.yaml
vendored
@@ -337,6 +337,7 @@ jobs:
|
||||
name: bundles-windows
|
||||
path: |
|
||||
dist/*.zip
|
||||
dist/*.ps1
|
||||
dist/OllamaSetup.exe
|
||||
|
||||
linux-build:
|
||||
@@ -514,6 +515,9 @@ jobs:
|
||||
- name: Log dist contents
|
||||
run: |
|
||||
ls -l dist/
|
||||
- name: Copy install scripts to dist
|
||||
run: |
|
||||
cp scripts/install.sh dist/install.sh
|
||||
- name: Generate checksum file
|
||||
run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
|
||||
working-directory: dist
|
||||
@@ -536,7 +540,7 @@ jobs:
|
||||
- name: Upload release artifacts
|
||||
run: |
|
||||
pids=()
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
|
||||
echo "Uploading $payload"
|
||||
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
||||
pids[$!]=$!
|
||||
|
||||
22
.github/workflows/test-install.yaml
vendored
Normal file
22
.github/workflows/test-install.yaml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
name: test-install
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'scripts/install.sh'
|
||||
- '.github/workflows/test-install.yaml'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Run install script
|
||||
run: sh ./scripts/install.sh
|
||||
env:
|
||||
OLLAMA_NO_START: 1 # do not start app
|
||||
- name: Verify ollama is available
|
||||
run: ollama --version
|
||||
@@ -182,7 +182,7 @@ option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||
|
||||
if(MLX_ENGINE)
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
||||
|
||||
# Find CUDA toolkit if MLX is built with CUDA support
|
||||
find_package(CUDAToolkit)
|
||||
@@ -216,4 +216,4 @@ if(MLX_ENGINE)
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -147,7 +147,7 @@ ARG PARALLEL
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY x/imagegen/mlx x/imagegen/mlx
|
||||
COPY go.mod go.sum .
|
||||
COPY MLX_VERSION .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
|
||||
@@ -518,24 +518,26 @@ func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
estimatedInputTokens: estimatedInputTokens,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -551,7 +553,11 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
// Use actual metrics if available, otherwise use estimate
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
|
||||
c.inputTokens = c.estimatedInputTokens
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
@@ -779,3 +785,117 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// CountTokensRequest represents an Anthropic count_tokens request
|
||||
type CountTokensRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
|
||||
func EstimateInputTokens(req MessagesRequest) int {
|
||||
return estimateTokens(CountTokensRequest{
|
||||
Model: req.Model,
|
||||
Messages: req.Messages,
|
||||
System: req.System,
|
||||
Tools: req.Tools,
|
||||
Thinking: req.Thinking,
|
||||
})
|
||||
}
|
||||
|
||||
// CountTokensResponse represents an Anthropic count_tokens response
|
||||
type CountTokensResponse struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
}
|
||||
|
||||
// estimateTokens returns a rough estimate of tokens (len/4).
|
||||
// TODO: Replace with actual tokenization via Tokenize API for accuracy.
|
||||
// Current len/4 heuristic is a rough approximation (~4 chars/token average).
|
||||
func estimateTokens(req CountTokensRequest) int {
|
||||
var totalLen int
|
||||
|
||||
// Count system prompt
|
||||
if req.System != nil {
|
||||
totalLen += countAnyContent(req.System)
|
||||
}
|
||||
|
||||
// Count messages
|
||||
for _, msg := range req.Messages {
|
||||
// Count role (always present)
|
||||
totalLen += len(msg.Role)
|
||||
// Count content
|
||||
contentLen := countAnyContent(msg.Content)
|
||||
totalLen += contentLen
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
|
||||
}
|
||||
|
||||
// Return len/4 as rough token estimate, minimum 1 if there's any content
|
||||
tokens := totalLen / 4
|
||||
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
|
||||
tokens = 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func countAnyContent(content any) int {
|
||||
if content == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return len(c)
|
||||
case []any:
|
||||
total := 0
|
||||
for _, block := range c {
|
||||
total += countContentBlock(block)
|
||||
}
|
||||
return total
|
||||
default:
|
||||
if data, err := json.Marshal(content); err == nil {
|
||||
return len(data)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func countContentBlock(block any) int {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
if s, ok := block.(string); ok {
|
||||
return len(s)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
total := 0
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
total += len(text)
|
||||
}
|
||||
|
||||
if thinking, ok := blockMap["thinking"].(string); ok {
|
||||
total += len(thinking)
|
||||
}
|
||||
|
||||
if blockType == "tool_use" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if blockType == "tool_result" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -321,8 +321,6 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
||||
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
||||
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
@@ -605,7 +603,7 @@ func TestGenerateMessageID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
@@ -678,7 +676,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -731,7 +729,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
@@ -778,7 +776,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
@@ -842,10 +840,6 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
||||
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
||||
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
||||
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
||||
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -899,11 +893,9 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -937,7 +929,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -969,3 +961,105 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEstimateTokens_SimpleMessage(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// "user" (4) + "Hello, world!" (13) = 17 chars / 4 = 4 tokens
|
||||
if tokens < 1 {
|
||||
t.Errorf("expected at least 1 token, got %d", tokens)
|
||||
}
|
||||
// Sanity check: shouldn't be wildly off
|
||||
if tokens > 10 {
|
||||
t.Errorf("expected fewer than 10 tokens for short message, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// System prompt adds to count
|
||||
if tokens < 5 {
|
||||
t.Errorf("expected at least 5 tokens with system prompt, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithTools(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather for a location",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// Tools add significant content
|
||||
if tokens < 10 {
|
||||
t.Errorf("expected at least 10 tokens with tools, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_WithThinking(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this carefully...",
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "Here is my response.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
// Thinking content should be counted
|
||||
if tokens < 10 {
|
||||
t.Errorf("expected at least 10 tokens with thinking content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens_EmptyContent(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{},
|
||||
}
|
||||
|
||||
tokens := estimateTokens(req)
|
||||
|
||||
if tokens != 0 {
|
||||
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -466,3 +466,25 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// AliasRequest is the request body for creating or updating a model alias.
|
||||
type AliasRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
|
||||
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
// AliasDeleteRequest is the request body for deleting a model alias.
|
||||
type AliasDeleteRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
|
||||
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
|
||||
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
13
cmd/background_unix.go
Normal file
13
cmd/background_unix.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !windows
|
||||
|
||||
package cmd
|
||||
|
||||
import "syscall"
|
||||
|
||||
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Unix.
|
||||
// Setpgid prevents the server from being killed when the parent process exits.
|
||||
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
}
|
||||
}
|
||||
12
cmd/background_windows.go
Normal file
12
cmd/background_windows.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package cmd
|
||||
|
||||
import "syscall"
|
||||
|
||||
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Windows.
|
||||
// CREATE_NO_WINDOW (0x08000000) prevents a console window from appearing.
|
||||
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
CreationFlags: 0x08000000,
|
||||
HideWindow: true,
|
||||
}
|
||||
}
|
||||
264
cmd/cmd.go
264
cmd/cmd.go
@@ -15,6 +15,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -37,6 +38,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/tui"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
@@ -53,6 +55,49 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem) (string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
}
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
}
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
config.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
userName, err := tui.RunSignIn(modelName, signInURL)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
return userName, err
|
||||
}
|
||||
|
||||
config.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
ok, err := tui.RunConfirm(prompt)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return false, config.ErrCancelled
|
||||
}
|
||||
return ok, err
|
||||
}
|
||||
}
|
||||
|
||||
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
@@ -1763,7 +1808,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||
return err
|
||||
}
|
||||
if err := startApp(cmd.Context(), client); err != nil {
|
||||
return fmt.Errorf("ollama server not responding - %w", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -1804,6 +1849,197 @@ Environment Variables:
|
||||
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
|
||||
}
|
||||
|
||||
// ensureServerRunning checks if the ollama server is running and starts it in the background if not.
|
||||
func ensureServerRunning(ctx context.Context) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if server is already running
|
||||
if err := client.Heartbeat(ctx); err == nil {
|
||||
return nil // server is already running
|
||||
}
|
||||
|
||||
// Server not running, start it in the background
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not find executable: %w", err)
|
||||
}
|
||||
|
||||
serverCmd := exec.CommandContext(ctx, exe, "serve")
|
||||
serverCmd.Env = os.Environ()
|
||||
serverCmd.SysProcAttr = backgroundServerSysProcAttr()
|
||||
if err := serverCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the server to be ready
|
||||
for {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := client.Heartbeat(ctx); err == nil {
|
||||
return nil // server has started
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runInteractiveTUI runs the main interactive TUI menu.
|
||||
func runInteractiveTUI(cmd *cobra.Command) {
|
||||
// Ensure the server is running before showing the TUI
|
||||
if err := ensureServerRunning(cmd.Context()); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem) (string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
}
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
}
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
for {
|
||||
result, err := tui.Run()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
runModel := func(modelName string) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil {
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
_ = config.SetLastModel(modelName)
|
||||
opts := runOptions{
|
||||
Model: modelName,
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := generateInteractive(cmd, opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
launchIntegration := func(name string) bool {
|
||||
// If not configured or model no longer exists, prompt for model selection
|
||||
configuredModel := config.IntegrationModel(name)
|
||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
return false // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
|
||||
return true
|
||||
}
|
||||
}
|
||||
if err := config.LaunchIntegration(name); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
switch result.Selection {
|
||||
case tui.SelectionNone:
|
||||
// User quit
|
||||
return
|
||||
case tui.SelectionRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
if modelName := config.LastModel(); modelName != "" {
|
||||
runModel(modelName)
|
||||
} else {
|
||||
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
runModel(modelName)
|
||||
}
|
||||
case tui.SelectionChangeRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
// Use model from modal if selected, otherwise show picker
|
||||
modelName := result.Model
|
||||
if modelName == "" {
|
||||
var err error
|
||||
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
runModel(modelName)
|
||||
case tui.SelectionIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
if !launchIntegration(result.Integration) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
case tui.SelectionChangeIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
// Use model from modal if selected, otherwise show picker
|
||||
if result.Model != "" {
|
||||
// Model already selected from modal - save and launch
|
||||
if err := config.SaveIntegrationModel(result.Integration, result.Model); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
} else {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegration(result.Integration); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewCLI() *cobra.Command {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
cobra.EnableCommandSorting = false
|
||||
@@ -1826,11 +2062,13 @@ func NewCLI() *cobra.Command {
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Print(cmd.UsageString())
|
||||
runInteractiveTUI(cmd)
|
||||
},
|
||||
}
|
||||
|
||||
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
||||
rootCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
rootCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
@@ -1934,6 +2172,15 @@ func NewCLI() *cobra.Command {
|
||||
RunE: SigninHandler,
|
||||
}
|
||||
|
||||
loginCmd := &cobra.Command{
|
||||
Use: "login",
|
||||
Short: "Sign in to ollama.com",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(0),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SigninHandler,
|
||||
}
|
||||
|
||||
signoutCmd := &cobra.Command{
|
||||
Use: "signout",
|
||||
Short: "Sign out from ollama.com",
|
||||
@@ -1942,6 +2189,15 @@ func NewCLI() *cobra.Command {
|
||||
RunE: SignoutHandler,
|
||||
}
|
||||
|
||||
logoutCmd := &cobra.Command{
|
||||
Use: "logout",
|
||||
Short: "Sign out from ollama.com",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(0),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SignoutHandler,
|
||||
}
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
@@ -2038,13 +2294,15 @@ func NewCLI() *cobra.Command {
|
||||
pullCmd,
|
||||
pushCmd,
|
||||
signinCmd,
|
||||
loginCmd,
|
||||
signoutCmd,
|
||||
logoutCmd,
|
||||
listCmd,
|
||||
psCmd,
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
config.LaunchCmd(checkServerHeartbeat),
|
||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
@@ -1,18 +1,23 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration
|
||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
// Compile-time check that Claude implements AliasConfigurer
|
||||
var _ AliasConfigurer = (*Claude)(nil)
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
@@ -53,10 +58,135 @@ func (c *Claude) Run(model string, args []string) error {
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
|
||||
env := append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
|
||||
env = append(env, c.modelEnvVars(model)...)
|
||||
|
||||
cmd.Env = env
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||
func (c *Claude) modelEnvVars(model string) []string {
|
||||
primary := model
|
||||
fast := model
|
||||
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
|
||||
if p := cfg.Aliases["primary"]; p != "" {
|
||||
primary = p
|
||||
}
|
||||
if f := cfg.Aliases["fast"]; f != "" {
|
||||
fast = f
|
||||
}
|
||||
}
|
||||
return []string{
|
||||
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
|
||||
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureAliases sets up model aliases for Claude Code.
|
||||
// model: the model to use (if empty, user will be prompted to select)
|
||||
// aliases: existing alias configuration to preserve/update
|
||||
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
|
||||
// there is a better strategy for prompt caching on local models.
|
||||
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existingAliases {
|
||||
aliases[k] = v
|
||||
}
|
||||
|
||||
if model != "" {
|
||||
aliases["primary"] = model
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" {
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
||||
return aliases, false, nil
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
return aliases, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := DefaultSingleSelector("Select model:", items)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
aliases["primary"] = primary
|
||||
}
|
||||
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
return aliases, true, nil
|
||||
}
|
||||
|
||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
||||
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
|
||||
// prevent stale routing to a previous cloud model.
|
||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
if aliases["fast"] == "" {
|
||||
for _, prefix := range prefixes {
|
||||
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
prefixAliases := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
var errs []string
|
||||
for prefix, target := range prefixAliases {
|
||||
req := &api.AliasRequest{
|
||||
Alias: prefix,
|
||||
Target: target,
|
||||
PrefixMatching: true,
|
||||
}
|
||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
||||
errs = append(errs, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -103,3 +104,95 @@ func TestClaudeArgs(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeModelEnvVars(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
envMap := func(envs []string) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for _, e := range envs {
|
||||
k, v, _ := strings.Cut(e, "=")
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
got := envMap(c.modelEnvVars("llama3.2"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
||||
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2" {
|
||||
t.Errorf("SONNET = %q, want llama3.2", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2" {
|
||||
t.Errorf("HAIKU = %q, want llama3.2", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"qwen3:8b"})
|
||||
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
||||
|
||||
got := envMap(c.modelEnvVars("qwen3:8b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses fast alias for haiku", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"llama3.2:70b"})
|
||||
saveAliases("claude", map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
})
|
||||
|
||||
got := envMap(c.modelEnvVars("llama3.2:70b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
|
||||
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("alias primary overrides model param", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"saved-model"})
|
||||
saveAliases("claude", map[string]string{"primary": "saved-model"})
|
||||
|
||||
got := envMap(c.modelEnvVars("different-model"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
|
||||
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,21 +3,26 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
LastModel string `json:"last_model,omitempty"`
|
||||
LastSelection string `json:"last_selection,omitempty"` // "run" or integration name
|
||||
}
|
||||
|
||||
func configPath() (string, error) {
|
||||
@@ -53,7 +58,6 @@ func migrateConfig() (bool, error) {
|
||||
|
||||
var js json.RawMessage
|
||||
if err := json.Unmarshal(oldData, &js); err != nil {
|
||||
slog.Warn("legacy config has invalid JSON, skipping migration", "path", oldPath, "error", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -72,7 +76,6 @@ func migrateConfig() (bool, error) {
|
||||
_ = os.Remove(oldPath)
|
||||
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
|
||||
|
||||
slog.Info("migrated config", "from", oldPath, "to", newPath)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -133,13 +136,89 @@ func saveIntegration(appName string, models []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
||||
Models: models,
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
var aliases map[string]string
|
||||
if existing != nil && existing.Aliases != nil {
|
||||
aliases = existing.Aliases
|
||||
}
|
||||
|
||||
cfg.Integrations[key] = &integration{
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||
func IntegrationModel(appName string) string {
|
||||
ic, err := loadIntegration(appName)
|
||||
if err != nil || len(ic.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return ic.Models[0]
|
||||
}
|
||||
|
||||
// LastModel returns the last model that was run, or empty string if none.
|
||||
func LastModel() string {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cfg.LastModel
|
||||
}
|
||||
|
||||
// SetLastModel saves the last model that was run.
|
||||
func SetLastModel(model string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.LastModel = model
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// LastSelection returns the last menu selection ("run" or integration name), or empty string if none.
|
||||
func LastSelection() string {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cfg.LastSelection
|
||||
}
|
||||
|
||||
// SetLastSelection saves the last menu selection ("run" or integration name).
|
||||
func SetLastSelection(selection string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.LastSelection = selection
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// ModelExists checks if a model exists on the Ollama server.
|
||||
func ModelExists(ctx context.Context, name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func loadIntegration(appName string) (*integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
@@ -154,6 +233,29 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return ic, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
if existing == nil {
|
||||
existing = &integration{}
|
||||
}
|
||||
|
||||
// Replace aliases entirely (not merge) so deletions are persisted
|
||||
existing.Aliases = aliases
|
||||
|
||||
cfg.Integrations[key] = existing
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
func listIntegrations() ([]integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
|
||||
677
cmd/config/config_cloud_test.go
Normal file
677
cmd/config/config_cloud_test.go
Normal file
@@ -0,0 +1,677 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSetAliases_CloudModel(t *testing.T) {
|
||||
// Test the SetAliases logic by checking the alias map behavior
|
||||
aliases := map[string]string{
|
||||
"primary": "kimi-k2.5:cloud",
|
||||
"fast": "kimi-k2.5:cloud",
|
||||
}
|
||||
|
||||
// Verify fast is set (cloud model behavior)
|
||||
if aliases["fast"] == "" {
|
||||
t.Error("cloud model should have fast alias set")
|
||||
}
|
||||
if aliases["fast"] != aliases["primary"] {
|
||||
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalModel(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:latest",
|
||||
}
|
||||
// Simulate local model behavior: fast should be empty
|
||||
delete(aliases, "fast")
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("local model should have empty fast alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save with both primary and fast
|
||||
initial := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
if err := saveAliases("claude", initial); err != nil {
|
||||
t.Fatalf("failed to save initial aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify both are saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
|
||||
// Now save without fast (simulating switch to local model)
|
||||
updated := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast intentionally missing
|
||||
}
|
||||
if err := saveAliases("claude", updated); err != nil {
|
||||
t.Fatalf("failed to save updated aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify fast is GONE (not merged/preserved)
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save integration with models
|
||||
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
|
||||
t.Fatalf("failed to save integration: %v", err)
|
||||
}
|
||||
|
||||
// Then update aliases
|
||||
aliases := map[string]string{"primary": "new-model"}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
|
||||
t.Errorf("models should be preserved, got %v", loaded.Models)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyMap clears all aliases
|
||||
func TestSaveAliases_EmptyMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save empty map
|
||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
||||
t.Fatalf("failed to save empty: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) != 0 {
|
||||
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_NilMap handles nil gracefully
|
||||
func TestSaveAliases_NilMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases first
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save nil map - should clear aliases
|
||||
if err := saveAliases("claude", nil); err != nil {
|
||||
t.Fatalf("failed to save nil: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) > 0 {
|
||||
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyAppName returns error
|
||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||
err := saveAliases("", map[string]string{"primary": "model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Load with different case
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model1" {
|
||||
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
|
||||
// Update with different case
|
||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to update: %v", err)
|
||||
}
|
||||
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model2" {
|
||||
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
|
||||
func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases for non-existent integration
|
||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("newintegration")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureAliases_AliasMap(t *testing.T) {
|
||||
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
|
||||
aliases := make(map[string]string)
|
||||
aliases["primary"] = "cloud-model"
|
||||
|
||||
// Simulate cloud model behavior
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model preserves custom fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "custom-fast-model",
|
||||
}
|
||||
|
||||
// Simulate cloud model behavior - should preserve existing fast
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "custom-fast-model" {
|
||||
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local model clears fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
"fast": "should-be-cleared",
|
||||
}
|
||||
|
||||
// Simulate local model behavior
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching cloud to local clears fast", func(t *testing.T) {
|
||||
// Start with cloud config
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
|
||||
// Switch to local
|
||||
aliases["primary"] = "local-model"
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
|
||||
}
|
||||
if aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated, got %q", aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching local to cloud sets fast", func(t *testing.T) {
|
||||
// Start with local config (no fast)
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
}
|
||||
|
||||
// Switch to cloud
|
||||
aliases["primary"] = "cloud-model"
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetAliases_PrefixMapping(t *testing.T) {
|
||||
// This tests the expected mapping without needing a real client
|
||||
aliases := map[string]string{
|
||||
"primary": "my-cloud-model",
|
||||
"fast": "my-fast-model",
|
||||
}
|
||||
|
||||
expectedMappings := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
|
||||
t.Errorf("claude-sonnet- should map to primary")
|
||||
}
|
||||
if expectedMappings["claude-haiku-"] != "my-fast-model" {
|
||||
t.Errorf("claude-haiku- should map to fast")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast is empty/missing - indicates local model
|
||||
}
|
||||
|
||||
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
// Verify the logic: when fast is empty, we should delete
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("fast should be empty for local model")
|
||||
}
|
||||
|
||||
// Verify we have the right prefixes to delete
|
||||
if len(prefixesToDelete) != 2 {
|
||||
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
|
||||
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server fails, config should NOT be saved
|
||||
serverErr := errors.New("server unavailable")
|
||||
|
||||
if serverErr == nil {
|
||||
t.Error("config should NOT be saved when server fails")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
|
||||
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server succeeds, config should be saved
|
||||
var serverErr error
|
||||
if serverErr != nil {
|
||||
t.Fatal("server should succeed")
|
||||
}
|
||||
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("saveAliases failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was actually saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Write config with extra fields
|
||||
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
os.MkdirAll(filepath.Dir(configPath), 0o755)
|
||||
|
||||
// Note: Our config struct only has Integrations, so top-level unknown fields
|
||||
// won't be preserved by our current implementation. This test documents that.
|
||||
initialConfig := `{
|
||||
"integrations": {
|
||||
"claude": {
|
||||
"models": ["model1"],
|
||||
"aliases": {"primary": "model1"},
|
||||
"unknownField": "should be lost"
|
||||
}
|
||||
},
|
||||
"topLevelUnknown": "will be lost"
|
||||
}`
|
||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||
|
||||
// Update aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Read raw file to check
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
// models should be preserved
|
||||
if !contains(content, "model1") {
|
||||
t.Error("models should be preserved")
|
||||
}
|
||||
|
||||
// primary should be updated
|
||||
if !contains(content, "model2") {
|
||||
t.Error("primary should be updated to model2")
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
||||
c := &Claude{}
|
||||
var _ AliasConfigurer = c // Compile-time check
|
||||
}
|
||||
|
||||
func TestModelNameEdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
model string
|
||||
}{
|
||||
{"simple", "llama3.2"},
|
||||
{"with tag", "llama3.2:latest"},
|
||||
{"with cloud tag", "kimi-k2.5:cloud"},
|
||||
{"with namespace", "library/llama3.2"},
|
||||
{"with dots", "glm-4.7-flash"},
|
||||
{"with numbers", "qwen3:8b"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
aliases := map[string]string{"primary": tc.model}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != tc.model {
|
||||
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchingScenarios(t *testing.T) {
|
||||
t.Run("cloud to local removes fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to local (no fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local to cloud adds fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial local config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to cloud (with fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud to different cloud updates both", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-1",
|
||||
"fast": "cloud-model-1",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to different cloud
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-2",
|
||||
"fast": "cloud-model-2",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model-2" {
|
||||
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolCapabilityFiltering(t *testing.T) {
|
||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
||||
// Both cloud and local models are checked for tool capability via Show API
|
||||
// Only models with "tools" in capabilities are included
|
||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("tool capable model should be marked as such")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("ToolCapable field should be accessible")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsCloudModel_RequiresClient(t *testing.T) {
|
||||
t.Run("nil client always returns false", func(t *testing.T) {
|
||||
// isCloudModel now only uses Show API, no suffix detection
|
||||
if isCloudModel(context.Background(), nil, "model:cloud") {
|
||||
t.Error("nil client should return false regardless of suffix")
|
||||
}
|
||||
if isCloudModel(context.Background(), nil, "local-model") {
|
||||
t.Error("nil client should return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases with one model
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Save integration with same model (this is the pattern we use)
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("out of sync config is detectable", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate out-of-sync state (like manual edit or bug)
|
||||
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
|
||||
// They should be different (this is the bug state)
|
||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||
t.Error("expected out-of-sync state for this test")
|
||||
}
|
||||
|
||||
// The fix: when updating aliases, also update models
|
||||
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ = loadIntegration("claude")
|
||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||
loaded.Models[0], loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updating primary alias updates models too", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial state
|
||||
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Update aliases AND models together
|
||||
newAliases := map[string]string{"primary": "updated-model"}
|
||||
if err := saveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Models[0] != "updated-model" {
|
||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "updated-model" {
|
||||
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -46,6 +46,53 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save and load aliases", func(t *testing.T) {
|
||||
models := []string{"llama3.2"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases == nil {
|
||||
t.Fatal("expected aliases to be saved")
|
||||
}
|
||||
for k, v := range aliases {
|
||||
if config.Aliases[k] != v {
|
||||
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases["primary"] != "model-a" {
|
||||
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
@@ -112,9 +114,17 @@ func (d *Droid) Edit(models []string) error {
|
||||
}
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
var newModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
maxOutput := 64000
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
maxOutput = l.Output
|
||||
}
|
||||
}
|
||||
modelID := fmt.Sprintf("custom:%s-%d", model, i)
|
||||
newModels = append(newModels, modelEntry{
|
||||
Model: model,
|
||||
@@ -122,7 +132,7 @@ func (d *Droid) Edit(models []string) error {
|
||||
BaseURL: envconfig.Host().String() + "/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: 64000,
|
||||
MaxOutputTokens: maxOutput,
|
||||
SupportsImages: false,
|
||||
ID: modelID,
|
||||
Index: i,
|
||||
|
||||
@@ -1251,6 +1251,55 @@ func TestDroidEdit_LargeNumberOfModels(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||
|
||||
if err := d.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(settingsPath)
|
||||
var settings map[string]any
|
||||
json.Unmarshal(data, &settings)
|
||||
|
||||
models := settings["customModels"].([]any)
|
||||
entry := models[0].(map[string]any)
|
||||
if entry["maxOutputTokens"] != float64(64000) {
|
||||
t.Errorf("local model maxOutputTokens = %v, want 64000", entry["maxOutputTokens"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
||||
// Verify that every cloud model in cloudModelLimits has a valid output
|
||||
// value that would be used for maxOutputTokens when isCloudModel returns true.
|
||||
// :cloud suffix stripping must also work since that's how users specify them.
|
||||
for name, expected := range cloudModelLimits {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(name)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
|
||||
}
|
||||
if l.Output != expected.Output {
|
||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||
}
|
||||
// Also verify :cloud suffix lookup
|
||||
cloudName := name + ":cloud"
|
||||
l2, ok := lookupCloudModelLimit(cloudName)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
||||
}
|
||||
if l2.Output != expected.Output {
|
||||
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDroidEdit_ArraysWithMixedTypes(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@@ -39,6 +39,15 @@ type Editor interface {
|
||||
Models() []string
|
||||
}
|
||||
|
||||
// AliasConfigurer can configure model aliases (e.g., for subagent routing).
|
||||
// Integrations like Claude and Codex use this to route model requests to local models.
|
||||
type AliasConfigurer interface {
|
||||
// ConfigureAliases prompts the user to configure aliases and returns the updated map.
|
||||
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
|
||||
// SetAliases syncs the configured aliases to the server
|
||||
SetAliases(ctx context.Context, aliases map[string]string) error
|
||||
}
|
||||
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
@@ -48,30 +57,268 @@ var integrations = map[string]Runner{
|
||||
"droid": &Droid{},
|
||||
"opencode": &OpenCode{},
|
||||
"openclaw": &Openclaw{},
|
||||
"pi": &Pi{},
|
||||
}
|
||||
|
||||
// recommendedModels are shown when the user has no models or as suggestions.
|
||||
// Order matters: local models first, then cloud models.
|
||||
var recommendedModels = []selectItem{
|
||||
{Name: "glm-4.7-flash", Description: "Recommended (requires ~25GB VRAM)"},
|
||||
{Name: "qwen3:8b", Description: "Recommended (requires ~11GB VRAM)"},
|
||||
{Name: "glm-4.7:cloud", Description: "Recommended"},
|
||||
{Name: "kimi-k2.5:cloud", Description: "Recommended"},
|
||||
var recommendedModels = []ModelItem{
|
||||
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
|
||||
{Name: "glm-4.7:cloud", Description: "Reasoning and code generation", Recommended: true},
|
||||
{Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
|
||||
{Name: "qwen3:8b", Description: "Efficient all-purpose assistant", Recommended: true},
|
||||
}
|
||||
|
||||
// recommendedVRAM maps local recommended models to their approximate VRAM requirement.
|
||||
var recommendedVRAM = map[string]string{
|
||||
"glm-4.7-flash": "~25GB",
|
||||
"qwen3:8b": "~11GB",
|
||||
}
|
||||
|
||||
// integrationAliases are hidden from the interactive selector but work as CLI arguments.
|
||||
var integrationAliases = map[string]bool{
|
||||
"clawdbot": true,
|
||||
"moltbot": true,
|
||||
"pi": true,
|
||||
}
|
||||
|
||||
// integrationInstallHints maps integration names to install URLs.
|
||||
var integrationInstallHints = map[string]string{
|
||||
"claude": "https://code.claude.com/docs/en/quickstart",
|
||||
"openclaw": "https://docs.openclaw.ai",
|
||||
"codex": "https://developers.openai.com/codex/cli/",
|
||||
"droid": "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||
"opencode": "https://opencode.ai",
|
||||
}
|
||||
|
||||
// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable.
|
||||
func hyperlink(url, text string) string {
|
||||
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
|
||||
}
|
||||
|
||||
// IntegrationInfo contains display information about a registered integration.
|
||||
type IntegrationInfo struct {
|
||||
Name string // registry key, e.g. "claude"
|
||||
DisplayName string // human-readable, e.g. "Claude Code"
|
||||
Description string // short description, e.g. "Anthropic's agentic coding tool"
|
||||
}
|
||||
|
||||
// integrationDescriptions maps integration names to short descriptions.
|
||||
var integrationDescriptions = map[string]string{
|
||||
"claude": "Anthropic's coding tool with subagents",
|
||||
"codex": "OpenAI's open-source coding agent",
|
||||
"openclaw": "Personal AI with 100+ skills",
|
||||
"droid": "Factory's coding agent across terminal and IDEs",
|
||||
"opencode": "Anomaly's open-source coding agent",
|
||||
}
|
||||
|
||||
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name.
|
||||
func ListIntegrationInfos() []IntegrationInfo {
|
||||
var result []IntegrationInfo
|
||||
for name, r := range integrations {
|
||||
if integrationAliases[name] {
|
||||
continue
|
||||
}
|
||||
result = append(result, IntegrationInfo{
|
||||
Name: name,
|
||||
DisplayName: r.String(),
|
||||
Description: integrationDescriptions[name],
|
||||
})
|
||||
}
|
||||
slices.SortFunc(result, func(a, b IntegrationInfo) int {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
// IntegrationInstallHint returns a user-friendly install hint for the given integration,
|
||||
// or an empty string if none is available. The URL is wrapped in an OSC 8 hyperlink
|
||||
// so it is cmd+clickable in supported terminals.
|
||||
func IntegrationInstallHint(name string) string {
|
||||
url := integrationInstallHints[name]
|
||||
if url == "" {
|
||||
return ""
|
||||
}
|
||||
return "Install from " + hyperlink(url, url)
|
||||
}
|
||||
|
||||
// IsIntegrationInstalled checks if an integration binary is installed.
|
||||
func IsIntegrationInstalled(name string) bool {
|
||||
switch name {
|
||||
case "claude":
|
||||
c := &Claude{}
|
||||
_, err := c.findPath()
|
||||
return err == nil
|
||||
case "openclaw":
|
||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||
return true
|
||||
}
|
||||
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
case "codex":
|
||||
_, err := exec.LookPath("codex")
|
||||
return err == nil
|
||||
case "droid":
|
||||
_, err := exec.LookPath("droid")
|
||||
return err == nil
|
||||
case "opencode":
|
||||
_, err := exec.LookPath("opencode")
|
||||
return err == nil
|
||||
default:
|
||||
return true // Assume installed for unknown integrations
|
||||
}
|
||||
}
|
||||
|
||||
// SelectModel lets the user select a model to run.
|
||||
// ModelItem represents a model for selection.
|
||||
type ModelItem struct {
|
||||
Name string
|
||||
Description string
|
||||
Recommended bool
|
||||
}
|
||||
|
||||
// SingleSelector is a function type for single item selection.
|
||||
type SingleSelector func(title string, items []ModelItem) (string, error)
|
||||
|
||||
// MultiSelector is a function type for multi item selection.
|
||||
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||
|
||||
// SelectModelWithSelector prompts the user to select a model using the provided selector.
|
||||
func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (string, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
lastModel := LastModel()
|
||||
var preChecked []string
|
||||
if lastModel != "" {
|
||||
preChecked = []string{lastModel}
|
||||
}
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, preChecked, lastModel)
|
||||
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
selected, err := selector("Select model to run:", items)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// If the selected model isn't installed, pull it first
|
||||
if !existingModels[selected] {
|
||||
msg := fmt.Sprintf("Download %s?", selected)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return "", err
|
||||
} else if !ok {
|
||||
return "", errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, selected); err != nil {
|
||||
return "", fmt.Errorf("failed to pull %s: %w", selected, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If it's a cloud model, ensure user is signed in
|
||||
if cloudModels[selected] {
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return "", err
|
||||
}
|
||||
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", selected))
|
||||
if err != nil || !yes {
|
||||
return "", fmt.Errorf("%s requires sign in", selected)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return "", ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func SelectModel(ctx context.Context) (string, error) {
|
||||
return SelectModelWithSelector(ctx, DefaultSingleSelector)
|
||||
}
|
||||
|
||||
// DefaultSingleSelector is the default single-select implementation.
|
||||
var DefaultSingleSelector SingleSelector
|
||||
|
||||
// DefaultMultiSelector is the default multi-select implementation.
|
||||
var DefaultMultiSelector MultiSelector
|
||||
|
||||
// DefaultSignIn provides a TUI-based sign-in flow.
|
||||
// When set, ensureAuth uses it instead of plain text prompts.
|
||||
// Returns the signed-in username or an error.
|
||||
var DefaultSignIn func(modelName, signInURL string) (string, error)
|
||||
|
||||
func selectIntegration() (string, error) {
|
||||
if DefaultSingleSelector == nil {
|
||||
return "", fmt.Errorf("no selector configured")
|
||||
}
|
||||
if len(integrations) == 0 {
|
||||
return "", fmt.Errorf("no integrations available")
|
||||
}
|
||||
|
||||
names := slices.Sorted(maps.Keys(integrations))
|
||||
var items []selectItem
|
||||
var items []ModelItem
|
||||
for _, name := range names {
|
||||
if integrationAliases[name] {
|
||||
continue
|
||||
@@ -81,14 +328,14 @@ func selectIntegration() (string, error) {
|
||||
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
||||
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
|
||||
}
|
||||
items = append(items, selectItem{Name: name, Description: description})
|
||||
items = append(items, ModelItem{Name: name, Description: description})
|
||||
}
|
||||
|
||||
return selectPrompt("Select integration:", items)
|
||||
return DefaultSingleSelector("Select integration:", items)
|
||||
}
|
||||
|
||||
// selectModels lets the user select models for an integration
|
||||
func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
|
||||
func selectModelsWithSelectors(ctx context.Context, name, current string, single SingleSelector, multi MultiSelector) ([]string, error) {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||
@@ -124,12 +371,16 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
|
||||
var selected []string
|
||||
if _, ok := r.(Editor); ok {
|
||||
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
selected, err = multi(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
|
||||
prompt := fmt.Sprintf("Select model for %s:", r)
|
||||
if _, ok := r.(AliasConfigurer); ok {
|
||||
prompt = fmt.Sprintf("Select Primary model for %s:", r)
|
||||
}
|
||||
model, err := single(prompt, items)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -157,73 +408,157 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
|
||||
if existingModels[model] {
|
||||
return nil
|
||||
}
|
||||
msg := fmt.Sprintf("Download %s?", model)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, model); err != nil {
|
||||
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): pull this out to tui package
|
||||
// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
|
||||
func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
|
||||
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||
return nil
|
||||
}
|
||||
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return pullModel(ctx, client, model)
|
||||
}
|
||||
|
||||
func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]bool, *api.Client, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{
|
||||
Name: m.Name,
|
||||
Remote: m.RemoteModel != "",
|
||||
})
|
||||
}
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
return items, existingModels, cloudModels, client, nil
|
||||
}
|
||||
|
||||
func OpenBrowser(url string) {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", url).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", url).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||
}
|
||||
}
|
||||
|
||||
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) > 0 {
|
||||
// ensure user is signed in
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
if len(selectedCloudModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return err
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
|
||||
if DefaultSignIn != nil {
|
||||
_, err := DefaultSignIn(modelList, aErr.SigninURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return nil, err
|
||||
}
|
||||
// Fallback: plain text sign-in flow
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return nil, fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
OpenBrowser(aErr.SigninURL)
|
||||
|
||||
// TODO(parthsareen): extract into auth package for cmd
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
// selectModels lets the user select models for an integration using default selectors.
|
||||
func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
return selectModelsWithSelectors(ctx, name, current, DefaultSingleSelector, DefaultMultiSelector)
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string, args []string) error {
|
||||
@@ -231,19 +566,151 @@ func runIntegration(name, modelName string, args []string) error {
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName, args)
|
||||
}
|
||||
|
||||
// syncAliases syncs aliases to server and saves locally for an AliasConfigurer.
|
||||
func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, name, model string, existing map[string]string) error {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existing {
|
||||
aliases[k] = v
|
||||
}
|
||||
aliases["primary"] = model
|
||||
|
||||
if isCloudModel(ctx, client, model) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = model
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if err := ac.SetAliases(ctx, aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
return saveAliases(name, aliases)
|
||||
}
|
||||
|
||||
// LaunchIntegration launches the named integration using saved config or prompts for setup.
|
||||
func LaunchIntegration(name string) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
// Try to use saved config
|
||||
if ic, err := loadIntegration(name); err == nil && len(ic.Models) > 0 {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ShowOrPull(context.Background(), client, ic.Models[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
return runIntegration(name, ic.Models[0], nil)
|
||||
}
|
||||
|
||||
// No saved config - prompt user to run setup
|
||||
return fmt.Errorf("%s is not configured. Run 'ollama launch %s' to set it up", r, name)
|
||||
}
|
||||
|
||||
// LaunchIntegrationWithModel launches the named integration with the specified model.
|
||||
func LaunchIntegrationWithModel(name, modelName string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ShowOrPull(context.Background(), client, modelName); err != nil {
|
||||
return err
|
||||
}
|
||||
return runIntegration(name, modelName, nil)
|
||||
}
|
||||
|
||||
// SaveIntegrationModel saves the model for an integration.
|
||||
func SaveIntegrationModel(name, modelName string) error {
|
||||
// Load existing models and prepend the new one
|
||||
var models []string
|
||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||
models = existing.Models
|
||||
// Remove the model if it already exists
|
||||
for i, m := range models {
|
||||
if m == modelName {
|
||||
models = append(models[:i], models[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Prepend the new model
|
||||
models = append([]string{modelName}, models...)
|
||||
return saveIntegration(name, models)
|
||||
}
|
||||
|
||||
// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors.
|
||||
func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
models, err := selectModelsWithSelectors(ctx, name, "", single, multi)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return errCancelled
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if editor, isEditor := r.(Editor); isEditor {
|
||||
paths := editor.Paths()
|
||||
if len(paths) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
|
||||
for _, p := range paths {
|
||||
fmt.Fprintf(os.Stderr, " %s\n", p)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
|
||||
|
||||
if ok, _ := confirmPrompt("Proceed?"); !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := editor.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := saveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
if len(models) == 1 {
|
||||
fmt.Fprintf(os.Stderr, "Configured %s with %s\n", r, models[0])
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Configured %s with %d models (default: %s)\n", r, len(models), models[0])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigureIntegration allows the user to select/change the model for an integration.
|
||||
func ConfigureIntegration(ctx context.Context, name string) error {
|
||||
return ConfigureIntegrationWithSelectors(ctx, name, DefaultSingleSelector, DefaultMultiSelector)
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
|
||||
// The runTUI callback is called when no arguments are provided (alias for main TUI).
|
||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
|
||||
var modelFlag string
|
||||
var configFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||
Short: "Launch an integration with Ollama",
|
||||
Long: `Launch an integration configured with Ollama models.
|
||||
Short: "Launch the Ollama menu or an integration",
|
||||
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
|
||||
|
||||
Without arguments, this is equivalent to running 'ollama' directly.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
@@ -262,6 +729,12 @@ Examples:
|
||||
Args: cobra.ArbitraryArgs,
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// No args and no flags - show the full TUI (same as bare 'ollama')
|
||||
if len(args) == 0 && modelFlag == "" && !configFlag {
|
||||
runTUI(cmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract integration name and args to pass through using -- separator
|
||||
var name string
|
||||
var passArgs []string
|
||||
@@ -302,9 +775,102 @@ Examples:
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
if !configFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0], passArgs)
|
||||
// Handle AliasConfigurer integrations (claude, codex)
|
||||
if ac, ok := r.(AliasConfigurer); ok {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate --model flag if provided
|
||||
if modelFlag != "" {
|
||||
if err := ShowOrPull(cmd.Context(), client, modelFlag); err != nil {
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var model string
|
||||
var existingAliases map[string]string
|
||||
|
||||
// Load saved config
|
||||
if cfg, err := loadIntegration(name); err == nil {
|
||||
existingAliases = cfg.Aliases
|
||||
if len(cfg.Models) > 0 {
|
||||
model = cfg.Models[0]
|
||||
// AliasConfigurer integrations use single model; sanitize if multiple
|
||||
if len(cfg.Models) > 1 {
|
||||
_ = saveIntegration(name, []string{model})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --model flag overrides saved model
|
||||
if modelFlag != "" {
|
||||
model = modelFlag
|
||||
}
|
||||
|
||||
// Validate saved model still exists
|
||||
if model != "" && modelFlag == "" {
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
|
||||
if err := ShowOrPull(cmd.Context(), client, model); err != nil {
|
||||
model = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no valid model or --config flag, show picker
|
||||
if model == "" || configFlag {
|
||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model = aliases["primary"]
|
||||
existingAliases = aliases
|
||||
}
|
||||
|
||||
// Ensure cloud models are authenticated
|
||||
if isCloudModel(cmd.Context(), client, model) {
|
||||
if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Sync aliases and save
|
||||
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
|
||||
}
|
||||
if err := saveIntegration(name, []string{model}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
// Launch (unless --config without confirmation)
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
|
||||
return runIntegration(name, model, passArgs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return runIntegration(name, model, passArgs)
|
||||
}
|
||||
|
||||
// Validate --model flag for non-AliasConfigurer integrations
|
||||
if modelFlag != "" {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ShowOrPull(cmd.Context(), client, modelFlag); err != nil {
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -318,6 +884,8 @@ Examples:
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
|
||||
return runIntegration(name, saved.Models[0], passArgs)
|
||||
} else {
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
@@ -380,20 +948,23 @@ Examples:
|
||||
}
|
||||
|
||||
type modelInfo struct {
|
||||
Name string
|
||||
Remote bool
|
||||
Name string
|
||||
Remote bool
|
||||
ToolCapable bool
|
||||
}
|
||||
|
||||
// buildModelList merges existing models with recommendations, sorts them, and returns
|
||||
// the ordered items along with maps of existing and cloud model names.
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []selectItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
existingModels = make(map[string]bool)
|
||||
cloudModels = make(map[string]bool)
|
||||
recommended := make(map[string]bool)
|
||||
var hasLocalModel, hasCloudModel bool
|
||||
|
||||
recDesc := make(map[string]string)
|
||||
for _, rec := range recommendedModels {
|
||||
recommended[rec.Name] = true
|
||||
recDesc[rec.Name] = rec.Description
|
||||
}
|
||||
|
||||
for _, m := range existing {
|
||||
@@ -406,10 +977,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
}
|
||||
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||
existingModels[displayName] = true
|
||||
item := selectItem{Name: displayName}
|
||||
if recommended[displayName] {
|
||||
item.Description = "recommended"
|
||||
}
|
||||
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
@@ -418,7 +986,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if isCloudModel(rec.Name) {
|
||||
if strings.HasSuffix(rec.Name, ":cloud") {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
@@ -446,31 +1014,76 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
for i := range items {
|
||||
if !existingModels[items[i].Name] {
|
||||
notInstalled[items[i].Name] = true
|
||||
var parts []string
|
||||
if items[i].Description != "" {
|
||||
items[i].Description += ", install?"
|
||||
} else {
|
||||
items[i].Description = "install?"
|
||||
parts = append(parts, items[i].Description)
|
||||
}
|
||||
if vram := recommendedVRAM[items[i].Name]; vram != "" {
|
||||
parts = append(parts, vram)
|
||||
}
|
||||
parts = append(parts, "install?")
|
||||
items[i].Description = strings.Join(parts, ", ")
|
||||
}
|
||||
}
|
||||
|
||||
// Build a recommended rank map to preserve ordering within tiers.
|
||||
recRank := make(map[string]int)
|
||||
for i, rec := range recommendedModels {
|
||||
recRank[rec.Name] = i + 1 // 1-indexed; 0 means not recommended
|
||||
}
|
||||
|
||||
onlyLocal := hasLocalModel && !hasCloudModel
|
||||
|
||||
if hasLocalModel || hasCloudModel {
|
||||
slices.SortStableFunc(items, func(a, b selectItem) int {
|
||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
|
||||
aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name]
|
||||
|
||||
// Checked/pre-selected always first
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if !ac && !bc && aNew != bNew {
|
||||
|
||||
// Recommended above non-recommended
|
||||
if aRec != bRec {
|
||||
if aRec {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// Both recommended
|
||||
if aRec && bRec {
|
||||
if aCloud != bCloud {
|
||||
if onlyLocal {
|
||||
// Local before cloud when only local installed
|
||||
if aCloud {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
// Cloud before local in mixed case
|
||||
if aCloud {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return recRank[a.Name] - recRank[b.Name]
|
||||
}
|
||||
|
||||
// Both non-recommended: installed before not-installed
|
||||
if aNew != bNew {
|
||||
if aNew {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
}
|
||||
@@ -478,8 +1091,45 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud")
|
||||
// isCloudModel checks if a model is a cloud model using the Show API.
|
||||
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||
if client == nil {
|
||||
return false
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return resp.RemoteModel != ""
|
||||
}
|
||||
|
||||
// GetModelItems returns a list of model items including recommendations for the TUI.
|
||||
// It includes all locally available models plus recommended models that aren't installed.
|
||||
func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
lastModel := LastModel()
|
||||
var preChecked []string
|
||||
if lastModel != "" {
|
||||
preChecked = []string{lastModel}
|
||||
}
|
||||
|
||||
items, _, existingModels, _ := buildModelList(existing, preChecked, lastModel)
|
||||
|
||||
return items, existingModels
|
||||
}
|
||||
|
||||
func pullModel(ctx context.Context, client *api.Client, model string) error {
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -88,8 +94,8 @@ func TestLaunchCmd(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck)
|
||||
mockTUI := func(cmd *cobra.Command) {}
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
@@ -122,6 +128,75 @@ func TestLaunchCmd(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchCmd_TUICallback(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("no args calls TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if !tuiCalled {
|
||||
t.Error("TUI callback should be called when no args provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"claude"})
|
||||
// Will error because claude isn't configured, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when integration arg provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--model flag bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--model", "test-model"})
|
||||
// Will error because no integration specified, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --model flag provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--config flag bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--config"})
|
||||
// Will error because no integration specified, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --config flag provided")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model", nil)
|
||||
if err == nil {
|
||||
@@ -162,7 +237,7 @@ func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
|
||||
|
||||
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
|
||||
// This should not panic - cmd creation should work even with nil
|
||||
cmd := LaunchCmd(nil)
|
||||
cmd := LaunchCmd(nil, nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("LaunchCmd returned nil")
|
||||
}
|
||||
@@ -297,27 +372,18 @@ func TestParseArgs(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIsCloudModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
{"glm-4.7:cloud", true},
|
||||
{"kimi-k2.5:cloud", true},
|
||||
{"glm-4.7-flash", false},
|
||||
{"glm-4.7-flash:latest", false},
|
||||
{"cloud-model", false},
|
||||
{"model:cloudish", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isCloudModel(tt.name); got != tt.want {
|
||||
t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
// isCloudModel now only uses Show API, so nil client always returns false
|
||||
t.Run("nil client returns false", func(t *testing.T) {
|
||||
models := []string{"glm-4.7:cloud", "kimi-k2.5:cloud", "local-model"}
|
||||
for _, model := range models {
|
||||
if isCloudModel(context.Background(), nil, model) {
|
||||
t.Errorf("isCloudModel(%q) with nil client should return false", model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func names(items []selectItem) []string {
|
||||
func names(items []ModelItem) []string {
|
||||
var out []string
|
||||
for _, item := range items {
|
||||
out = append(out, item.Name)
|
||||
@@ -328,7 +394,7 @@ func names(items []selectItem) []string {
|
||||
func TestBuildModelList_NoExistingModels(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(nil, nil, "")
|
||||
|
||||
want := []string{"glm-4.7-flash", "qwen3:8b", "glm-4.7:cloud", "kimi-k2.5:cloud"}
|
||||
want := []string{"kimi-k2.5:cloud", "glm-4.7:cloud", "glm-4.7-flash", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, names(items)); diff != "" {
|
||||
t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -349,9 +415,10 @@ func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
want := []string{"llama3.2", "qwen2.5", "glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
// Recommended pinned at top (local recs first, then cloud recs when only-local), then installed non-recs
|
||||
want := []string{"glm-4.7-flash", "qwen3:8b", "kimi-k2.5:cloud", "glm-4.7:cloud", "llama3.2", "qwen2.5"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("cloud recs should be at bottom (-want +got):\n%s", diff)
|
||||
t.Errorf("recs pinned at top, local recs before cloud recs (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,9 +431,10 @@ func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
want := []string{"glm-4.7:cloud", "llama3.2", "glm-4.7-flash", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
// All recs pinned at top (cloud before local in mixed case), then non-recs
|
||||
want := []string{"kimi-k2.5:cloud", "glm-4.7:cloud", "glm-4.7-flash", "qwen3:8b", "llama3.2"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("mixed models should be alphabetical (-want +got):\n%s", diff)
|
||||
t.Errorf("recs pinned at top, cloud recs first in mixed case (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -417,9 +485,10 @@ func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
|
||||
|
||||
// glm-4.7-flash and glm-4.7:cloud are installed so they sort normally;
|
||||
// kimi-k2.5:cloud and qwen3:8b are not installed so they go to the bottom
|
||||
want := []string{"glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
// All recs: cloud first in mixed case, then local, in rec order within each
|
||||
want := []string{"kimi-k2.5:cloud", "glm-4.7:cloud", "glm-4.7-flash", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("existing cloud models should sort normally (-want +got):\n%s", diff)
|
||||
t.Errorf("all recs, cloud first in mixed case (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -434,9 +503,10 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes
|
||||
|
||||
// kimi-k2.5:cloud is installed so it sorts normally;
|
||||
// the rest of the recommendations are not installed so they go to the bottom
|
||||
want := []string{"kimi-k2.5:cloud", "llama3.2", "glm-4.7-flash", "glm-4.7:cloud", "qwen3:8b"}
|
||||
// All recs pinned at top (cloud first in mixed case), then non-recs
|
||||
want := []string{"kimi-k2.5:cloud", "glm-4.7:cloud", "glm-4.7-flash", "qwen3:8b", "llama3.2"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("only non-installed models should be at bottom (-want +got):\n%s", diff)
|
||||
t.Errorf("recs pinned at top, cloud first in mixed case (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
@@ -509,3 +579,588 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
|
||||
t.Error("llama3.2 should not be in cloudModels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_RecommendedFieldSet(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "glm-4.7-flash", Remote: false},
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
|
||||
for _, item := range items {
|
||||
switch item.Name {
|
||||
case "glm-4.7-flash", "qwen3:8b", "glm-4.7:cloud", "kimi-k2.5:cloud":
|
||||
if !item.Recommended {
|
||||
t.Errorf("%q should have Recommended=true", item.Name)
|
||||
}
|
||||
case "llama3.2":
|
||||
if item.Recommended {
|
||||
t.Errorf("%q should have Recommended=false", item.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_MixedCase_CloudRecsFirst(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// Cloud recs should sort before local recs in mixed case
|
||||
cloudIdx := slices.Index(got, "glm-4.7:cloud")
|
||||
localIdx := slices.Index(got, "glm-4.7-flash")
|
||||
if cloudIdx > localIdx {
|
||||
t.Errorf("cloud recs should be before local recs in mixed case, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// Local recs should sort before cloud recs in only-local case
|
||||
localIdx := slices.Index(got, "glm-4.7-flash")
|
||||
cloudIdx := slices.Index(got, "glm-4.7:cloud")
|
||||
if localIdx > cloudIdx {
|
||||
t.Errorf("local recs should be before cloud recs in only-local case, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_RecsAboveNonRecs(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "custom-model", Remote: false},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// All recommended models should appear before non-recommended installed models
|
||||
lastRecIdx := -1
|
||||
firstNonRecIdx := len(got)
|
||||
for i, name := range got {
|
||||
isRec := name == "glm-4.7-flash" || name == "qwen3:8b" || name == "glm-4.7:cloud" || name == "kimi-k2.5:cloud"
|
||||
if isRec && i > lastRecIdx {
|
||||
lastRecIdx = i
|
||||
}
|
||||
if !isRec && i < firstNonRecIdx {
|
||||
firstNonRecIdx = i
|
||||
}
|
||||
}
|
||||
if lastRecIdx > firstNonRecIdx {
|
||||
t.Errorf("all recs should be above non-recs, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_CheckedBeforeRecs(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
|
||||
got := names(items)
|
||||
|
||||
if got[0] != "llama3.2" {
|
||||
t.Errorf("checked model should be first even before recs, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save a config for opencode so it looks like a previous launch
|
||||
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify loadIntegration returns the saved models
|
||||
saved, err := loadIntegration("opencode")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(saved.Models) == 0 {
|
||||
t.Fatal("expected saved models")
|
||||
}
|
||||
if saved.Models[0] != "llama3.2" {
|
||||
t.Errorf("expected llama3.2, got %s", saved.Models[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasConfigurerInterface(t *testing.T) {
|
||||
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
|
||||
claude := &Claude{}
|
||||
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
|
||||
t.Error("Claude should implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
|
||||
codex := &Codex{}
|
||||
if _, ok := interface{}(codex).(AliasConfigurer); ok {
|
||||
t.Error("Codex should not implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShowOrPull_ModelExists(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"model":"test-model"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPull(context.Background(), client, "test-model")
|
||||
if err != nil {
|
||||
t.Errorf("showOrPull should return nil when model exists, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
// confirmPrompt will fail in test (no terminal), so showOrPull should return an error
|
||||
err := ShowOrPull(context.Background(), client, "missing-model")
|
||||
if err == nil {
|
||||
t.Error("showOrPull should return error when model not found and no terminal available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_ShowCalledWithCorrectModel(t *testing.T) {
|
||||
var receivedModel string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
var req api.ShowRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err == nil {
|
||||
receivedModel = req.Model
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"model":"%s"}`, receivedModel)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
_ = ShowOrPull(context.Background(), client, "qwen3:8b")
|
||||
if receivedModel != "qwen3:8b" {
|
||||
t.Errorf("expected Show to be called with %q, got %q", "qwen3:8b", receivedModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_ModelNotFound_ConfirmYes_Pulls(t *testing.T) {
|
||||
// Set up hook so confirmPrompt doesn't need a terminal
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
if !strings.Contains(prompt, "missing-model") {
|
||||
t.Errorf("expected prompt to contain model name, got %q", prompt)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPull(context.Background(), client, "missing-model")
|
||||
if err != nil {
|
||||
t.Errorf("ShowOrPull should succeed after pull, got: %v", err)
|
||||
}
|
||||
if !pullCalled {
|
||||
t.Error("expected pull to be called when user confirms download")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_ModelNotFound_ConfirmNo_Cancelled(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
return false, ErrCancelled
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
t.Error("pull should not be called when user declines")
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPull(context.Background(), client, "missing-model")
|
||||
if err == nil {
|
||||
t.Error("ShowOrPull should return error when user declines")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmPrompt_DelegatesToHook(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
var hookCalled bool
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
hookCalled = true
|
||||
if prompt != "test prompt?" {
|
||||
t.Errorf("expected prompt %q, got %q", "test prompt?", prompt)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
ok, err := confirmPrompt("test prompt?")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Error("expected true from hook")
|
||||
}
|
||||
if !hookCalled {
|
||||
t.Error("expected DefaultConfirmPrompt hook to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_NoCloudModels(t *testing.T) {
|
||||
// ensureAuth should be a no-op when no cloud models are selected
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("no API calls expected when no cloud models selected")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ensureAuth(context.Background(), client, map[string]bool{}, []string{"local-model"})
|
||||
if err != nil {
|
||||
t.Errorf("ensureAuth should return nil for non-cloud models, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_CloudModelFilteredCorrectly(t *testing.T) {
|
||||
// ensureAuth should only care about models in cloudModels map
|
||||
var whoamiCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/me" {
|
||||
whoamiCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"name":"testuser"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cloudModels := map[string]bool{"cloud-model:cloud": true}
|
||||
selected := []string{"cloud-model:cloud", "local-model"}
|
||||
|
||||
err := ensureAuth(context.Background(), client, cloudModels, selected)
|
||||
if err != nil {
|
||||
t.Errorf("ensureAuth should succeed when user is authenticated, got: %v", err)
|
||||
}
|
||||
if !whoamiCalled {
|
||||
t.Error("expected whoami to be called for cloud model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_SkipsWhenNoCloudSelected(t *testing.T) {
|
||||
var whoamiCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/me" {
|
||||
whoamiCalled = true
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
// cloudModels has entries but none are in selected
|
||||
cloudModels := map[string]bool{"cloud-model:cloud": true}
|
||||
selected := []string{"local-model"}
|
||||
|
||||
err := ensureAuth(context.Background(), client, cloudModels, selected)
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error, got: %v", err)
|
||||
}
|
||||
if whoamiCalled {
|
||||
t.Error("whoami should not be called when no cloud models are selected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHyperlink(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
text string
|
||||
wantURL string
|
||||
wantText string
|
||||
}{
|
||||
{
|
||||
name: "basic link",
|
||||
url: "https://example.com",
|
||||
text: "click here",
|
||||
wantURL: "https://example.com",
|
||||
wantText: "click here",
|
||||
},
|
||||
{
|
||||
name: "url with path",
|
||||
url: "https://example.com/docs/install",
|
||||
text: "install docs",
|
||||
wantURL: "https://example.com/docs/install",
|
||||
wantText: "install docs",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := hyperlink(tt.url, tt.text)
|
||||
|
||||
// Should contain OSC 8 escape sequences
|
||||
if !strings.Contains(got, "\033]8;;") {
|
||||
t.Error("should contain OSC 8 open sequence")
|
||||
}
|
||||
if !strings.Contains(got, tt.wantURL) {
|
||||
t.Errorf("should contain URL %q", tt.wantURL)
|
||||
}
|
||||
if !strings.Contains(got, tt.wantText) {
|
||||
t.Errorf("should contain text %q", tt.wantText)
|
||||
}
|
||||
|
||||
// Should have closing OSC 8 sequence
|
||||
wantSuffix := "\033]8;;\033\\"
|
||||
if !strings.HasSuffix(got, wantSuffix) {
|
||||
t.Error("should end with OSC 8 close sequence")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationInstallHint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantEmpty bool
|
||||
wantURL string
|
||||
}{
|
||||
{
|
||||
name: "claude has hint",
|
||||
input: "claude",
|
||||
wantURL: "https://code.claude.com/docs/en/quickstart",
|
||||
},
|
||||
{
|
||||
name: "codex has hint",
|
||||
input: "codex",
|
||||
wantURL: "https://developers.openai.com/codex/cli/",
|
||||
},
|
||||
{
|
||||
name: "openclaw has hint",
|
||||
input: "openclaw",
|
||||
wantURL: "https://docs.openclaw.ai",
|
||||
},
|
||||
{
|
||||
name: "unknown has no hint",
|
||||
input: "unknown",
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "empty name has no hint",
|
||||
input: "",
|
||||
wantEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IntegrationInstallHint(tt.input)
|
||||
if tt.wantEmpty {
|
||||
if got != "" {
|
||||
t.Errorf("expected empty hint, got %q", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if !strings.Contains(got, "Install from") {
|
||||
t.Errorf("hint should start with 'Install from', got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, tt.wantURL) {
|
||||
t.Errorf("hint should contain URL %q, got %q", tt.wantURL, got)
|
||||
}
|
||||
// Should be a clickable hyperlink
|
||||
if !strings.Contains(got, "\033]8;;") {
|
||||
t.Error("hint URL should be wrapped in OSC 8 hyperlink")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListIntegrationInfos(t *testing.T) {
|
||||
infos := ListIntegrationInfos()
|
||||
|
||||
t.Run("excludes aliases", func(t *testing.T) {
|
||||
for _, info := range infos {
|
||||
if integrationAliases[info.Name] {
|
||||
t.Errorf("alias %q should not appear in ListIntegrationInfos", info.Name)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sorted by name", func(t *testing.T) {
|
||||
for i := 1; i < len(infos); i++ {
|
||||
if infos[i-1].Name >= infos[i].Name {
|
||||
t.Errorf("not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("all fields populated", func(t *testing.T) {
|
||||
for _, info := range infos {
|
||||
if info.Name == "" {
|
||||
t.Error("Name should not be empty")
|
||||
}
|
||||
if info.DisplayName == "" {
|
||||
t.Errorf("DisplayName for %q should not be empty", info.Name)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("includes known integrations", func(t *testing.T) {
|
||||
known := map[string]bool{"claude": false, "codex": false, "opencode": false}
|
||||
for _, info := range infos {
|
||||
if _, ok := known[info.Name]; ok {
|
||||
known[info.Name] = true
|
||||
}
|
||||
}
|
||||
for name, found := range known {
|
||||
if !found {
|
||||
t.Errorf("expected %q in ListIntegrationInfos", name)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildModelList_Descriptions(t *testing.T) {
|
||||
t.Run("installed recommended has base description", func(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "qwen3:8b", Remote: false},
|
||||
}
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
|
||||
for _, item := range items {
|
||||
if item.Name == "qwen3:8b" {
|
||||
if strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("installed model should not have 'install?' suffix, got %q", item.Description)
|
||||
}
|
||||
if item.Description == "" {
|
||||
t.Error("installed recommended model should have a description")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Error("qwen3:8b not found in items")
|
||||
})
|
||||
|
||||
t.Run("not-installed local rec has VRAM in description", func(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(nil, nil, "")
|
||||
|
||||
for _, item := range items {
|
||||
if item.Name == "qwen3:8b" {
|
||||
if !strings.Contains(item.Description, "~11GB") {
|
||||
t.Errorf("not-installed qwen3:8b should show VRAM hint, got %q", item.Description)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Error("qwen3:8b not found in items")
|
||||
})
|
||||
|
||||
t.Run("installed local rec omits VRAM", func(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "qwen3:8b", Remote: false},
|
||||
}
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
|
||||
for _, item := range items {
|
||||
if item.Name == "qwen3:8b" {
|
||||
if strings.Contains(item.Description, "~11GB") {
|
||||
t.Errorf("installed qwen3:8b should not show VRAM hint, got %q", item.Description)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Error("qwen3:8b not found in items")
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := LaunchIntegration("nonexistent-integration")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown integration")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown integration") {
|
||||
t.Errorf("error should mention 'unknown integration', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_NotConfigured(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Claude is a known integration but not configured in temp dir
|
||||
err := LaunchIntegration("claude")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when integration is not configured")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not configured") {
|
||||
t.Errorf("error should mention 'not configured', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,8 +17,6 @@ type Openclaw struct{}
|
||||
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
const ansiGreen = "\033[32m"
|
||||
|
||||
func (c *Openclaw) Run(model string, args []string) error {
|
||||
bin := "openclaw"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
@@ -10,12 +11,53 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// cloudModelLimits maps cloud model base names to their token limits.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
if l, ok := cloudModelLimits[name]; ok {
|
||||
return l, true
|
||||
}
|
||||
base := strings.TrimSuffix(name, ":cloud")
|
||||
if base != name {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string, args []string) error {
|
||||
@@ -113,6 +155,8 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
@@ -122,12 +166,29 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
existing["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
models[model] = map[string]any{
|
||||
entry := map[string]any{
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
models[model] = entry
|
||||
}
|
||||
|
||||
ollama["models"] = models
|
||||
|
||||
@@ -2,6 +2,7 @@ package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -495,6 +496,166 @@ func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry, ok := models[model].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("model %s not found in config", model)
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
if entry["limit"] != nil {
|
||||
t.Errorf("local model should not have limit set, got %v", entry["limit"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
// Set up a model with a user-configured limit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"llama3.2": {
|
||||
"name": "llama3.2",
|
||||
"_launch": true,
|
||||
"limit": {"context": 8192, "output": 4096}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`), 0o644)
|
||||
|
||||
// Re-edit should preserve the user's limit (not delete it)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("user-configured limit was removed")
|
||||
}
|
||||
if limit["context"] != float64(8192) {
|
||||
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
|
||||
}
|
||||
if limit["output"] != float64(4096) {
|
||||
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
||||
// Verify that when a cloud model entry has limits set (as Edit would do),
|
||||
// the structure matches what opencode expects and re-edit preserves them.
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
expected := cloudModelLimits["glm-4.7"]
|
||||
|
||||
// Simulate a cloud model that already has the limit set by a previous Edit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud",
|
||||
"_launch": true,
|
||||
"limit": {"context": %d, "output": %d}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, expected.Context, expected.Output)), 0o644)
|
||||
|
||||
// Re-edit should preserve the cloud model limit
|
||||
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("cloud model limit was removed on re-edit")
|
||||
}
|
||||
if limit["context"] != float64(expected.Context) {
|
||||
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
|
||||
}
|
||||
if limit["output"] != float64(expected.Output) {
|
||||
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupCloudModelLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantOK bool
|
||||
wantContext int
|
||||
wantOutput int
|
||||
}{
|
||||
{"glm-4.7", true, 202_752, 131_072},
|
||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||
{"kimi-k2.5", true, 262_144, 262_144},
|
||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
||||
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||
{"llama3.2", false, 0, 0},
|
||||
{"unknown-model:cloud", false, 0, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(tt.name)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
|
||||
}
|
||||
if ok {
|
||||
if l.Context != tt.wantContext {
|
||||
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
|
||||
}
|
||||
if l.Output != tt.wantOutput {
|
||||
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
237
cmd/config/pi.go
Normal file
237
cmd/config/pi.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
|
||||
type Pi struct{}
|
||||
|
||||
func (p *Pi) String() string { return "Pi" }
|
||||
|
||||
func (p *Pi) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("pi"); err != nil {
|
||||
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := p.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("pi", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (p *Pi) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var paths []string
|
||||
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
if _, err := os.Stat(modelsPath); err == nil {
|
||||
paths = append(paths, modelsPath)
|
||||
}
|
||||
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||
if _, err := os.Stat(settingsPath); err == nil {
|
||||
paths = append(paths, settingsPath)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func (p *Pi) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
providers, ok := config["providers"].(map[string]any)
|
||||
if !ok {
|
||||
providers = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama, ok := providers["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"baseUrl": envconfig.Host().String() + "/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
}
|
||||
}
|
||||
|
||||
existingModels, ok := ollama["models"].([]any)
|
||||
if !ok {
|
||||
existingModels = make([]any, 0)
|
||||
}
|
||||
|
||||
// Build set of selected models to track which need to be added
|
||||
selectedSet := make(map[string]bool, len(models))
|
||||
for _, m := range models {
|
||||
selectedSet[m] = true
|
||||
}
|
||||
|
||||
// Build new models list:
|
||||
// 1. Keep user-managed models (no _launch marker) - untouched
|
||||
// 2. Keep ollama-managed models (_launch marker) that are still selected
|
||||
// 3. Add new ollama-managed models
|
||||
var newModels []any
|
||||
for _, m := range existingModels {
|
||||
if modelObj, ok := m.(map[string]any); ok {
|
||||
if id, ok := modelObj["id"].(string); ok {
|
||||
// User-managed model (no _launch marker) - always preserve
|
||||
if !isPiOllamaModel(modelObj) {
|
||||
newModels = append(newModels, m)
|
||||
} else if selectedSet[id] {
|
||||
// Ollama-managed and still selected - keep it
|
||||
newModels = append(newModels, m)
|
||||
selectedSet[id] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add newly selected models that weren't already in the list
|
||||
client := api.NewClient(envconfig.Host(), http.DefaultClient)
|
||||
ctx := context.Background()
|
||||
for _, model := range models {
|
||||
if selectedSet[model] {
|
||||
newModels = append(newModels, createConfig(ctx, client, model))
|
||||
}
|
||||
}
|
||||
|
||||
ollama["models"] = newModels
|
||||
providers["ollama"] = ollama
|
||||
config["providers"] = providers
|
||||
|
||||
configData, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update settings.json with default provider and model
|
||||
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||
settings := make(map[string]any)
|
||||
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||
_ = json.Unmarshal(data, &settings)
|
||||
}
|
||||
|
||||
settings["defaultProvider"] = "ollama"
|
||||
settings["defaultModel"] = models[0]
|
||||
|
||||
settingsData, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(settingsPath, settingsData)
|
||||
}
|
||||
|
||||
func (p *Pi) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
config, err := readJSONFile(configPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
providers, _ := config["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].([]any)
|
||||
|
||||
var result []string
|
||||
for _, m := range models {
|
||||
if modelObj, ok := m.(map[string]any); ok {
|
||||
if id, ok := modelObj["id"].(string); ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
slices.Sort(result)
|
||||
return result
|
||||
}
|
||||
|
||||
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
|
||||
func isPiOllamaModel(cfg map[string]any) bool {
|
||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// createConfig builds Pi model config with capability detection
|
||||
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
||||
cfg := map[string]any{
|
||||
"id": modelID,
|
||||
"_launch": true,
|
||||
}
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
||||
if err != nil {
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Set input types based on vision capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
||||
cfg["input"] = []string{"text", "image"}
|
||||
} else {
|
||||
cfg["input"] = []string{"text"}
|
||||
}
|
||||
|
||||
// Set reasoning based on thinking capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
cfg["reasoning"] = true
|
||||
}
|
||||
|
||||
// Extract context window from ModelInfo
|
||||
for key, val := range resp.ModelInfo {
|
||||
if strings.HasSuffix(key, ".context_length") {
|
||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||
cfg["contextWindow"] = int(ctxLen)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
830
cmd/config/pi_test.go
Normal file
830
cmd/config/pi_test.go
Normal file
@@ -0,0 +1,830 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestPiIntegration(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := pi.String(); got != "Pi" {
|
||||
t.Errorf("String() = %q, want %q", got, "Pi")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = pi
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = pi
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiPaths(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
|
||||
t.Run("returns empty when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
paths := pi.Paths()
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("Paths() = %v, want empty", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte("{}"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
paths := pi.Paths()
|
||||
if len(paths) != 1 || paths[0] != configPath {
|
||||
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiEdit(t *testing.T) {
|
||||
// Mock Ollama server for createConfig calls during Edit
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
pi := &Pi{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
|
||||
cleanup := func() {
|
||||
os.RemoveAll(configDir)
|
||||
}
|
||||
|
||||
readConfig := func() map[string]any {
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
return cfg
|
||||
}
|
||||
|
||||
t.Run("returns nil for empty models", func(t *testing.T) {
|
||||
if err := pi.Edit([]string{}); err != nil {
|
||||
t.Errorf("Edit([]) error = %v, want nil", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates config with models", func(t *testing.T) {
|
||||
cleanup()
|
||||
|
||||
models := []string{"llama3.2", "qwen3:8b"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
|
||||
providers, ok := cfg["providers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Config missing providers")
|
||||
}
|
||||
|
||||
ollama, ok := providers["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Providers missing ollama")
|
||||
}
|
||||
|
||||
modelsArray, ok := ollama["models"].([]any)
|
||||
if !ok || len(modelsArray) != 2 {
|
||||
t.Errorf("Expected 2 models, got %v", modelsArray)
|
||||
}
|
||||
|
||||
if ollama["baseUrl"] == nil {
|
||||
t.Error("Missing baseUrl")
|
||||
}
|
||||
if ollama["api"] != "openai-completions" {
|
||||
t.Errorf("Expected api=openai-completions, got %v", ollama["api"])
|
||||
}
|
||||
if ollama["apiKey"] != "ollama" {
|
||||
t.Errorf("Expected apiKey=ollama, got %v", ollama["apiKey"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates existing config preserving ollama provider settings", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://custom:8080/v1",
|
||||
"api": "custom-api",
|
||||
"apiKey": "custom-key",
|
||||
"models": [
|
||||
{"id": "old-model", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"new-model"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
|
||||
if ollama["baseUrl"] != "http://custom:8080/v1" {
|
||||
t.Errorf("Custom baseUrl not preserved, got %v", ollama["baseUrl"])
|
||||
}
|
||||
if ollama["api"] != "custom-api" {
|
||||
t.Errorf("Custom api not preserved, got %v", ollama["api"])
|
||||
}
|
||||
if ollama["apiKey"] != "custom-key" {
|
||||
t.Errorf("Custom apiKey not preserved, got %v", ollama["apiKey"])
|
||||
}
|
||||
|
||||
modelsArray := ollama["models"].([]any)
|
||||
if len(modelsArray) != 1 {
|
||||
t.Errorf("Expected 1 model after update, got %d", len(modelsArray))
|
||||
} else {
|
||||
modelEntry := modelsArray[0].(map[string]any)
|
||||
if modelEntry["id"] != "new-model" {
|
||||
t.Errorf("Expected new-model, got %v", modelEntry["id"])
|
||||
}
|
||||
// Verify _launch marker is present
|
||||
if modelEntry["_launch"] != true {
|
||||
t.Errorf("Expected _launch marker to be true")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("replaces old models with new ones", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Old models must have _launch marker to be managed by us
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "old-model-1", "_launch": true},
|
||||
{"id": "old-model-2", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newModels := []string{"new-model-1", "new-model-2"}
|
||||
if err := pi.Edit(newModels); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
if len(modelsArray) != 2 {
|
||||
t.Errorf("Expected 2 models, got %d", len(modelsArray))
|
||||
}
|
||||
|
||||
modelIDs := make(map[string]bool)
|
||||
for _, m := range modelsArray {
|
||||
modelObj := m.(map[string]any)
|
||||
id := modelObj["id"].(string)
|
||||
modelIDs[id] = true
|
||||
}
|
||||
|
||||
if !modelIDs["new-model-1"] || !modelIDs["new-model-2"] {
|
||||
t.Errorf("Expected new models, got %v", modelIDs)
|
||||
}
|
||||
if modelIDs["old-model-1"] || modelIDs["old-model-2"] {
|
||||
t.Errorf("Old models should have been removed, got %v", modelIDs)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles partial overlap in model list", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Models must have _launch marker to be managed
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "keep-model", "_launch": true},
|
||||
{"id": "remove-model", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newModels := []string{"keep-model", "add-model"}
|
||||
if err := pi.Edit(newModels); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
if len(modelsArray) != 2 {
|
||||
t.Errorf("Expected 2 models, got %d", len(modelsArray))
|
||||
}
|
||||
|
||||
modelIDs := make(map[string]bool)
|
||||
for _, m := range modelsArray {
|
||||
modelObj := m.(map[string]any)
|
||||
id := modelObj["id"].(string)
|
||||
modelIDs[id] = true
|
||||
}
|
||||
|
||||
if !modelIDs["keep-model"] || !modelIDs["add-model"] {
|
||||
t.Errorf("Expected keep-model and add-model, got %v", modelIDs)
|
||||
}
|
||||
if modelIDs["remove-model"] {
|
||||
t.Errorf("remove-model should have been removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupt config gracefully", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"test-model"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() should not fail with corrupt config, got %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read config: %v", err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("Config should be valid after Edit, got parse error: %v", err)
|
||||
}
|
||||
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
if len(modelsArray) != 1 {
|
||||
t.Errorf("Expected 1 model, got %d", len(modelsArray))
|
||||
}
|
||||
})
|
||||
|
||||
// CRITICAL SAFETY TEST: verifies we don't stomp on user configs
|
||||
t.Run("preserves user-managed models without _launch marker", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// User has manually configured models in ollama provider (no _launch marker)
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "user-model-1"},
|
||||
{"id": "user-model-2", "customField": "preserved"},
|
||||
{"id": "ollama-managed", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add a new ollama-managed model
|
||||
newModels := []string{"new-ollama-model"}
|
||||
if err := pi.Edit(newModels); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
// Should have: new-ollama-model (managed) + 2 user models (preserved)
|
||||
if len(modelsArray) != 3 {
|
||||
t.Errorf("Expected 3 models (1 new managed + 2 preserved user models), got %d", len(modelsArray))
|
||||
}
|
||||
|
||||
modelIDs := make(map[string]map[string]any)
|
||||
for _, m := range modelsArray {
|
||||
modelObj := m.(map[string]any)
|
||||
id := modelObj["id"].(string)
|
||||
modelIDs[id] = modelObj
|
||||
}
|
||||
|
||||
// Verify new model has _launch marker
|
||||
if m, ok := modelIDs["new-ollama-model"]; !ok {
|
||||
t.Errorf("new-ollama-model should be present")
|
||||
} else if m["_launch"] != true {
|
||||
t.Errorf("new-ollama-model should have _launch marker")
|
||||
}
|
||||
|
||||
// Verify user models are preserved
|
||||
if _, ok := modelIDs["user-model-1"]; !ok {
|
||||
t.Errorf("user-model-1 should be preserved")
|
||||
}
|
||||
if _, ok := modelIDs["user-model-2"]; !ok {
|
||||
t.Errorf("user-model-2 should be preserved")
|
||||
} else if modelIDs["user-model-2"]["customField"] != "preserved" {
|
||||
t.Errorf("user-model-2 customField should be preserved")
|
||||
}
|
||||
|
||||
// Verify old ollama-managed model is removed (not in new list)
|
||||
if _, ok := modelIDs["ollama-managed"]; ok {
|
||||
t.Errorf("ollama-managed should be removed (old ollama model not in new selection)")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates settings.json with default provider and model", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Create existing settings with other fields
|
||||
settingsPath := filepath.Join(configDir, "settings.json")
|
||||
existingSettings := `{
|
||||
"theme": "dark",
|
||||
"customSetting": "value",
|
||||
"defaultProvider": "anthropic",
|
||||
"defaultModel": "claude-3"
|
||||
}`
|
||||
if err := os.WriteFile(settingsPath, []byte(existingSettings), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"llama3.2"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read settings: %v", err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
t.Fatalf("Failed to parse settings: %v", err)
|
||||
}
|
||||
|
||||
// Verify defaultProvider is set to ollama
|
||||
if settings["defaultProvider"] != "ollama" {
|
||||
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||
}
|
||||
|
||||
// Verify defaultModel is set to first model
|
||||
if settings["defaultModel"] != "llama3.2" {
|
||||
t.Errorf("defaultModel = %v, want llama3.2", settings["defaultModel"])
|
||||
}
|
||||
|
||||
// Verify other fields are preserved
|
||||
if settings["theme"] != "dark" {
|
||||
t.Errorf("theme = %v, want dark (preserved)", settings["theme"])
|
||||
}
|
||||
if settings["customSetting"] != "value" {
|
||||
t.Errorf("customSetting = %v, want value (preserved)", settings["customSetting"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates settings.json if it does not exist", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
models := []string{"qwen3:8b"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
settingsPath := filepath.Join(configDir, "settings.json")
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("settings.json should be created: %v", err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
t.Fatalf("Failed to parse settings: %v", err)
|
||||
}
|
||||
|
||||
if settings["defaultProvider"] != "ollama" {
|
||||
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||
}
|
||||
if settings["defaultModel"] != "qwen3:8b" {
|
||||
t.Errorf("defaultModel = %v, want qwen3:8b", settings["defaultModel"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupt settings.json gracefully", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Create corrupt settings
|
||||
settingsPath := filepath.Join(configDir, "settings.json")
|
||||
if err := os.WriteFile(settingsPath, []byte("{invalid"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"test-model"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() should not fail with corrupt settings, got %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read settings: %v", err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
t.Fatalf("settings.json should be valid after Edit, got parse error: %v", err)
|
||||
}
|
||||
|
||||
if settings["defaultProvider"] != "ollama" {
|
||||
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||
}
|
||||
if settings["defaultModel"] != "test-model" {
|
||||
t.Errorf("defaultModel = %v, want test-model", settings["defaultModel"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiModels(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
|
||||
t.Run("returns nil when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
models := pi.Models()
|
||||
if models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns models from config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"models": [
|
||||
{"id": "llama3.2"},
|
||||
{"id": "qwen3:8b"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if len(models) != 2 {
|
||||
t.Errorf("Models() returned %d models, want 2", len(models))
|
||||
}
|
||||
if models[0] != "llama3.2" || models[1] != "qwen3:8b" {
|
||||
t.Errorf("Models() = %v, want [llama3.2 qwen3:8b] (sorted)", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns sorted models", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"models": [
|
||||
{"id": "z-model"},
|
||||
{"id": "a-model"},
|
||||
{"id": "m-model"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if models[0] != "a-model" || models[1] != "m-model" || models[2] != "z-model" {
|
||||
t.Errorf("Models() = %v, want [a-model m-model z-model] (sorted)", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when models array is missing", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := `{
|
||||
"providers": {
|
||||
"ollama": {}
|
||||
}
|
||||
}`
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if models != nil {
|
||||
t.Errorf("Models() = %v, want nil when models array is missing", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupt config gracefully", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if models != nil {
|
||||
t.Errorf("Models() = %v, want nil for corrupt config", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsPiOllamaModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg map[string]any
|
||||
want bool
|
||||
}{
|
||||
{"with _launch true", map[string]any{"id": "m", "_launch": true}, true},
|
||||
{"with _launch false", map[string]any{"id": "m", "_launch": false}, false},
|
||||
{"without _launch", map[string]any{"id": "m"}, false},
|
||||
{"with _launch non-bool", map[string]any{"id": "m", "_launch": "yes"}, false},
|
||||
{"empty map", map[string]any{}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isPiOllamaModel(tt.cfg); got != tt.want {
|
||||
t.Errorf("isPiOllamaModel(%v) = %v, want %v", tt.cfg, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateConfig(t *testing.T) {
|
||||
t.Run("sets vision input when model has vision capability", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "llava:7b")
|
||||
|
||||
if cfg["id"] != "llava:7b" {
|
||||
t.Errorf("id = %v, want llava:7b", cfg["id"])
|
||||
}
|
||||
if cfg["_launch"] != true {
|
||||
t.Error("expected _launch = true")
|
||||
}
|
||||
input, ok := cfg["input"].([]string)
|
||||
if !ok || len(input) != 2 || input[0] != "text" || input[1] != "image" {
|
||||
t.Errorf("input = %v, want [text image]", cfg["input"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets text-only input when model lacks vision", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["completion"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "llama3.2")
|
||||
|
||||
input, ok := cfg["input"].([]string)
|
||||
if !ok || len(input) != 1 || input[0] != "text" {
|
||||
t.Errorf("input = %v, want [text]", cfg["input"])
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set for non-thinking model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets reasoning when model has thinking capability", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "qwq")
|
||||
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true for thinking model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("extracts context window from model info", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":131072}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "llama3.2")
|
||||
|
||||
if cfg["contextWindow"] != 131072 {
|
||||
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles all capabilities together", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision","thinking"],"model_info":{"qwen3.context_length":32768}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "qwen3-vision")
|
||||
|
||||
input := cfg["input"].([]string)
|
||||
if len(input) != 2 || input[0] != "text" || input[1] != "image" {
|
||||
t.Errorf("input = %v, want [text image]", input)
|
||||
}
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true")
|
||||
}
|
||||
if cfg["contextWindow"] != 32768 {
|
||||
t.Errorf("contextWindow = %v, want 32768", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns minimal config when show fails", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "missing-model")
|
||||
|
||||
if cfg["id"] != "missing-model" {
|
||||
t.Errorf("id = %v, want missing-model", cfg["id"])
|
||||
}
|
||||
if cfg["_launch"] != true {
|
||||
t.Error("expected _launch = true")
|
||||
}
|
||||
// Should not have capability fields
|
||||
if _, ok := cfg["input"]; ok {
|
||||
t.Error("input should not be set when show fails")
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set when show fails")
|
||||
}
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set when show fails")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips zero context length", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":0}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "test-model")
|
||||
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set for zero value")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Ensure Capability constants used in createConfig match expected values
|
||||
func TestPiCapabilityConstants(t *testing.T) {
|
||||
if model.CapabilityVision != "vision" {
|
||||
t.Errorf("CapabilityVision = %q, want %q", model.CapabilityVision, "vision")
|
||||
}
|
||||
if model.CapabilityThinking != "thinking" {
|
||||
t.Errorf("CapabilityThinking = %q, want %q", model.CapabilityThinking, "thinking")
|
||||
}
|
||||
}
|
||||
@@ -3,474 +3,34 @@ package config
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiHideCursor = "\033[?25l"
|
||||
ansiShowCursor = "\033[?25h"
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiClearDown = "\033[J"
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
)
|
||||
|
||||
const maxDisplayedItems = 10
|
||||
// ErrCancelled is returned when the user cancels a selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
var errCancelled = errors.New("cancelled")
|
||||
// errCancelled is kept as an alias for backward compatibility within the package.
|
||||
var errCancelled = ErrCancelled
|
||||
|
||||
type selectItem struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
type inputEvent int
|
||||
|
||||
const (
|
||||
eventNone inputEvent = iota
|
||||
eventEnter
|
||||
eventEscape
|
||||
eventUp
|
||||
eventDown
|
||||
eventTab
|
||||
eventBackspace
|
||||
eventChar
|
||||
)
|
||||
|
||||
type selectState struct {
|
||||
items []selectItem
|
||||
filter string
|
||||
selected int
|
||||
scrollOffset int
|
||||
}
|
||||
|
||||
func newSelectState(items []selectItem) *selectState {
|
||||
return &selectState{items: items}
|
||||
}
|
||||
|
||||
func (s *selectState) filtered() []selectItem {
|
||||
return filterItems(s.items, s.filter)
|
||||
}
|
||||
|
||||
func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) {
|
||||
filtered := s.filtered()
|
||||
|
||||
switch event {
|
||||
case eventEnter:
|
||||
if len(filtered) > 0 && s.selected < len(filtered) {
|
||||
return true, filtered[s.selected].Name, nil
|
||||
}
|
||||
case eventEscape:
|
||||
return true, "", errCancelled
|
||||
case eventBackspace:
|
||||
if len(s.filter) > 0 {
|
||||
s.filter = s.filter[:len(s.filter)-1]
|
||||
s.selected = 0
|
||||
s.scrollOffset = 0
|
||||
}
|
||||
case eventUp:
|
||||
if s.selected > 0 {
|
||||
s.selected--
|
||||
if s.selected < s.scrollOffset {
|
||||
s.scrollOffset = s.selected
|
||||
}
|
||||
}
|
||||
case eventDown:
|
||||
if s.selected < len(filtered)-1 {
|
||||
s.selected++
|
||||
if s.selected >= s.scrollOffset+maxDisplayedItems {
|
||||
s.scrollOffset = s.selected - maxDisplayedItems + 1
|
||||
}
|
||||
}
|
||||
case eventChar:
|
||||
s.filter += string(char)
|
||||
s.selected = 0
|
||||
s.scrollOffset = 0
|
||||
}
|
||||
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
type multiSelectState struct {
|
||||
items []selectItem
|
||||
itemIndex map[string]int
|
||||
filter string
|
||||
highlighted int
|
||||
scrollOffset int
|
||||
checked map[int]bool
|
||||
checkOrder []int
|
||||
focusOnButton bool
|
||||
}
|
||||
|
||||
func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState {
|
||||
s := &multiSelectState{
|
||||
items: items,
|
||||
itemIndex: make(map[string]int, len(items)),
|
||||
checked: make(map[int]bool),
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
s.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := s.itemIndex[name]; ok {
|
||||
s.checked[idx] = true
|
||||
s.checkOrder = append(s.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *multiSelectState) filtered() []selectItem {
|
||||
return filterItems(s.items, s.filter)
|
||||
}
|
||||
|
||||
func (s *multiSelectState) toggleItem() {
|
||||
filtered := s.filtered()
|
||||
if len(filtered) == 0 || s.highlighted >= len(filtered) {
|
||||
return
|
||||
}
|
||||
|
||||
item := filtered[s.highlighted]
|
||||
origIdx := s.itemIndex[item.Name]
|
||||
|
||||
if s.checked[origIdx] {
|
||||
delete(s.checked, origIdx)
|
||||
for i, idx := range s.checkOrder {
|
||||
if idx == origIdx {
|
||||
s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.checked[origIdx] = true
|
||||
s.checkOrder = append(s.checkOrder, origIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) {
|
||||
filtered := s.filtered()
|
||||
|
||||
switch event {
|
||||
case eventEnter:
|
||||
if s.focusOnButton && len(s.checkOrder) > 0 {
|
||||
var res []string
|
||||
for _, idx := range s.checkOrder {
|
||||
res = append(res, s.items[idx].Name)
|
||||
}
|
||||
return true, res, nil
|
||||
} else if !s.focusOnButton {
|
||||
s.toggleItem()
|
||||
}
|
||||
case eventTab:
|
||||
if len(s.checkOrder) > 0 {
|
||||
s.focusOnButton = !s.focusOnButton
|
||||
}
|
||||
case eventEscape:
|
||||
return true, nil, errCancelled
|
||||
case eventBackspace:
|
||||
if len(s.filter) > 0 {
|
||||
s.filter = s.filter[:len(s.filter)-1]
|
||||
s.highlighted = 0
|
||||
s.scrollOffset = 0
|
||||
s.focusOnButton = false
|
||||
}
|
||||
case eventUp:
|
||||
if s.focusOnButton {
|
||||
s.focusOnButton = false
|
||||
} else if s.highlighted > 0 {
|
||||
s.highlighted--
|
||||
if s.highlighted < s.scrollOffset {
|
||||
s.scrollOffset = s.highlighted
|
||||
}
|
||||
}
|
||||
case eventDown:
|
||||
if s.focusOnButton {
|
||||
s.focusOnButton = false
|
||||
} else if s.highlighted < len(filtered)-1 {
|
||||
s.highlighted++
|
||||
if s.highlighted >= s.scrollOffset+maxDisplayedItems {
|
||||
s.scrollOffset = s.highlighted - maxDisplayedItems + 1
|
||||
}
|
||||
}
|
||||
case eventChar:
|
||||
s.filter += string(char)
|
||||
s.highlighted = 0
|
||||
s.scrollOffset = 0
|
||||
s.focusOnButton = false
|
||||
}
|
||||
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
func (s *multiSelectState) selectedCount() int {
|
||||
return len(s.checkOrder)
|
||||
}
|
||||
|
||||
// Terminal I/O handling
|
||||
|
||||
type terminalState struct {
|
||||
fd int
|
||||
oldState *term.State
|
||||
}
|
||||
|
||||
func enterRawMode() (*terminalState, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Fprint(os.Stderr, ansiHideCursor)
|
||||
return &terminalState{fd: fd, oldState: oldState}, nil
|
||||
}
|
||||
|
||||
func (t *terminalState) restore() {
|
||||
fmt.Fprint(os.Stderr, ansiShowCursor)
|
||||
term.Restore(t.fd, t.oldState)
|
||||
}
|
||||
|
||||
func clearLines(n int) {
|
||||
if n > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA", n)
|
||||
fmt.Fprint(os.Stderr, ansiClearDown)
|
||||
}
|
||||
}
|
||||
|
||||
func parseInput(r io.Reader) (inputEvent, byte, error) {
|
||||
buf := make([]byte, 3)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case n == 1 && buf[0] == 13:
|
||||
return eventEnter, 0, nil
|
||||
case n == 1 && (buf[0] == 3 || buf[0] == 27):
|
||||
return eventEscape, 0, nil
|
||||
case n == 1 && buf[0] == 9:
|
||||
return eventTab, 0, nil
|
||||
case n == 1 && buf[0] == 127:
|
||||
return eventBackspace, 0, nil
|
||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65:
|
||||
return eventUp, 0, nil
|
||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66:
|
||||
return eventDown, 0, nil
|
||||
case n == 1 && buf[0] >= 32 && buf[0] < 127:
|
||||
return eventChar, buf[0], nil
|
||||
}
|
||||
|
||||
return eventNone, 0, nil
|
||||
}
|
||||
|
||||
// Rendering
|
||||
|
||||
func renderSelect(w io.Writer, prompt string, s *selectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
if s.filter == "" {
|
||||
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
}
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
lineCount++
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxDisplayedItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := s.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
prefix := " "
|
||||
if idx == s.selected {
|
||||
prefix = " " + ansiBold + "> "
|
||||
}
|
||||
if item.Description != "" {
|
||||
fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||
lineCount++
|
||||
}
|
||||
}
|
||||
|
||||
return lineCount
|
||||
}
|
||||
|
||||
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
if s.filter == "" {
|
||||
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
}
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||
lineCount++
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxDisplayedItems)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := s.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
origIdx := s.itemIndex[item.Name]
|
||||
|
||||
checkbox := "[ ]"
|
||||
if s.checked[origIdx] {
|
||||
checkbox = "[x]"
|
||||
}
|
||||
|
||||
prefix := " "
|
||||
suffix := ""
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
prefix = "> "
|
||||
}
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx {
|
||||
suffix = " " + ansiGray + "(default)" + ansiReset
|
||||
}
|
||||
|
||||
desc := ""
|
||||
if item.Description != "" {
|
||||
desc = " " + ansiGray + "- " + item.Description + ansiReset
|
||||
}
|
||||
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, desc, suffix)
|
||||
} else {
|
||||
fmt.Fprintf(w, " %s %s %s%s%s\r\n", prefix, checkbox, item.Name, desc, suffix)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||
lineCount++
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "\r\n")
|
||||
lineCount++
|
||||
count := s.selectedCount()
|
||||
switch {
|
||||
case count == 0:
|
||||
fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset)
|
||||
case s.focusOnButton:
|
||||
fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset)
|
||||
default:
|
||||
fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset)
|
||||
}
|
||||
lineCount++
|
||||
|
||||
return lineCount
|
||||
}
|
||||
|
||||
// selectPrompt prompts the user to select a single item from a list.
|
||||
func selectPrompt(prompt string, items []selectItem) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
ts, err := enterRawMode()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer ts.restore()
|
||||
|
||||
state := newSelectState(items)
|
||||
var lastLineCount int
|
||||
|
||||
render := func() {
|
||||
clearLines(lastLineCount)
|
||||
lastLineCount = renderSelect(os.Stderr, prompt, state)
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
for {
|
||||
event, char, err := parseInput(os.Stdin)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
done, result, err := state.handleInput(event, char)
|
||||
if done {
|
||||
clearLines(lastLineCount)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
render()
|
||||
}
|
||||
}
|
||||
|
||||
// multiSelectPrompt prompts the user to select multiple items from a list.
|
||||
func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
ts, err := enterRawMode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ts.restore()
|
||||
|
||||
state := newMultiSelectState(items, preChecked)
|
||||
var lastLineCount int
|
||||
|
||||
render := func() {
|
||||
clearLines(lastLineCount)
|
||||
lastLineCount = renderMultiSelect(os.Stderr, prompt, state)
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
for {
|
||||
event, char, err := parseInput(os.Stdin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
done, result, err := state.handleInput(event, char)
|
||||
if done {
|
||||
clearLines(lastLineCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
render()
|
||||
}
|
||||
}
|
||||
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||
// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
|
||||
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
||||
|
||||
func confirmPrompt(prompt string) (bool, error) {
|
||||
if DefaultConfirmPrompt != nil {
|
||||
return DefaultConfirmPrompt(prompt)
|
||||
}
|
||||
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
@@ -496,17 +56,3 @@ func confirmPrompt(prompt string) (bool, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterItems(items []selectItem, filter string) []selectItem {
|
||||
if filter == "" {
|
||||
return items
|
||||
}
|
||||
var result []selectItem
|
||||
filterLower := strings.ToLower(filter)
|
||||
for _, item := range items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -1,651 +1,9 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFilterItems(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "llama3.2:latest"},
|
||||
{Name: "qwen2.5:7b"},
|
||||
{Name: "deepseek-v3:cloud"},
|
||||
{Name: "GPT-OSS:20b"},
|
||||
}
|
||||
|
||||
t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) {
|
||||
result := filterItems(items, "")
|
||||
if len(result) != len(items) {
|
||||
t.Errorf("expected %d items, got %d", len(items), len(result))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) {
|
||||
result := filterItems(items, "LLAMA")
|
||||
if len(result) != 1 || result[0].Name != "llama3.2:latest" {
|
||||
t.Errorf("expected llama3.2:latest, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) {
|
||||
result := filterItems(items, "gpt")
|
||||
if len(result) != 1 || result[0].Name != "GPT-OSS:20b" {
|
||||
t.Errorf("expected GPT-OSS:20b, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PartialMatch", func(t *testing.T) {
|
||||
result := filterItems(items, "deep")
|
||||
if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" {
|
||||
t.Errorf("expected deepseek-v3:cloud, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) {
|
||||
result := filterItems(items, "nonexistent")
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected 0 items, got %d", len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSelectState(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
{Name: "item3"},
|
||||
}
|
||||
|
||||
t.Run("InitialState", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0, got %d", s.selected)
|
||||
}
|
||||
if s.filter != "" {
|
||||
t.Errorf("expected empty filter, got %q", s.filter)
|
||||
}
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_SelectsCurrentItem", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "item1" || err != nil {
|
||||
t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "item3"
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "item3" || err != nil {
|
||||
t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "nonexistent"
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != "" || err != nil {
|
||||
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
if !done || result != "" || err != errCancelled {
|
||||
t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_MovesSelection", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 1 {
|
||||
t.Errorf("expected selected=1, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 2 {
|
||||
t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_MovesSelection", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 1 {
|
||||
t.Errorf("expected selected=1, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 (stayed at top), got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventChar, 'i')
|
||||
s.handleInput(eventChar, 't')
|
||||
s.handleInput(eventChar, 'e')
|
||||
s.handleInput(eventChar, 'm')
|
||||
s.handleInput(eventChar, '2')
|
||||
if s.filter != "item2" {
|
||||
t.Errorf("expected filter='item2', got %q", s.filter)
|
||||
}
|
||||
filtered := s.filtered()
|
||||
if len(filtered) != 1 || filtered[0].Name != "item2" {
|
||||
t.Errorf("expected [item2], got %v", filtered)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_ResetsSelectionToZero", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.selected = 2
|
||||
s.handleInput(eventChar, 'x')
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after typing, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "test"
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "tes" {
|
||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "" {
|
||||
t.Errorf("expected filter='', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "test"
|
||||
s.selected = 2
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after backspace, got %d", s.selected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) {
|
||||
// maxDisplayedItems is 10, so with 15 items we need to scroll
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
|
||||
// move down 12 times (past the 10-item viewport)
|
||||
for range 12 {
|
||||
s.handleInput(eventDown, 0)
|
||||
}
|
||||
|
||||
if s.selected != 12 {
|
||||
t.Errorf("expected selected=12, got %d", s.selected)
|
||||
}
|
||||
if s.scrollOffset != 3 {
|
||||
t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
s.selected = 5
|
||||
s.scrollOffset = 5
|
||||
|
||||
s.handleInput(eventUp, 0)
|
||||
|
||||
if s.selected != 4 {
|
||||
t.Errorf("expected selected=4, got %d", s.selected)
|
||||
}
|
||||
if s.scrollOffset != 4 {
|
||||
t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiSelectState(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
{Name: "item3"},
|
||||
}
|
||||
|
||||
t.Run("InitialState_NoPrechecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
if s.selectedCount() != 0 {
|
||||
t.Errorf("expected 0 selected, got %d", s.selectedCount())
|
||||
}
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focusOnButton=false initially")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InitialState_WithPrechecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item3"})
|
||||
if s.selectedCount() != 2 {
|
||||
t.Errorf("expected 2 selected, got %d", s.selectedCount())
|
||||
}
|
||||
if !s.checked[1] || !s.checked[2] {
|
||||
t.Error("expected item2 and item3 to be checked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) {
|
||||
// order matters: first checked = default model
|
||||
s := newMultiSelectState(items, []string{"item3", "item1"})
|
||||
if len(s.checkOrder) != 2 {
|
||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||
}
|
||||
if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 {
|
||||
t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "nonexistent"})
|
||||
if s.selectedCount() != 1 {
|
||||
t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.toggleItem()
|
||||
if !s.checked[0] {
|
||||
t.Error("expected item1 to be checked after toggle")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.toggleItem()
|
||||
if s.checked[0] {
|
||||
t.Error("expected item1 to be unchecked after toggle")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "item2", "item3"})
|
||||
s.highlighted = 1 // toggle item2
|
||||
s.toggleItem()
|
||||
|
||||
if len(s.checkOrder) != 2 {
|
||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||
}
|
||||
// should be [0, 2] (item1, item3) with item2 removed
|
||||
if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 {
|
||||
t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventEnter, 0)
|
||||
if !s.checked[0] {
|
||||
t.Error("expected item1 to be checked after enter")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||
s.focusOnButton = true
|
||||
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
|
||||
if !done || err != nil {
|
||||
t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err)
|
||||
}
|
||||
// result should preserve selection order
|
||||
if len(result) != 2 || result[0] != "item2" || result[1] != "item1" {
|
||||
t.Errorf("expected [item2, item1], got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.focusOnButton = true
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != nil || err != nil {
|
||||
t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.handleInput(eventTab, 0)
|
||||
if !s.focusOnButton {
|
||||
t.Error("expected focus on button after tab")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventTab, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("tab should not focus button when nothing selected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab_TogglesButtonFocus", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.handleInput(eventTab, 0)
|
||||
if !s.focusOnButton {
|
||||
t.Error("expected focus on button after first tab")
|
||||
}
|
||||
s.handleInput(eventTab, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focus back on list after second tab")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
if !done || result != nil || err != errCancelled {
|
||||
t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||
if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) {
|
||||
t.Error("expected item2 (idx 1) to be default (first checked)")
|
||||
}
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||
t.Error("expected item1 (idx 0) to NOT be default")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||
t.Error("expected isDefault=false when nothing checked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Down_MovesHighlight", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.highlighted != 1 {
|
||||
t.Errorf("expected highlighted=1, got %d", s.highlighted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Up_MovesHighlight", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.highlighted = 1
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.focusOnButton = true
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focus to return to list on arrow key")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.handleInput(eventChar, 'x')
|
||||
if s.filter != "x" {
|
||||
t.Errorf("expected filter='x', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newMultiSelectState(manyItems, nil)
|
||||
s.highlighted = 10
|
||||
s.scrollOffset = 5
|
||||
|
||||
s.handleInput(eventChar, 'x')
|
||||
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||
}
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.filter = "test"
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.filter != "tes" {
|
||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
s.filter = "x"
|
||||
s.focusOnButton = true
|
||||
s.handleInput(eventBackspace, 0)
|
||||
if s.focusOnButton {
|
||||
t.Error("expected focusOnButton=false after backspace")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseInput(t *testing.T) {
|
||||
t.Run("Enter", func(t *testing.T) {
|
||||
event, char, err := parseInput(bytes.NewReader([]byte{13}))
|
||||
if err != nil || event != eventEnter || char != 0 {
|
||||
t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27}))
|
||||
if err != nil || event != eventEscape {
|
||||
t.Errorf("expected eventEscape, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{3}))
|
||||
if err != nil || event != eventEscape {
|
||||
t.Errorf("expected eventEscape for Ctrl+C, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Tab", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{9}))
|
||||
if err != nil || event != eventTab {
|
||||
t.Errorf("expected eventTab, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Backspace", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{127}))
|
||||
if err != nil || event != eventBackspace {
|
||||
t.Errorf("expected eventBackspace, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpArrow", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65}))
|
||||
if err != nil || event != eventUp {
|
||||
t.Errorf("expected eventUp, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DownArrow", func(t *testing.T) {
|
||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66}))
|
||||
if err != nil || event != eventDown {
|
||||
t.Errorf("expected eventDown, got %v", event)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PrintableChars", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
char byte
|
||||
}{
|
||||
{"lowercase", 'a'},
|
||||
{"uppercase", 'Z'},
|
||||
{"digit", '5'},
|
||||
{"space", ' '},
|
||||
{"tilde", '~'},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
event, char, err := parseInput(bytes.NewReader([]byte{tt.char}))
|
||||
if err != nil || event != eventChar || char != tt.char {
|
||||
t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRenderSelect(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1", Description: "first item"},
|
||||
{Name: "item2"},
|
||||
}
|
||||
|
||||
t.Run("ShowsPromptAndItems", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
var buf bytes.Buffer
|
||||
lineCount := renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Select:") {
|
||||
t.Error("expected prompt in output")
|
||||
}
|
||||
if !strings.Contains(output, "item1") {
|
||||
t.Error("expected item1 in output")
|
||||
}
|
||||
if !strings.Contains(output, "first item") {
|
||||
t.Error("expected description in output")
|
||||
}
|
||||
if !strings.Contains(output, "item2") {
|
||||
t.Error("expected item2 in output")
|
||||
}
|
||||
if lineCount != 3 { // 1 prompt + 2 items
|
||||
t.Errorf("expected 3 lines, got %d", lineCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
s.filter = "xyz"
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' message")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LongList_ShowsRemainingCount", func(t *testing.T) {
|
||||
manyItems := make([]selectItem, 15)
|
||||
for i := range manyItems {
|
||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
s := newSelectState(manyItems)
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
// 15 items - 10 displayed = 5 more
|
||||
if !strings.Contains(buf.String(), "5 more") {
|
||||
t.Error("expected '5 more' indicator")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRenderMultiSelect(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1"},
|
||||
{Name: "item2"},
|
||||
}
|
||||
|
||||
t.Run("ShowsCheckboxes", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "[x]") {
|
||||
t.Error("expected checked checkbox [x]")
|
||||
}
|
||||
if !strings.Contains(output, "[ ]") {
|
||||
t.Error("expected unchecked checkbox [ ]")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ShowsDefaultMarker", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "(default)") {
|
||||
t.Error("expected (default) marker for first checked item")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ShowsSelectedCount", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, []string{"item1", "item2"})
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "2 selected") {
|
||||
t.Error("expected '2 selected' in output")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoSelection_ShowsHelperText", func(t *testing.T) {
|
||||
s := newMultiSelectState(items, nil)
|
||||
var buf bytes.Buffer
|
||||
renderMultiSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "Select at least one") {
|
||||
t.Error("expected 'Select at least one' helper text")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrCancelled(t *testing.T) {
|
||||
t.Run("NotNil", func(t *testing.T) {
|
||||
if errCancelled == nil {
|
||||
@@ -659,255 +17,3 @@ func TestErrCancelled(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Edge case tests for selector.go
|
||||
|
||||
// TestSelectState_SingleItem verifies that single item list works without crash.
|
||||
// List with only one item should still work.
|
||||
func TestSelectState_SingleItem(t *testing.T) {
|
||||
items := []selectItem{{Name: "only-one"}}
|
||||
|
||||
s := newSelectState(items)
|
||||
|
||||
// Down should do nothing (already at bottom)
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("down on single item: expected selected=0, got %d", s.selected)
|
||||
}
|
||||
|
||||
// Up should do nothing (already at top)
|
||||
s.handleInput(eventUp, 0)
|
||||
if s.selected != 0 {
|
||||
t.Errorf("up on single item: expected selected=0, got %d", s.selected)
|
||||
}
|
||||
|
||||
// Enter should select the only item
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if !done || result != "only-one" || err != nil {
|
||||
t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems.
|
||||
// List with exactly maxDisplayedItems items should not scroll.
|
||||
func TestSelectState_ExactlyMaxItems(t *testing.T) {
|
||||
items := make([]selectItem, maxDisplayedItems)
|
||||
for i := range items {
|
||||
items[i] = selectItem{Name: string(rune('a' + i))}
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
|
||||
// Move to last item
|
||||
for range maxDisplayedItems - 1 {
|
||||
s.handleInput(eventDown, 0)
|
||||
}
|
||||
|
||||
if s.selected != maxDisplayedItems-1 {
|
||||
t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||
}
|
||||
|
||||
// Should not scroll when exactly at max
|
||||
if s.scrollOffset != 0 {
|
||||
t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset)
|
||||
}
|
||||
|
||||
// One more down should do nothing
|
||||
s.handleInput(eventDown, 0)
|
||||
if s.selected != maxDisplayedItems-1 {
|
||||
t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex.
|
||||
// User typing "model.v1" shouldn't match "modelsv1".
|
||||
func TestFilterItems_RegexSpecialChars(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "model.v1"},
|
||||
{Name: "modelsv1"},
|
||||
{Name: "model-v1"},
|
||||
}
|
||||
|
||||
// Filter with dot should only match literal dot
|
||||
result := filterItems(items, "model.v1")
|
||||
if len(result) != 1 {
|
||||
t.Errorf("expected 1 exact match, got %d", len(result))
|
||||
}
|
||||
if len(result) > 0 && result[0].Name != "model.v1" {
|
||||
t.Errorf("expected 'model.v1', got %s", result[0].Name)
|
||||
}
|
||||
|
||||
// Other regex special chars should be literal too
|
||||
items2 := []selectItem{
|
||||
{Name: "test[0]"},
|
||||
{Name: "test0"},
|
||||
{Name: "test(1)"},
|
||||
}
|
||||
|
||||
result2 := filterItems(items2, "test[0]")
|
||||
if len(result2) != 1 || result2[0].Name != "test[0]" {
|
||||
t.Errorf("expected only 'test[0]', got %v", result2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiSelectState_DuplicateNames documents handling of duplicate item names.
|
||||
// itemIndex uses name as key - duplicates cause collision. This documents
|
||||
// the current behavior: the last index for a duplicate name is stored
|
||||
func TestMultiSelectState_DuplicateNames(t *testing.T) {
|
||||
// Duplicate names - this is an edge case that shouldn't happen in practice
|
||||
items := []selectItem{
|
||||
{Name: "duplicate"},
|
||||
{Name: "duplicate"},
|
||||
{Name: "unique"},
|
||||
}
|
||||
|
||||
s := newMultiSelectState(items, nil)
|
||||
|
||||
// DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index
|
||||
// When there are duplicates, only the last occurrence's index is stored
|
||||
if s.itemIndex["duplicate"] != 1 {
|
||||
t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"])
|
||||
}
|
||||
|
||||
// Toggle item at highlighted=0 (first "duplicate")
|
||||
// Due to name collision, toggleItem uses itemIndex["duplicate"] = 1
|
||||
// So it actually toggles the SECOND duplicate item, not the first
|
||||
s.toggleItem()
|
||||
|
||||
// This documents the potentially surprising behavior:
|
||||
// We toggled at highlighted=0, but itemIndex lookup returned 1
|
||||
if !s.checked[1] {
|
||||
t.Error("toggle should check index 1 (due to name collision in itemIndex)")
|
||||
}
|
||||
if s.checked[0] {
|
||||
t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list.
|
||||
// Prevents index-out-of-bounds on next keystroke
|
||||
func TestSelectState_FilterReducesBelowSelection(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "apple"},
|
||||
{Name: "banana"},
|
||||
{Name: "cherry"},
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
s.selected = 2 // Select "cherry"
|
||||
|
||||
// Type a filter that removes cherry from results
|
||||
s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana"
|
||||
|
||||
// Selection should reset to 0
|
||||
if s.selected != 0 {
|
||||
t.Errorf("expected selected=0 after filter, got %d", s.selected)
|
||||
}
|
||||
|
||||
filtered := s.filtered()
|
||||
if len(filtered) != 2 {
|
||||
t.Errorf("expected 2 filtered items, got %d", len(filtered))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8.
|
||||
// Model names might contain unicode characters
|
||||
func TestFilterItems_UnicodeCharacters(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "llama-日本語"},
|
||||
{Name: "模型-chinese"},
|
||||
{Name: "émoji-🦙"},
|
||||
{Name: "regular-model"},
|
||||
}
|
||||
|
||||
t.Run("filter japanese", func(t *testing.T) {
|
||||
result := filterItems(items, "日本")
|
||||
if len(result) != 1 || result[0].Name != "llama-日本語" {
|
||||
t.Errorf("expected llama-日本語, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter chinese", func(t *testing.T) {
|
||||
result := filterItems(items, "模型")
|
||||
if len(result) != 1 || result[0].Name != "模型-chinese" {
|
||||
t.Errorf("expected 模型-chinese, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter emoji", func(t *testing.T) {
|
||||
result := filterItems(items, "🦙")
|
||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||
t.Errorf("expected émoji-🦙, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter accented char", func(t *testing.T) {
|
||||
result := filterItems(items, "émoji")
|
||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||
t.Errorf("expected émoji-🦙, got %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list.
|
||||
func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "apple"},
|
||||
{Name: "banana"},
|
||||
{Name: "cherry"},
|
||||
}
|
||||
|
||||
s := newMultiSelectState(items, nil)
|
||||
s.highlighted = 2 // Highlight "cherry"
|
||||
|
||||
// Type a filter that removes cherry
|
||||
s.handleInput(eventChar, 'a')
|
||||
|
||||
if s.highlighted != 0 {
|
||||
t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiSelectState_EmptyItems verifies handling of empty item list.
|
||||
// Empty list should be handled gracefully.
|
||||
func TestMultiSelectState_EmptyItems(t *testing.T) {
|
||||
s := newMultiSelectState([]selectItem{}, nil)
|
||||
|
||||
// Toggle should not panic on empty list
|
||||
s.toggleItem()
|
||||
|
||||
if s.selectedCount() != 0 {
|
||||
t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount())
|
||||
}
|
||||
|
||||
// Render should handle empty list
|
||||
var buf bytes.Buffer
|
||||
lineCount := renderMultiSelect(&buf, "Select:", s)
|
||||
if lineCount == 0 {
|
||||
t.Error("renderMultiSelect should produce output even for empty list")
|
||||
}
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' for empty list")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions.
|
||||
func TestSelectState_RenderWithDescriptions(t *testing.T) {
|
||||
items := []selectItem{
|
||||
{Name: "item1", Description: "First item description"},
|
||||
{Name: "item2", Description: ""},
|
||||
{Name: "item3", Description: "Third item"},
|
||||
}
|
||||
|
||||
s := newSelectState(items)
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "First item description") {
|
||||
t.Error("expected description to be rendered")
|
||||
}
|
||||
if !strings.Contains(output, "item2") {
|
||||
t.Error("expected item without description to be rendered")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,19 +10,21 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
|
||||
|
||||
func startApp(ctx context.Context, client *api.Client) error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
return errNotRunning
|
||||
}
|
||||
link, err := os.Readlink(exe)
|
||||
if err != nil {
|
||||
return err
|
||||
return errNotRunning
|
||||
}
|
||||
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
|
||||
m := r.FindStringSubmatch(link)
|
||||
if len(m) != 1 {
|
||||
return errors.New("could not find ollama app")
|
||||
return errNotRunning
|
||||
}
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
|
||||
return err
|
||||
|
||||
109
cmd/tui/confirm.go
Normal file
109
cmd/tui/confirm.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
var (
|
||||
confirmActiveStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
|
||||
|
||||
confirmInactiveStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
)
|
||||
|
||||
type confirmModel struct {
|
||||
prompt string
|
||||
yes bool
|
||||
confirmed bool
|
||||
cancelled bool
|
||||
width int
|
||||
}
|
||||
|
||||
func (m confirmModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m confirmModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "esc", "n":
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
case "y":
|
||||
m.yes = true
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
case "enter":
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
case "left", "h":
|
||||
m.yes = true
|
||||
case "right", "l":
|
||||
m.yes = false
|
||||
case "tab":
|
||||
m.yes = !m.yes
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m confirmModel) View() string {
|
||||
if m.confirmed || m.cancelled {
|
||||
return ""
|
||||
}
|
||||
|
||||
var yesBtn, noBtn string
|
||||
if m.yes {
|
||||
yesBtn = confirmActiveStyle.Render(" Yes ")
|
||||
noBtn = confirmInactiveStyle.Render(" No ")
|
||||
} else {
|
||||
yesBtn = confirmInactiveStyle.Render(" Yes ")
|
||||
noBtn = confirmActiveStyle.Render(" No ")
|
||||
}
|
||||
|
||||
s := selectorTitleStyle.Render(m.prompt) + "\n\n"
|
||||
s += " " + yesBtn + " " + noBtn + "\n\n"
|
||||
s += selectorHelpStyle.Render("←/→ navigate • enter confirm • esc cancel")
|
||||
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// RunConfirm shows a bubbletea yes/no confirmation prompt.
|
||||
// Returns true if the user confirmed, false if cancelled.
|
||||
func RunConfirm(prompt string) (bool, error) {
|
||||
m := confirmModel{
|
||||
prompt: prompt,
|
||||
yes: true, // default to yes
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error running confirm: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(confirmModel)
|
||||
if fm.cancelled {
|
||||
return false, ErrCancelled
|
||||
}
|
||||
|
||||
return fm.yes, nil
|
||||
}
|
||||
208
cmd/tui/confirm_test.go
Normal file
208
cmd/tui/confirm_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
func TestConfirmModel_DefaultsToYes(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download test?", yes: true}
|
||||
if !m.yes {
|
||||
t.Error("should default to yes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ContainsPrompt(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download qwen3:8b?", yes: true}
|
||||
got := m.View()
|
||||
if !strings.Contains(got, "Download qwen3:8b?") {
|
||||
t.Error("should contain the prompt text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ContainsButtons(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
got := m.View()
|
||||
if !strings.Contains(got, "Yes") {
|
||||
t.Error("should contain Yes button")
|
||||
}
|
||||
if !strings.Contains(got, "No") {
|
||||
t.Error("should contain No button")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ContainsHelp(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
got := m.View()
|
||||
if !strings.Contains(got, "enter confirm") {
|
||||
t.Error("should contain help text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ClearsAfterConfirm(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", confirmed: true}
|
||||
if m.View() != "" {
|
||||
t.Error("View should return empty string after confirmation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ClearsAfterCancel(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", cancelled: true}
|
||||
if m.View() != "" {
|
||||
t.Error("View should return empty string after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_EnterConfirmsYes(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.confirmed {
|
||||
t.Error("enter should set confirmed=true")
|
||||
}
|
||||
if !fm.yes {
|
||||
t.Error("enter with yes selected should keep yes=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("enter should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_EnterConfirmsNo(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: false}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.confirmed {
|
||||
t.Error("enter should set confirmed=true")
|
||||
}
|
||||
if fm.yes {
|
||||
t.Error("enter with no selected should keep yes=false")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("enter should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_EscCancels(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEsc})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("esc should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("esc should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_CtrlCCancels(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("ctrl+c should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("ctrl+c should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_NCancels(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'n'}})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("'n' should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("'n' should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_YConfirmsYes(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: false}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'y'}})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.confirmed {
|
||||
t.Error("'y' should set confirmed=true")
|
||||
}
|
||||
if !fm.yes {
|
||||
t.Error("'y' should set yes=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("'y' should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_ArrowKeysNavigate(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
|
||||
// Right moves to No
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'l'}})
|
||||
fm := updated.(confirmModel)
|
||||
if fm.yes {
|
||||
t.Error("right/l should move to No")
|
||||
}
|
||||
if fm.confirmed || fm.cancelled {
|
||||
t.Error("navigation should not confirm or cancel")
|
||||
}
|
||||
|
||||
// Left moves back to Yes
|
||||
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'h'}})
|
||||
fm = updated.(confirmModel)
|
||||
if !fm.yes {
|
||||
t.Error("left/h should move to Yes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_TabToggles(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
fm := updated.(confirmModel)
|
||||
if fm.yes {
|
||||
t.Error("tab should toggle from Yes to No")
|
||||
}
|
||||
|
||||
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
fm = updated.(confirmModel)
|
||||
if !fm.yes {
|
||||
t.Error("tab should toggle from No to Yes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_WindowSizeUpdatesWidth(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?"}
|
||||
updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 40})
|
||||
fm := updated.(confirmModel)
|
||||
if fm.width != 100 {
|
||||
t.Errorf("expected width 100, got %d", fm.width)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_ResizeEntersAltScreen(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", width: 80}
|
||||
_, cmd := m.Update(tea.WindowSizeMsg{Width: 100, Height: 40})
|
||||
if cmd == nil {
|
||||
t.Error("resize (width already set) should return a command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_InitialWindowSizeNoAltScreen(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?"}
|
||||
_, cmd := m.Update(tea.WindowSizeMsg{Width: 80, Height: 40})
|
||||
if cmd != nil {
|
||||
t.Error("initial WindowSizeMsg should not return a command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_ViewMaxWidth(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true, width: 40}
|
||||
got := m.View()
|
||||
// Just ensure it doesn't panic and returns content
|
||||
if got == "" {
|
||||
t.Error("View with width set should still return content")
|
||||
}
|
||||
}
|
||||
654
cmd/tui/selector.go
Normal file
654
cmd/tui/selector.go
Normal file
@@ -0,0 +1,654 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
var (
|
||||
selectorTitleStyle = lipgloss.NewStyle().
|
||||
Bold(true)
|
||||
|
||||
selectorItemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(4)
|
||||
|
||||
selectorSelectedItemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Bold(true).
|
||||
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
|
||||
|
||||
selectorDescStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
|
||||
selectorDescLineStyle = selectorDescStyle.
|
||||
PaddingLeft(6)
|
||||
|
||||
selectorFilterStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
|
||||
Italic(true)
|
||||
|
||||
selectorInputStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "235", Dark: "252"})
|
||||
|
||||
selectorCheckboxStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
|
||||
selectorCheckboxCheckedStyle = lipgloss.NewStyle().
|
||||
Bold(true)
|
||||
|
||||
selectorDefaultTagStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
|
||||
Italic(true)
|
||||
|
||||
selectorHelpStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "244", Dark: "244"})
|
||||
|
||||
selectorMoreStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(6).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
|
||||
Italic(true)
|
||||
|
||||
sectionHeaderStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Bold(true).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "240", Dark: "249"})
|
||||
)
|
||||
|
||||
const maxSelectorItems = 10
|
||||
|
||||
// ErrCancelled is returned when the user cancels the selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
type SelectItem struct {
|
||||
Name string
|
||||
Description string
|
||||
Recommended bool
|
||||
}
|
||||
|
||||
// selectorModel is the bubbletea model for single selection.
|
||||
type selectorModel struct {
|
||||
title string
|
||||
items []SelectItem
|
||||
filter string
|
||||
cursor int
|
||||
scrollOffset int
|
||||
selected string
|
||||
cancelled bool
|
||||
helpText string
|
||||
width int
|
||||
}
|
||||
|
||||
func (m selectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
}
|
||||
filterLower := strings.ToLower(m.filter)
|
||||
var result []SelectItem
|
||||
for _, item := range m.items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m selectorModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// otherStart returns the index of the first non-recommended item in the filtered list.
|
||||
// When filtering, all items scroll together so this returns 0.
|
||||
func (m selectorModel) otherStart() int {
|
||||
if m.filter != "" {
|
||||
return 0
|
||||
}
|
||||
filtered := m.filteredItems()
|
||||
for i, item := range filtered {
|
||||
if !item.Recommended {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return len(filtered)
|
||||
}
|
||||
|
||||
// updateNavigation handles navigation keys (up/down/pgup/pgdown/filter/backspace).
|
||||
// It does NOT handle Enter, Esc, or CtrlC. This is used by both the standalone
|
||||
// selector and the TUI modal (which intercepts Enter/Esc for its own logic).
|
||||
func (m *selectorModel) updateNavigation(msg tea.KeyMsg) {
|
||||
filtered := m.filteredItems()
|
||||
otherStart := m.otherStart()
|
||||
|
||||
switch msg.Type {
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
m.updateScroll(otherStart)
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
if m.cursor < len(filtered)-1 {
|
||||
m.cursor++
|
||||
m.updateScroll(otherStart)
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
m.cursor -= maxSelectorItems
|
||||
if m.cursor < 0 {
|
||||
m.cursor = 0
|
||||
}
|
||||
m.updateScroll(otherStart)
|
||||
|
||||
case tea.KeyPgDown:
|
||||
m.cursor += maxSelectorItems
|
||||
if m.cursor >= len(filtered) {
|
||||
m.cursor = len(filtered) - 1
|
||||
}
|
||||
m.updateScroll(otherStart)
|
||||
|
||||
case tea.KeyBackspace:
|
||||
if len(m.filter) > 0 {
|
||||
m.filter = m.filter[:len(m.filter)-1]
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
|
||||
// updateScroll adjusts scrollOffset based on cursor position.
|
||||
// When not filtering, scrollOffset is relative to the "More" (non-recommended) section.
|
||||
// When filtering, it's relative to the full filtered list.
|
||||
func (m *selectorModel) updateScroll(otherStart int) {
|
||||
if m.filter != "" {
|
||||
if m.cursor < m.scrollOffset {
|
||||
m.scrollOffset = m.cursor
|
||||
}
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Cursor is in recommended section — reset "More" scroll to top
|
||||
if m.cursor < otherStart {
|
||||
m.scrollOffset = 0
|
||||
return
|
||||
}
|
||||
|
||||
// Cursor is in "More" section — scroll relative to others
|
||||
posInOthers := m.cursor - otherStart
|
||||
maxOthers := maxSelectorItems - otherStart
|
||||
if maxOthers < 3 {
|
||||
maxOthers = 3
|
||||
}
|
||||
if posInOthers < m.scrollOffset {
|
||||
m.scrollOffset = posInOthers
|
||||
}
|
||||
if posInOthers >= m.scrollOffset+maxOthers {
|
||||
m.scrollOffset = posInOthers - maxOthers + 1
|
||||
}
|
||||
}
|
||||
|
||||
func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyEnter:
|
||||
filtered := m.filteredItems()
|
||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||
m.selected = filtered[m.cursor].Name
|
||||
}
|
||||
return m, tea.Quit
|
||||
|
||||
default:
|
||||
m.updateNavigation(msg)
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m selectorModel) renderItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
s.WriteString("\n")
|
||||
if item.Description != "" {
|
||||
s.WriteString(selectorDescLineStyle.Render(item.Description))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// renderContent renders the selector content (title, items, help text) without
|
||||
// checking the cancelled/selected state. This is used by both View() (standalone mode)
|
||||
// and by the TUI modal which embeds a selectorModel.
|
||||
func (m selectorModel) renderContent() string {
|
||||
var s strings.Builder
|
||||
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
s.WriteString(" ")
|
||||
if m.filter == "" {
|
||||
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
|
||||
} else {
|
||||
s.WriteString(selectorInputStyle.Render(m.filter))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
filtered := m.filteredItems()
|
||||
|
||||
if len(filtered) == 0 {
|
||||
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
|
||||
s.WriteString("\n")
|
||||
} else if m.filter != "" {
|
||||
s.WriteString(sectionHeaderStyle.Render("Top Results"))
|
||||
s.WriteString("\n")
|
||||
|
||||
displayCount := min(len(filtered), maxSelectorItems)
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
m.renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
// Split into pinned recommended and scrollable others
|
||||
var recItems, otherItems []int
|
||||
for i, item := range filtered {
|
||||
if item.Recommended {
|
||||
recItems = append(recItems, i)
|
||||
} else {
|
||||
otherItems = append(otherItems, i)
|
||||
}
|
||||
}
|
||||
|
||||
// Always render all recommended items (pinned)
|
||||
if len(recItems) > 0 {
|
||||
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
||||
s.WriteString("\n")
|
||||
for _, idx := range recItems {
|
||||
m.renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
}
|
||||
|
||||
if len(otherItems) > 0 {
|
||||
s.WriteString("\n")
|
||||
s.WriteString(sectionHeaderStyle.Render("More"))
|
||||
s.WriteString("\n")
|
||||
|
||||
maxOthers := maxSelectorItems - len(recItems)
|
||||
if maxOthers < 3 {
|
||||
maxOthers = 3
|
||||
}
|
||||
displayCount := min(len(otherItems), maxOthers)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(otherItems) {
|
||||
break
|
||||
}
|
||||
m.renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
}
|
||||
|
||||
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.WriteString("\n")
|
||||
help := "↑/↓ navigate • enter select • esc cancel"
|
||||
if m.helpText != "" {
|
||||
help = m.helpText
|
||||
}
|
||||
s.WriteString(selectorHelpStyle.Render(help))
|
||||
|
||||
return s.String()
|
||||
}
|
||||
|
||||
func (m selectorModel) View() string {
|
||||
if m.cancelled || m.selected != "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
s := m.renderContent()
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func SelectSingle(title string, items []SelectItem) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error running selector: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(selectorModel)
|
||||
if fm.cancelled {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
|
||||
return fm.selected, nil
|
||||
}
|
||||
|
||||
// multiSelectorModel is the bubbletea model for multi selection.
|
||||
type multiSelectorModel struct {
|
||||
title string
|
||||
items []SelectItem
|
||||
itemIndex map[string]int
|
||||
filter string
|
||||
cursor int
|
||||
scrollOffset int
|
||||
checked map[int]bool
|
||||
checkOrder []int
|
||||
cancelled bool
|
||||
confirmed bool
|
||||
width int
|
||||
}
|
||||
|
||||
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
|
||||
m := multiSelectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
itemIndex: make(map[string]int, len(items)),
|
||||
checked: make(map[int]bool),
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
m.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := m.itemIndex[name]; ok {
|
||||
m.checked[idx] = true
|
||||
m.checkOrder = append(m.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
}
|
||||
filterLower := strings.ToLower(m.filter)
|
||||
var result []SelectItem
|
||||
for _, item := range m.items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *multiSelectorModel) toggleItem() {
|
||||
filtered := m.filteredItems()
|
||||
if len(filtered) == 0 || m.cursor >= len(filtered) {
|
||||
return
|
||||
}
|
||||
|
||||
item := filtered[m.cursor]
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
if m.checked[origIdx] {
|
||||
delete(m.checked, origIdx)
|
||||
for i, idx := range m.checkOrder {
|
||||
if idx == origIdx {
|
||||
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
m.checked[origIdx] = true
|
||||
m.checkOrder = append(m.checkOrder, origIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) selectedCount() int {
|
||||
return len(m.checkOrder)
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
filtered := m.filteredItems()
|
||||
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyEnter:
|
||||
if len(m.checkOrder) > 0 {
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case tea.KeySpace:
|
||||
m.toggleItem()
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
if m.cursor < m.scrollOffset {
|
||||
m.scrollOffset = m.cursor
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
if m.cursor < len(filtered)-1 {
|
||||
m.cursor++
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
m.cursor -= maxSelectorItems
|
||||
if m.cursor < 0 {
|
||||
m.cursor = 0
|
||||
}
|
||||
m.scrollOffset -= maxSelectorItems
|
||||
if m.scrollOffset < 0 {
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyPgDown:
|
||||
m.cursor += maxSelectorItems
|
||||
if m.cursor >= len(filtered) {
|
||||
m.cursor = len(filtered) - 1
|
||||
}
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
|
||||
case tea.KeyBackspace:
|
||||
if len(m.filter) > 0 {
|
||||
m.filter = m.filter[:len(m.filter)-1]
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) View() string {
|
||||
if m.cancelled || m.confirmed {
|
||||
return ""
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
s.WriteString(" ")
|
||||
if m.filter == "" {
|
||||
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
|
||||
} else {
|
||||
s.WriteString(selectorInputStyle.Render(m.filter))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
filtered := m.filteredItems()
|
||||
|
||||
if len(filtered) == 0 {
|
||||
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
|
||||
s.WriteString("\n")
|
||||
} else {
|
||||
displayCount := min(len(filtered), maxSelectorItems)
|
||||
shownRecHeader := false
|
||||
prevWasRec := false
|
||||
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
item := filtered[idx]
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
if m.filter == "" {
|
||||
if item.Recommended && !shownRecHeader {
|
||||
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
||||
s.WriteString("\n")
|
||||
shownRecHeader = true
|
||||
} else if !item.Recommended && prevWasRec {
|
||||
s.WriteString("\n")
|
||||
}
|
||||
prevWasRec = item.Recommended
|
||||
}
|
||||
|
||||
var checkbox string
|
||||
if m.checked[origIdx] {
|
||||
checkbox = selectorCheckboxCheckedStyle.Render("[x]")
|
||||
} else {
|
||||
checkbox = selectorCheckboxStyle.Render("[ ]")
|
||||
}
|
||||
|
||||
var line string
|
||||
if idx == m.cursor {
|
||||
line = selectorSelectedItemStyle.Render("▸ ") + checkbox + " " + selectorSelectedItemStyle.Render(item.Name)
|
||||
} else {
|
||||
line = " " + checkbox + " " + item.Name
|
||||
}
|
||||
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
|
||||
line += " " + selectorDefaultTagStyle.Render("(default)")
|
||||
}
|
||||
|
||||
s.WriteString(line)
|
||||
s.WriteString("\n")
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
s.WriteString("\n")
|
||||
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
|
||||
|
||||
result := s.String()
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(result)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := newMultiSelectorModel(title, items, preChecked)
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error running selector: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(multiSelectorModel)
|
||||
if fm.cancelled {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
if !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
var result []string
|
||||
for _, idx := range fm.checkOrder {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
410
cmd/tui/selector_test.go
Normal file
410
cmd/tui/selector_test.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
func items(names ...string) []SelectItem {
|
||||
var out []SelectItem
|
||||
for _, n := range names {
|
||||
out = append(out, SelectItem{Name: n})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func recItems(names ...string) []SelectItem {
|
||||
var out []SelectItem
|
||||
for _, n := range names {
|
||||
out = append(out, SelectItem{Name: n, Recommended: true})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mixedItems() []SelectItem {
|
||||
return []SelectItem{
|
||||
{Name: "rec-a", Recommended: true},
|
||||
{Name: "rec-b", Recommended: true},
|
||||
{Name: "other-1"},
|
||||
{Name: "other-2"},
|
||||
{Name: "other-3"},
|
||||
{Name: "other-4"},
|
||||
{Name: "other-5"},
|
||||
{Name: "other-6"},
|
||||
{Name: "other-7"},
|
||||
{Name: "other-8"},
|
||||
{Name: "other-9"},
|
||||
{Name: "other-10"},
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilteredItems(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
items []SelectItem
|
||||
filter string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "no filter returns all",
|
||||
items: items("alpha", "beta", "gamma"),
|
||||
filter: "",
|
||||
want: []string{"alpha", "beta", "gamma"},
|
||||
},
|
||||
{
|
||||
name: "filter matches substring",
|
||||
items: items("llama3.2", "qwen3:8b", "llama2"),
|
||||
filter: "llama",
|
||||
want: []string{"llama3.2", "llama2"},
|
||||
},
|
||||
{
|
||||
name: "filter is case insensitive",
|
||||
items: items("Qwen3:8b", "llama3.2"),
|
||||
filter: "QWEN",
|
||||
want: []string{"Qwen3:8b"},
|
||||
},
|
||||
{
|
||||
name: "no matches",
|
||||
items: items("alpha", "beta"),
|
||||
filter: "zzz",
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := selectorModel{items: tt.items, filter: tt.filter}
|
||||
got := m.filteredItems()
|
||||
var gotNames []string
|
||||
for _, item := range got {
|
||||
gotNames = append(gotNames, item.Name)
|
||||
}
|
||||
if len(gotNames) != len(tt.want) {
|
||||
t.Fatalf("got %v, want %v", gotNames, tt.want)
|
||||
}
|
||||
for i := range tt.want {
|
||||
if gotNames[i] != tt.want[i] {
|
||||
t.Errorf("index %d: got %q, want %q", i, gotNames[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOtherStart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
items []SelectItem
|
||||
filter string
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "all recommended",
|
||||
items: recItems("a", "b", "c"),
|
||||
want: 3,
|
||||
},
|
||||
{
|
||||
name: "none recommended",
|
||||
items: items("a", "b"),
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "mixed",
|
||||
items: []SelectItem{
|
||||
{Name: "rec-a", Recommended: true},
|
||||
{Name: "rec-b", Recommended: true},
|
||||
{Name: "other-1"},
|
||||
{Name: "other-2"},
|
||||
},
|
||||
want: 2,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
items: nil,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "filtering returns 0",
|
||||
items: []SelectItem{
|
||||
{Name: "rec-a", Recommended: true},
|
||||
{Name: "other-1"},
|
||||
},
|
||||
filter: "rec",
|
||||
want: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := selectorModel{items: tt.items, filter: tt.filter}
|
||||
if got := m.otherStart(); got != tt.want {
|
||||
t.Errorf("otherStart() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateScroll(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cursor int
|
||||
offset int
|
||||
otherStart int
|
||||
filter string
|
||||
wantOffset int
|
||||
}{
|
||||
{
|
||||
name: "cursor in recommended resets scroll",
|
||||
cursor: 1,
|
||||
offset: 5,
|
||||
otherStart: 3,
|
||||
wantOffset: 0,
|
||||
},
|
||||
{
|
||||
name: "cursor at start of others",
|
||||
cursor: 2,
|
||||
offset: 0,
|
||||
otherStart: 2,
|
||||
wantOffset: 0,
|
||||
},
|
||||
{
|
||||
name: "cursor scrolls down in others",
|
||||
cursor: 12,
|
||||
offset: 0,
|
||||
otherStart: 2,
|
||||
wantOffset: 3, // posInOthers=10, maxOthers=8, 10-8+1=3
|
||||
},
|
||||
{
|
||||
name: "cursor scrolls up in others",
|
||||
cursor: 4,
|
||||
offset: 5,
|
||||
otherStart: 2,
|
||||
wantOffset: 2, // posInOthers=2 < offset=5
|
||||
},
|
||||
{
|
||||
name: "filter mode standard scroll down",
|
||||
cursor: 12,
|
||||
offset: 0,
|
||||
filter: "x",
|
||||
otherStart: 0,
|
||||
wantOffset: 3, // 12 - 10 + 1 = 3
|
||||
},
|
||||
{
|
||||
name: "filter mode standard scroll up",
|
||||
cursor: 2,
|
||||
offset: 5,
|
||||
filter: "x",
|
||||
otherStart: 0,
|
||||
wantOffset: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := selectorModel{
|
||||
cursor: tt.cursor,
|
||||
scrollOffset: tt.offset,
|
||||
filter: tt.filter,
|
||||
}
|
||||
m.updateScroll(tt.otherStart)
|
||||
if m.scrollOffset != tt.wantOffset {
|
||||
t.Errorf("scrollOffset = %d, want %d", m.scrollOffset, tt.wantOffset)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_SectionHeaders(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: []SelectItem{
|
||||
{Name: "rec-a", Recommended: true},
|
||||
{Name: "other-1"},
|
||||
},
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
if !strings.Contains(content, "Recommended") {
|
||||
t.Error("should contain 'Recommended' header")
|
||||
}
|
||||
if !strings.Contains(content, "More") {
|
||||
t.Error("should contain 'More' header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_FilteredHeader(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: items("alpha", "beta", "alphabet"),
|
||||
filter: "alpha",
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
if !strings.Contains(content, "Top Results") {
|
||||
t.Error("filtered view should contain 'Top Results' header")
|
||||
}
|
||||
if strings.Contains(content, "Recommended") {
|
||||
t.Error("filtered view should not contain 'Recommended' header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_NoMatches(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: items("alpha"),
|
||||
filter: "zzz",
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
if !strings.Contains(content, "(no matches)") {
|
||||
t.Error("should show '(no matches)' when filter has no results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_SelectedItemIndicator(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: items("alpha", "beta"),
|
||||
cursor: 0,
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
if !strings.Contains(content, "▸") {
|
||||
t.Error("selected item should have ▸ indicator")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_Description(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: []SelectItem{
|
||||
{Name: "alpha", Description: "the first letter"},
|
||||
},
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
if !strings.Contains(content, "the first letter") {
|
||||
t.Error("should render item description")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_PinnedRecommended(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: mixedItems(),
|
||||
// cursor deep in "More" section
|
||||
cursor: 8,
|
||||
scrollOffset: 3,
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
// Recommended items should always be visible (pinned)
|
||||
if !strings.Contains(content, "rec-a") {
|
||||
t.Error("recommended items should always be rendered (pinned)")
|
||||
}
|
||||
if !strings.Contains(content, "rec-b") {
|
||||
t.Error("recommended items should always be rendered (pinned)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_MoreOverflowIndicator(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
items: mixedItems(), // 2 rec + 10 other = 12 total, maxSelectorItems=10
|
||||
}
|
||||
content := m.renderContent()
|
||||
|
||||
if !strings.Contains(content, "... and") {
|
||||
t.Error("should show overflow indicator when more items than visible")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateNavigation_CursorBounds(t *testing.T) {
|
||||
m := selectorModel{
|
||||
items: items("a", "b", "c"),
|
||||
cursor: 0,
|
||||
}
|
||||
|
||||
// Up at top stays at 0
|
||||
m.updateNavigation(keyMsg(KeyUp))
|
||||
if m.cursor != 0 {
|
||||
t.Errorf("cursor should stay at 0 when pressing up at top, got %d", m.cursor)
|
||||
}
|
||||
|
||||
// Down moves to 1
|
||||
m.updateNavigation(keyMsg(KeyDown))
|
||||
if m.cursor != 1 {
|
||||
t.Errorf("cursor should be 1 after down, got %d", m.cursor)
|
||||
}
|
||||
|
||||
// Down to end
|
||||
m.updateNavigation(keyMsg(KeyDown))
|
||||
m.updateNavigation(keyMsg(KeyDown))
|
||||
if m.cursor != 2 {
|
||||
t.Errorf("cursor should be 2 at bottom, got %d", m.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateNavigation_FilterResetsState(t *testing.T) {
|
||||
m := selectorModel{
|
||||
items: items("alpha", "beta"),
|
||||
cursor: 1,
|
||||
scrollOffset: 5,
|
||||
}
|
||||
|
||||
m.updateNavigation(runeMsg('x'))
|
||||
if m.filter != "x" {
|
||||
t.Errorf("filter should be 'x', got %q", m.filter)
|
||||
}
|
||||
if m.cursor != 0 {
|
||||
t.Errorf("cursor should reset to 0 on filter, got %d", m.cursor)
|
||||
}
|
||||
if m.scrollOffset != 0 {
|
||||
t.Errorf("scrollOffset should reset to 0 on filter, got %d", m.scrollOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateNavigation_Backspace(t *testing.T) {
|
||||
m := selectorModel{
|
||||
items: items("alpha"),
|
||||
filter: "abc",
|
||||
cursor: 1,
|
||||
}
|
||||
|
||||
m.updateNavigation(keyMsg(KeyBackspace))
|
||||
if m.filter != "ab" {
|
||||
t.Errorf("filter should be 'ab' after backspace, got %q", m.filter)
|
||||
}
|
||||
if m.cursor != 0 {
|
||||
t.Errorf("cursor should reset to 0 on backspace, got %d", m.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
// Key message helpers for testing
|
||||
|
||||
type keyType = int
|
||||
|
||||
const (
|
||||
KeyUp keyType = iota
|
||||
KeyDown keyType = iota
|
||||
KeyBackspace keyType = iota
|
||||
)
|
||||
|
||||
func keyMsg(k keyType) tea.KeyMsg {
|
||||
switch k {
|
||||
case KeyUp:
|
||||
return tea.KeyMsg{Type: tea.KeyUp}
|
||||
case KeyDown:
|
||||
return tea.KeyMsg{Type: tea.KeyDown}
|
||||
case KeyBackspace:
|
||||
return tea.KeyMsg{Type: tea.KeyBackspace}
|
||||
default:
|
||||
return tea.KeyMsg{}
|
||||
}
|
||||
}
|
||||
|
||||
func runeMsg(r rune) tea.KeyMsg {
|
||||
return tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{r}}
|
||||
}
|
||||
128
cmd/tui/signin.go
Normal file
128
cmd/tui/signin.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
)
|
||||
|
||||
type signInModel struct {
|
||||
modelName string
|
||||
signInURL string
|
||||
spinner int
|
||||
width int
|
||||
userName string
|
||||
cancelled bool
|
||||
}
|
||||
|
||||
func (m signInModel) Init() tea.Cmd {
|
||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func (m signInModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case signInTickMsg:
|
||||
m.spinner++
|
||||
if m.spinner%5 == 0 {
|
||||
return m, tea.Batch(
|
||||
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
}),
|
||||
checkSignIn,
|
||||
)
|
||||
}
|
||||
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
|
||||
case signInCheckMsg:
|
||||
if msg.signedIn {
|
||||
m.userName = msg.userName
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m signInModel) View() string {
|
||||
if m.userName != "" {
|
||||
return ""
|
||||
}
|
||||
return renderSignIn(m.modelName, m.signInURL, m.spinner, m.width)
|
||||
}
|
||||
|
||||
func renderSignIn(modelName, signInURL string, spinner, width int) string {
|
||||
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
|
||||
frame := spinnerFrames[spinner%len(spinnerFrames)]
|
||||
|
||||
urlColor := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("117"))
|
||||
urlWrap := lipgloss.NewStyle().PaddingLeft(2)
|
||||
if width > 4 {
|
||||
urlWrap = urlWrap.Width(width - 4)
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
fmt.Fprintf(&s, "To use %s, please sign in.\n\n", selectorSelectedItemStyle.Render(modelName))
|
||||
|
||||
// Wrap in OSC 8 hyperlink so the entire URL is clickable even when wrapped.
|
||||
// Padding is outside the hyperlink so spaces don't get underlined.
|
||||
link := fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", signInURL, urlColor.Render(signInURL))
|
||||
s.WriteString("Navigate to:\n")
|
||||
s.WriteString(urlWrap.Render(link))
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
|
||||
frame + " Waiting for sign in to complete..."))
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(selectorHelpStyle.Render("esc cancel"))
|
||||
|
||||
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
|
||||
}
|
||||
|
||||
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
|
||||
func RunSignIn(modelName, signInURL string) (string, error) {
|
||||
config.OpenBrowser(signInURL)
|
||||
|
||||
m := signInModel{
|
||||
modelName: modelName,
|
||||
signInURL: signInURL,
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error running sign-in: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(signInModel)
|
||||
if fm.cancelled {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
|
||||
return fm.userName, nil
|
||||
}
|
||||
175
cmd/tui/signin_test.go
Normal file
175
cmd/tui/signin_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
func TestRenderSignIn_ContainsModelName(t *testing.T) {
|
||||
got := renderSignIn("glm-4.7:cloud", "https://example.com/signin", 0, 80)
|
||||
if !strings.Contains(got, "glm-4.7:cloud") {
|
||||
t.Error("should contain model name")
|
||||
}
|
||||
if !strings.Contains(got, "please sign in") {
|
||||
t.Error("should contain sign-in prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_ContainsURL(t *testing.T) {
|
||||
url := "https://ollama.com/connect?key=abc123"
|
||||
got := renderSignIn("test:cloud", url, 0, 120)
|
||||
if !strings.Contains(got, url) {
|
||||
t.Errorf("should contain URL %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_OSC8Hyperlink(t *testing.T) {
|
||||
url := "https://ollama.com/connect?key=abc123"
|
||||
got := renderSignIn("test:cloud", url, 0, 120)
|
||||
|
||||
// Should contain OSC 8 open sequence with the URL
|
||||
osc8Open := "\033]8;;" + url + "\033\\"
|
||||
if !strings.Contains(got, osc8Open) {
|
||||
t.Error("should contain OSC 8 open sequence with URL")
|
||||
}
|
||||
|
||||
// Should contain OSC 8 close sequence
|
||||
osc8Close := "\033]8;;\033\\"
|
||||
if !strings.Contains(got, osc8Close) {
|
||||
t.Error("should contain OSC 8 close sequence")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_ContainsSpinner(t *testing.T) {
|
||||
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
|
||||
if !strings.Contains(got, "Waiting for sign in to complete") {
|
||||
t.Error("should contain waiting message")
|
||||
}
|
||||
if !strings.Contains(got, "⠋") {
|
||||
t.Error("should contain first spinner frame at spinner=0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_SpinnerAdvances(t *testing.T) {
|
||||
got0 := renderSignIn("test:cloud", "https://example.com", 0, 80)
|
||||
got1 := renderSignIn("test:cloud", "https://example.com", 1, 80)
|
||||
if got0 == got1 {
|
||||
t.Error("different spinner values should produce different output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_ContainsEscHelp(t *testing.T) {
|
||||
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
|
||||
if !strings.Contains(got, "esc cancel") {
|
||||
t.Error("should contain esc cancel help text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_EscCancels(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEsc})
|
||||
fm := updated.(signInModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("esc should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("esc should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_CtrlCCancels(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
|
||||
fm := updated.(signInModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("ctrl+c should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("ctrl+c should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_SignedInQuitsClean(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, cmd := m.Update(signInCheckMsg{signedIn: true, userName: "alice"})
|
||||
fm := updated.(signInModel)
|
||||
if fm.userName != "alice" {
|
||||
t.Errorf("expected userName 'alice', got %q", fm.userName)
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("successful sign-in should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_SignedInViewClears(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
userName: "alice",
|
||||
}
|
||||
|
||||
got := m.View()
|
||||
if got != "" {
|
||||
t.Errorf("View should return empty string after sign-in, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_NotSignedInContinues(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, _ := m.Update(signInCheckMsg{signedIn: false})
|
||||
fm := updated.(signInModel)
|
||||
if fm.userName != "" {
|
||||
t.Error("should not set userName when not signed in")
|
||||
}
|
||||
if fm.cancelled {
|
||||
t.Error("should not cancel when check returns not signed in")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_WindowSizeUpdatesWidth(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
|
||||
fm := updated.(signInModel)
|
||||
if fm.width != 120 {
|
||||
t.Errorf("expected width 120, got %d", fm.width)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_TickAdvancesSpinner(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
spinner: 0,
|
||||
}
|
||||
|
||||
updated, cmd := m.Update(signInTickMsg{})
|
||||
fm := updated.(signInModel)
|
||||
if fm.spinner != 1 {
|
||||
t.Errorf("expected spinner=1, got %d", fm.spinner)
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("tick should return a command")
|
||||
}
|
||||
}
|
||||
603
cmd/tui/tui.go
Normal file
603
cmd/tui/tui.go
Normal file
@@ -0,0 +1,603 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
var (
|
||||
versionStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "243", Dark: "250"})
|
||||
|
||||
menuItemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2)
|
||||
|
||||
menuSelectedItemStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
|
||||
|
||||
menuDescStyle = selectorDescStyle.
|
||||
PaddingLeft(4)
|
||||
|
||||
greyedStyle = menuItemStyle.
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
|
||||
greyedSelectedStyle = menuSelectedItemStyle.
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
|
||||
modelStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "243", Dark: "250"})
|
||||
|
||||
notInstalledStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
|
||||
Italic(true)
|
||||
)
|
||||
|
||||
type menuItem struct {
|
||||
title string
|
||||
description string
|
||||
integration string // integration name for loading model config, empty if not an integration
|
||||
isRunModel bool
|
||||
isOthers bool
|
||||
}
|
||||
|
||||
var mainMenuItems = []menuItem{
|
||||
{
|
||||
title: "Run a model",
|
||||
description: "Start an interactive chat with a model",
|
||||
isRunModel: true,
|
||||
},
|
||||
{
|
||||
title: "Launch Claude Code",
|
||||
description: "Agentic coding across large codebases",
|
||||
integration: "claude",
|
||||
},
|
||||
{
|
||||
title: "Launch Codex",
|
||||
description: "OpenAI's open-source coding agent",
|
||||
integration: "codex",
|
||||
},
|
||||
{
|
||||
title: "Launch OpenClaw",
|
||||
description: "Personal AI with 100+ skills",
|
||||
integration: "openclaw",
|
||||
},
|
||||
}
|
||||
|
||||
var othersMenuItem = menuItem{
|
||||
title: "More...",
|
||||
description: "Show additional integrations",
|
||||
isOthers: true,
|
||||
}
|
||||
|
||||
// getOtherIntegrations dynamically builds the "Others" list from the integration
|
||||
// registry, excluding any integrations already present in the pinned mainMenuItems.
|
||||
func getOtherIntegrations() []menuItem {
|
||||
pinned := map[string]bool{
|
||||
"run": true, // not an integration but in the pinned list
|
||||
}
|
||||
for _, item := range mainMenuItems {
|
||||
if item.integration != "" {
|
||||
pinned[item.integration] = true
|
||||
}
|
||||
}
|
||||
|
||||
var others []menuItem
|
||||
for _, info := range config.ListIntegrationInfos() {
|
||||
if pinned[info.Name] {
|
||||
continue
|
||||
}
|
||||
desc := info.Description
|
||||
if desc == "" {
|
||||
desc = "Open " + info.DisplayName + " integration"
|
||||
}
|
||||
others = append(others, menuItem{
|
||||
title: "Launch " + info.DisplayName,
|
||||
description: desc,
|
||||
integration: info.Name,
|
||||
})
|
||||
}
|
||||
return others
|
||||
}
|
||||
|
||||
type model struct {
|
||||
items []menuItem
|
||||
cursor int
|
||||
quitting bool
|
||||
selected bool
|
||||
changeModel bool
|
||||
showOthers bool
|
||||
availableModels map[string]bool
|
||||
err error
|
||||
|
||||
showingModal bool
|
||||
modalSelector selectorModel
|
||||
modalItems []SelectItem
|
||||
|
||||
showingSignIn bool
|
||||
signInURL string
|
||||
signInModel string
|
||||
signInSpinner int
|
||||
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
|
||||
|
||||
width int // terminal width from WindowSizeMsg
|
||||
statusMsg string // temporary status message shown near help text
|
||||
}
|
||||
|
||||
type signInTickMsg struct{}
|
||||
|
||||
type signInCheckMsg struct {
|
||||
signedIn bool
|
||||
userName string
|
||||
}
|
||||
|
||||
type clearStatusMsg struct{}
|
||||
|
||||
func (m *model) modelExists(name string) bool {
|
||||
if m.availableModels == nil || name == "" {
|
||||
return false
|
||||
}
|
||||
if m.availableModels[name] {
|
||||
return true
|
||||
}
|
||||
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
|
||||
for modelName := range m.availableModels {
|
||||
if strings.HasPrefix(modelName, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *model) buildModalItems() []SelectItem {
|
||||
modelItems, _ := config.GetModelItems(context.Background())
|
||||
var items []SelectItem
|
||||
for _, item := range modelItems {
|
||||
items = append(items, SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended})
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func (m *model) openModelModal() {
|
||||
m.modalItems = m.buildModalItems()
|
||||
m.modalSelector = selectorModel{
|
||||
title: "Select model:",
|
||||
items: m.modalItems,
|
||||
helpText: "↑/↓ navigate • enter select • ← back",
|
||||
}
|
||||
m.showingModal = true
|
||||
}
|
||||
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud")
|
||||
}
|
||||
|
||||
// checkCloudSignIn checks if a cloud model needs sign-in.
|
||||
// Returns a command to start sign-in if needed, or nil if already signed in.
|
||||
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
|
||||
if modelName == "" || !isCloudModel(modelName) {
|
||||
return nil
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
||||
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startSignIn initiates the sign-in flow for a cloud model.
|
||||
// fromModal indicates if this was triggered from the model picker modal.
|
||||
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
|
||||
m.showingModal = false
|
||||
m.showingSignIn = true
|
||||
m.signInURL = signInURL
|
||||
m.signInModel = modelName
|
||||
m.signInSpinner = 0
|
||||
m.signInFromModal = fromModal
|
||||
|
||||
config.OpenBrowser(signInURL)
|
||||
|
||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func checkSignIn() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||
}
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
|
||||
func (m *model) loadAvailableModels() {
|
||||
m.availableModels = make(map[string]bool)
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
models, err := client.List(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, mdl := range models.Models {
|
||||
m.availableModels[mdl.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) buildItems() {
|
||||
others := getOtherIntegrations()
|
||||
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
|
||||
m.items = append(m.items, mainMenuItems...)
|
||||
|
||||
if m.showOthers {
|
||||
m.items = append(m.items, others...)
|
||||
} else {
|
||||
m.items = append(m.items, othersMenuItem)
|
||||
}
|
||||
}
|
||||
|
||||
func isOthersIntegration(name string) bool {
|
||||
for _, item := range getOtherIntegrations() {
|
||||
if item.integration == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func initialModel() model {
|
||||
m := model{
|
||||
cursor: 0,
|
||||
}
|
||||
m.loadAvailableModels()
|
||||
|
||||
lastSelection := config.LastSelection()
|
||||
if isOthersIntegration(lastSelection) {
|
||||
m.showOthers = true
|
||||
}
|
||||
|
||||
m.buildItems()
|
||||
|
||||
if lastSelection != "" {
|
||||
for i, item := range m.items {
|
||||
if lastSelection == "run" && item.isRunModel {
|
||||
m.cursor = i
|
||||
break
|
||||
} else if item.integration == lastSelection {
|
||||
m.cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m model) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if wmsg, ok := msg.(tea.WindowSizeMsg); ok {
|
||||
wasSet := m.width > 0
|
||||
m.width = wmsg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if _, ok := msg.(clearStatusMsg); ok {
|
||||
m.statusMsg = ""
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.showingSignIn {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.showingSignIn = false
|
||||
if m.signInFromModal {
|
||||
m.showingModal = true
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case signInTickMsg:
|
||||
m.signInSpinner++
|
||||
// Check sign-in status every 5th tick (~1 second)
|
||||
if m.signInSpinner%5 == 0 {
|
||||
return m, tea.Batch(
|
||||
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
}),
|
||||
checkSignIn,
|
||||
)
|
||||
}
|
||||
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
|
||||
case signInCheckMsg:
|
||||
if msg.signedIn {
|
||||
if m.signInFromModal {
|
||||
m.modalSelector.selected = m.signInModel
|
||||
m.changeModel = true
|
||||
} else {
|
||||
m.selected = true
|
||||
}
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.showingModal {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc, tea.KeyLeft:
|
||||
m.showingModal = false
|
||||
return m, nil
|
||||
|
||||
case tea.KeyEnter:
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
|
||||
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
|
||||
}
|
||||
if m.modalSelector.selected != "" {
|
||||
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
m.changeModel = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
|
||||
default:
|
||||
// Delegate navigation (up/down/pgup/pgdown/filter/backspace) to selectorModel
|
||||
m.modalSelector.updateNavigation(msg)
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "q", "esc":
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
|
||||
case "up", "k":
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
}
|
||||
// Auto-collapse "Others" when cursor moves back into pinned items
|
||||
if m.showOthers && m.cursor < len(mainMenuItems) {
|
||||
m.showOthers = false
|
||||
m.buildItems()
|
||||
}
|
||||
|
||||
case "down", "j":
|
||||
if m.cursor < len(m.items)-1 {
|
||||
m.cursor++
|
||||
}
|
||||
// Auto-expand "Others..." when cursor lands on it
|
||||
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
|
||||
m.showOthers = true
|
||||
m.buildItems()
|
||||
// cursor now points at the first "other" integration
|
||||
}
|
||||
|
||||
case "enter", " ":
|
||||
item := m.items[m.cursor]
|
||||
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var configuredModel string
|
||||
if item.isRunModel {
|
||||
configuredModel = config.LastModel()
|
||||
} else if item.integration != "" {
|
||||
configuredModel = config.IntegrationModel(item.integration)
|
||||
}
|
||||
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
|
||||
case "right", "l":
|
||||
item := m.items[m.cursor]
|
||||
if item.integration != "" || item.isRunModel {
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
return m, nil
|
||||
}
|
||||
m.openModelModal()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) View() string {
|
||||
if m.quitting {
|
||||
return ""
|
||||
}
|
||||
|
||||
if m.showingSignIn {
|
||||
return m.renderSignInDialog()
|
||||
}
|
||||
|
||||
if m.showingModal {
|
||||
return m.renderModal()
|
||||
}
|
||||
|
||||
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
|
||||
|
||||
for i, item := range m.items {
|
||||
cursor := ""
|
||||
style := menuItemStyle
|
||||
isInstalled := true
|
||||
|
||||
if item.integration != "" {
|
||||
isInstalled = config.IsIntegrationInstalled(item.integration)
|
||||
}
|
||||
|
||||
if m.cursor == i {
|
||||
cursor = "▸ "
|
||||
if isInstalled {
|
||||
style = menuSelectedItemStyle
|
||||
} else {
|
||||
style = greyedSelectedStyle
|
||||
}
|
||||
} else if !isInstalled && item.integration != "" {
|
||||
style = greyedStyle
|
||||
}
|
||||
|
||||
title := item.title
|
||||
var modelSuffix string
|
||||
if item.integration != "" {
|
||||
if !isInstalled {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
} else if m.cursor == i {
|
||||
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
|
||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
}
|
||||
} else if item.isRunModel && m.cursor == i {
|
||||
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
|
||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
}
|
||||
|
||||
s += style.Render(cursor+title) + modelSuffix + "\n"
|
||||
|
||||
desc := item.description
|
||||
if !isInstalled && item.integration != "" && m.cursor == i {
|
||||
if hint := config.IntegrationInstallHint(item.integration); hint != "" {
|
||||
desc = hint
|
||||
} else {
|
||||
desc = "not installed"
|
||||
}
|
||||
}
|
||||
s += menuDescStyle.Render(desc) + "\n\n"
|
||||
}
|
||||
|
||||
if m.statusMsg != "" {
|
||||
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "124", Dark: "210"}).Render(m.statusMsg) + "\n"
|
||||
}
|
||||
|
||||
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → change model • esc quit")
|
||||
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m model) renderModal() string {
|
||||
modalStyle := lipgloss.NewStyle().
|
||||
PaddingBottom(1).
|
||||
PaddingRight(2)
|
||||
|
||||
s := modalStyle.Render(m.modalSelector.renderContent())
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m model) renderSignInDialog() string {
|
||||
return renderSignIn(m.signInModel, m.signInURL, m.signInSpinner, m.width)
|
||||
}
|
||||
|
||||
type Selection int
|
||||
|
||||
const (
|
||||
SelectionNone Selection = iota
|
||||
SelectionRunModel
|
||||
SelectionChangeRunModel
|
||||
SelectionIntegration // Generic integration selection
|
||||
SelectionChangeIntegration // Generic change model for integration
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
Selection Selection
|
||||
Integration string // integration name if applicable
|
||||
Model string // model name if selected from modal
|
||||
}
|
||||
|
||||
func Run() (Result, error) {
|
||||
m := initialModel()
|
||||
p := tea.NewProgram(m)
|
||||
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(model)
|
||||
if fm.err != nil {
|
||||
return Result{Selection: SelectionNone}, fm.err
|
||||
}
|
||||
|
||||
if !fm.selected && !fm.changeModel {
|
||||
return Result{Selection: SelectionNone}, nil
|
||||
}
|
||||
|
||||
item := fm.items[fm.cursor]
|
||||
|
||||
if fm.changeModel {
|
||||
if item.isRunModel {
|
||||
return Result{
|
||||
Selection: SelectionChangeRunModel,
|
||||
Model: fm.modalSelector.selected,
|
||||
}, nil
|
||||
}
|
||||
return Result{
|
||||
Selection: SelectionChangeIntegration,
|
||||
Integration: item.integration,
|
||||
Model: fm.modalSelector.selected,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if item.isRunModel {
|
||||
return Result{Selection: SelectionRunModel}, nil
|
||||
}
|
||||
|
||||
return Result{
|
||||
Selection: SelectionIntegration,
|
||||
Integration: item.integration,
|
||||
}, nil
|
||||
}
|
||||
@@ -5,7 +5,10 @@ title: Context length
|
||||
Context length is the maximum number of tokens that the model has access to in memory.
|
||||
|
||||
<Note>
|
||||
The default context length in Ollama is 4096 tokens.
|
||||
Ollama defaults to the following context lengths based on VRAM:
|
||||
- < 24 GiB VRAM: 4k context
|
||||
- 24-48 GiB VRAM: 32k context
|
||||
- >= 48 GiB VRAM: 256k context
|
||||
</Note>
|
||||
|
||||
Tasks which require large context like web search, agents, and coding tools should be set to at least 64000 tokens.
|
||||
|
||||
@@ -105,21 +105,52 @@
|
||||
{
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/cline",
|
||||
"/integrations/openclaw",
|
||||
"/integrations/codex",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/marimo",
|
||||
"/integrations/n8n",
|
||||
"/integrations/onyx",
|
||||
"/integrations/opencode",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/xcode",
|
||||
"/integrations/zed"
|
||||
"/integrations/index",
|
||||
{
|
||||
"group": "Coding",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/codex",
|
||||
"/integrations/opencode",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Assistants",
|
||||
"pages": [
|
||||
"/integrations/openclaw"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "IDEs & Editors",
|
||||
"pages": [
|
||||
"/integrations/cline",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/xcode",
|
||||
"/integrations/zed"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Chat & RAG",
|
||||
"pages": [
|
||||
"/integrations/onyx"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Automation",
|
||||
"pages": [
|
||||
"/integrations/n8n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Notebooks",
|
||||
"pages": [
|
||||
"/integrations/marimo"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
22
docs/faq.mdx
22
docs/faq.mdx
@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
## How can I view the logs?
|
||||
|
||||
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
|
||||
Review the [Troubleshooting](./troubleshooting.mdx) docs for more about using logs.
|
||||
|
||||
## Is my GPU compatible with Ollama?
|
||||
|
||||
Please refer to the [GPU docs](./gpu).
|
||||
Please refer to the [GPU docs](./gpu.mdx).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
@@ -66,7 +66,7 @@ llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
||||
```
|
||||
</Info>
|
||||
|
||||
The `Processor` column will show which memory the model was loaded in to:
|
||||
The `Processor` column will show which memory the model was loaded into:
|
||||
|
||||
- `100% GPU` means the model was loaded entirely into the GPU
|
||||
- `100% CPU` means the model was loaded entirely in system memory
|
||||
@@ -158,7 +158,7 @@ docker run -d -e HTTPS_PROXY=https://my.proxy.example.com -p 11434:11434 ollama-
|
||||
|
||||
## Does Ollama send my prompts and answers back to ollama.com?
|
||||
|
||||
No. Ollama runs locally, and conversation data does not leave your machine.
|
||||
Ollama runs locally. We don't see your prompts or data when you run locally. When using cloud-hosted models, we process your prompts and responses to provide the service but do not store or log that content and never train on it. We collect basic account info and limited usage metadata to provide the service that does not include prompt or response content. We don't sell your data. You can delete your account anytime.
|
||||
|
||||
## How can I expose Ollama on my network?
|
||||
|
||||
@@ -183,7 +183,7 @@ server {
|
||||
|
||||
## How can I use Ollama with ngrok?
|
||||
|
||||
Ollama can be accessed using a range of tools for tunneling tools. For example with Ngrok:
|
||||
Ollama can be accessed using a range of tunneling apps. For example with Ngrok:
|
||||
|
||||
```shell
|
||||
ngrok http 11434 --host-header="localhost:11434"
|
||||
@@ -240,7 +240,7 @@ GPU acceleration is not available for Docker Desktop in macOS due to the lack of
|
||||
|
||||
This can impact both installing Ollama, as well as downloading models.
|
||||
|
||||
Open `Control Panel > Networking and Internet > View network status and tasks` and click on `Change adapter settings` on the left panel. Find the `vEthernel (WSL)` adapter, right click and select `Properties`.
|
||||
Open `Control Panel > Networking and Internet > View network status and tasks` and click on `Change adapter settings` on the left panel. Find the `vEthernet (WSL)` adapter, right click and select `Properties`.
|
||||
Click on `Configure` and open the `Advanced` tab. Search through each of the properties until you find `Large Send Offload Version 2 (IPv4)` and `Large Send Offload Version 2 (IPv6)`. _Disable_ both of these
|
||||
properties.
|
||||
|
||||
@@ -299,7 +299,7 @@ The `keep_alive` API parameter with the `/api/generate` and `/api/chat` API endp
|
||||
|
||||
## How do I manage the maximum number of requests the Ollama server can queue?
|
||||
|
||||
If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queue by setting `OLLAMA_MAX_QUEUE`.
|
||||
If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queued by setting `OLLAMA_MAX_QUEUE`.
|
||||
|
||||
## How does Ollama handle concurrent requests?
|
||||
|
||||
@@ -312,10 +312,10 @@ Parallel request processing for a given model results in increasing the context
|
||||
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
|
||||
|
||||
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 \* the number of GPUs or 3 for CPU inference.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time, default 1. Required RAM will scale by `OLLAMA_NUM_PARALLEL` * `OLLAMA_CONTEXT_LENGTH`.
|
||||
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPU's VRAM.
|
||||
|
||||
## How does Ollama load models on multiple GPUs?
|
||||
|
||||
@@ -382,7 +382,7 @@ ollama signin
|
||||
Replace <username> with your actual Windows user name.
|
||||
</Note>
|
||||
|
||||
## How can I stop Ollama from starting when I login to my computer
|
||||
## How can I stop Ollama from starting when I login to my computer?
|
||||
|
||||
Ollama for Windows and macOS register as a login item during installation. You can disable this if you prefer not to have Ollama automatically start. Ollama will respect this setting across upgrades, unless you uninstall the application.
|
||||
|
||||
@@ -390,4 +390,4 @@ Ollama for Windows and macOS register as a login item during installation. You
|
||||
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
|
||||
|
||||
**MacOS**
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under `Allow in the Background`, then click the slider to disable.
|
||||
|
||||
50
docs/integrations/index.mdx
Normal file
50
docs/integrations/index.mdx
Normal file
@@ -0,0 +1,50 @@
|
||||
---
|
||||
title: Overview
|
||||
---
|
||||
|
||||
Ollama integrates with a wide range of tools.
|
||||
|
||||
## Coding Agents
|
||||
|
||||
Coding assistants that can read, modify, and execute code in your projects.
|
||||
|
||||
- [Claude Code](/integrations/claude-code)
|
||||
- [Codex](/integrations/codex)
|
||||
- [OpenCode](/integrations/opencode)
|
||||
- [Droid](/integrations/droid)
|
||||
- [Goose](/integrations/goose)
|
||||
|
||||
## Assistants
|
||||
|
||||
AI assistants that help with everyday tasks.
|
||||
|
||||
- [OpenClaw](/integrations/openclaw)
|
||||
|
||||
## IDEs & Editors
|
||||
|
||||
Native integrations for popular development environments.
|
||||
|
||||
- [VS Code](/integrations/vscode)
|
||||
- [Cline](/integrations/cline)
|
||||
- [Roo Code](/integrations/roo-code)
|
||||
- [JetBrains](/integrations/jetbrains)
|
||||
- [Xcode](/integrations/xcode)
|
||||
- [Zed](/integrations/zed)
|
||||
|
||||
## Chat & RAG
|
||||
|
||||
Chat interfaces and retrieval-augmented generation platforms.
|
||||
|
||||
- [Onyx](/integrations/onyx)
|
||||
|
||||
## Automation
|
||||
|
||||
Workflow automation platforms with AI integration.
|
||||
|
||||
- [n8n](/integrations/n8n)
|
||||
|
||||
## Notebooks
|
||||
|
||||
Interactive computing environments with AI capabilities.
|
||||
|
||||
- [marimo](/integrations/marimo)
|
||||
@@ -2,7 +2,7 @@
|
||||
title: Quickstart
|
||||
---
|
||||
|
||||
This quickstart will walk your through running your first model with Ollama. To get started, download Ollama on macOS, Windows or Linux.
|
||||
Ollama is available on macOS, Windows, and Linux.
|
||||
|
||||
<a
|
||||
href="https://ollama.com/download"
|
||||
@@ -12,131 +12,48 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
Download Ollama
|
||||
</a>
|
||||
|
||||
## Run a model
|
||||
## Get Started
|
||||
|
||||
<Tabs>
|
||||
<Tab title="CLI">
|
||||
Open a terminal and run the command:
|
||||
|
||||
```sh
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="cURL">
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
Lastly, chat with the model:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "gemma3",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "Hello there!"
|
||||
}],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="Python">
|
||||
Start by downloading a model:
|
||||
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
Then install Ollama's Python library:
|
||||
|
||||
```sh
|
||||
pip install ollama
|
||||
```
|
||||
|
||||
Lastly, chat with the model:
|
||||
|
||||
```python
|
||||
from ollama import chat
|
||||
from ollama import ChatResponse
|
||||
|
||||
response: ChatResponse = chat(model='gemma3', messages=[
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
},
|
||||
])
|
||||
print(response['message']['content'])
|
||||
# or access fields directly from the response object
|
||||
print(response.message.content)
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="JavaScript">
|
||||
Start by downloading a model:
|
||||
|
||||
```
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
Then install the Ollama JavaScript library:
|
||||
```
|
||||
npm i ollama
|
||||
```
|
||||
|
||||
Lastly, chat with the model:
|
||||
|
||||
```shell
|
||||
import ollama from 'ollama'
|
||||
|
||||
const response = await ollama.chat({
|
||||
model: 'gemma3',
|
||||
messages: [{ role: 'user', content: 'Why is the sky blue?' }],
|
||||
})
|
||||
console.log(response.message.content)
|
||||
```
|
||||
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
See a full list of available models [here](https://ollama.com/models).
|
||||
|
||||
## Coding
|
||||
|
||||
For coding use cases, we recommend using the `glm-4.7-flash` model.
|
||||
|
||||
Note: this model requires 23 GB of VRAM with 64000 tokens context length.
|
||||
```sh
|
||||
ollama pull glm-4.7-flash
|
||||
```
|
||||
|
||||
Alternatively, you can use a more powerful cloud model (with full context length):
|
||||
```sh
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
Use `ollama launch` to quickly set up a coding tool with Ollama models:
|
||||
Run `ollama` in your terminal to open the interactive menu:
|
||||
|
||||
```sh
|
||||
ollama launch
|
||||
ollama
|
||||
```
|
||||
|
||||
### Supported integrations
|
||||
Navigate with `↑/↓`, press `enter` to launch, `→` to change model, and `esc` to quit.
|
||||
|
||||
- [OpenCode](/integrations/opencode) - Open-source coding assistant
|
||||
- [Claude Code](/integrations/claude-code) - Anthropic's agentic coding tool
|
||||
- [Codex](/integrations/codex) - OpenAI's coding assistant
|
||||
- [Droid](/integrations/droid) - Factory's AI coding agent
|
||||
The menu provides quick access to:
|
||||
- **Run a model** - Start an interactive chat
|
||||
- **Launch tools** - Claude Code, Codex, OpenClaw, and more
|
||||
- **Additional integrations** - Available under "More..."
|
||||
|
||||
### Launch with a specific model
|
||||
## Coding
|
||||
|
||||
Launch coding tools with Ollama models:
|
||||
|
||||
```sh
|
||||
ollama launch claude --model glm-4.7-flash
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
### Configure without launching
|
||||
|
||||
```sh
|
||||
ollama launch claude --config
|
||||
ollama launch codex
|
||||
```
|
||||
|
||||
```sh
|
||||
ollama launch opencode
|
||||
```
|
||||
|
||||
See [integrations](/integrations) for all supported tools.
|
||||
|
||||
## API
|
||||
|
||||
Use the [API](/api) to integrate Ollama into your applications:
|
||||
|
||||
```sh
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "gemma3",
|
||||
"messages": [{ "role": "user", "content": "Hello!" }]
|
||||
}'
|
||||
```
|
||||
|
||||
See the [API documentation](/api) for Python, JavaScript, and other integrations.
|
||||
|
||||
@@ -188,8 +188,6 @@ func LogLevel() slog.Level {
|
||||
var (
|
||||
// FlashAttention enables the experimental flash attention feature.
|
||||
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
|
||||
// DebugLogRequests logs inference requests to disk for replay/debugging.
|
||||
DebugLogRequests = Bool("OLLAMA_DEBUG_LOG_REQUESTS")
|
||||
// KvCacheType is the quantization type for the K/V cache.
|
||||
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
||||
// NoHistory disables readline history.
|
||||
@@ -275,27 +273,26 @@ type EnvVar struct {
|
||||
|
||||
func AsMap() map[string]EnvVar {
|
||||
ret := map[string]EnvVar{
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_DEBUG_LOG_REQUESTS": {"OLLAMA_DEBUG_LOG_REQUESTS", DebugLogRequests(), "Log inference request bodies and replay curl commands to a temp directory"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
||||
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
||||
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
|
||||
23
go.mod
23
go.mod
@@ -13,7 +13,7 @@ require (
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
@@ -21,14 +21,18 @@ require (
|
||||
|
||||
require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/mattn/go-runewidth v0.0.16
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.30.0
|
||||
@@ -38,22 +42,35 @@ require (
|
||||
|
||||
require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/x/ansi v0.10.1 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/chewxy/hm v1.0.0 // indirect
|
||||
github.com/chewxy/math32 v1.11.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-pointer v0.0.1 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/tkrajina/go-reflector v0.5.5 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
|
||||
67
go.sum
67
go.sum
@@ -14,6 +14,8 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
@@ -24,6 +26,18 @@ github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
|
||||
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
|
||||
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
|
||||
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
|
||||
@@ -59,6 +73,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
@@ -148,13 +164,19 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
|
||||
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
|
||||
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
||||
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -162,6 +184,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
|
||||
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
@@ -182,8 +210,9 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
@@ -206,12 +235,39 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
|
||||
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
|
||||
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
|
||||
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
|
||||
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
|
||||
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
|
||||
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
|
||||
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
|
||||
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
|
||||
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
|
||||
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
|
||||
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
|
||||
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
|
||||
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
|
||||
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
|
||||
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
|
||||
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
|
||||
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
|
||||
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
|
||||
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
|
||||
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
|
||||
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
|
||||
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
|
||||
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
@@ -220,6 +276,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
||||
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -306,6 +364,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type filteredEnv []string
|
||||
@@ -116,7 +117,7 @@ type llamaServer struct {
|
||||
type ollamaServer struct {
|
||||
llmServer
|
||||
|
||||
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
|
||||
tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
|
||||
}
|
||||
|
||||
// LoadModel will load a model from disk. The model must be in the GGML format.
|
||||
@@ -142,11 +143,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
var llamaModel *llama.Model
|
||||
var textProcessor model.TextProcessor
|
||||
var tok tokenizer.Tokenizer
|
||||
var err error
|
||||
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
||||
if len(projectors) == 0 {
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
tok, err = model.NewTextProcessor(modelPath)
|
||||
} else {
|
||||
err = errors.New("split vision models aren't supported")
|
||||
}
|
||||
@@ -155,7 +156,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||
}
|
||||
}
|
||||
if textProcessor == nil {
|
||||
if tok == nil {
|
||||
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -211,7 +212,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
|
||||
if textProcessor == nil {
|
||||
if tok == nil {
|
||||
flashAttention := ml.FlashAttentionAuto
|
||||
if faUserSet {
|
||||
if fa {
|
||||
@@ -261,7 +262,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
gpuLibs := ml.LibraryPaths(gpus)
|
||||
status := NewStatusWriter(os.Stderr)
|
||||
cmd, port, err := StartRunner(
|
||||
textProcessor != nil,
|
||||
tok != nil,
|
||||
modelPath,
|
||||
gpuLibs,
|
||||
status,
|
||||
@@ -310,8 +311,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
}
|
||||
}()
|
||||
|
||||
if textProcessor != nil {
|
||||
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
|
||||
if tok != nil {
|
||||
return &ollamaServer{llmServer: s, tokenizer: tok}, nil
|
||||
} else {
|
||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||
}
|
||||
@@ -1774,7 +1775,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
|
||||
}
|
||||
|
||||
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
tokens, err := s.textProcessor.Encode(content, false)
|
||||
tokens, err := s.tokenizer.Encode(content, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1809,7 +1810,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
|
||||
toks[i] = int32(t)
|
||||
}
|
||||
|
||||
content, err := s.textProcessor.Decode(toks)
|
||||
content, err := s.tokenizer.Decode(toks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -131,12 +131,15 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
|
||||
messageID := anthropic.GenerateMessageID()
|
||||
|
||||
// Estimate input tokens for streaming (actual count not available until generation completes)
|
||||
estimatedTokens := anthropic.EstimateInputTokens(req)
|
||||
|
||||
w := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model),
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type BytePairEncoding struct {
|
||||
vocab *Vocabulary
|
||||
regexps []*regexp2.Regexp
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||
if len(pretokenizers) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizers {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||
return bpe.vocab
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||
return bpe.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||
parts := []string{s}
|
||||
for _, re := range bpe.regexps {
|
||||
parts = slices.Collect(func(yield func(string) bool) {
|
||||
for _, part := range parts {
|
||||
r := []rune(part)
|
||||
var offset int
|
||||
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
|
||||
if offset-m.Index != 0 {
|
||||
if !yield(string(r[:m.Index])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !yield(m.String()) {
|
||||
return
|
||||
}
|
||||
|
||||
offset = m.Index + m.Length
|
||||
}
|
||||
|
||||
if offset < len(r) {
|
||||
if !yield(string(r[offset:])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return slices.Values(parts)
|
||||
}
|
||||
|
||||
// fragment is a string fragment and their corresponding token IDs
|
||||
type fragment struct {
|
||||
value string
|
||||
ids []int32
|
||||
}
|
||||
|
||||
// pair is a pair of runes and its rank
|
||||
type pair struct {
|
||||
a, b int
|
||||
rank int
|
||||
value string
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
p, n int
|
||||
runes []rune
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
id := bpe.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var middle []fragment
|
||||
switch i := strings.Index(frag.value, special); {
|
||||
case i < 0:
|
||||
middle = append(middle, frag)
|
||||
case i > 0:
|
||||
middle = append(middle, fragment{value: frag.value[:i]})
|
||||
fallthrough
|
||||
default:
|
||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||
if rest := frag.value[i+len(special):]; rest != "" {
|
||||
middle = append(middle, fragment{value: rest})
|
||||
}
|
||||
}
|
||||
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
continue
|
||||
}
|
||||
|
||||
for split := range bpe.split(frag.value) {
|
||||
// TODO: process splits concurrently
|
||||
var sb strings.Builder
|
||||
for _, b := range []byte(split) {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
|
||||
// short circuit if the fragment is in the vocabulary
|
||||
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
runes := []rune(sb.String())
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *pair {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
rank := bpe.vocab.Merge(left, right)
|
||||
if rank < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &pair{
|
||||
a: a,
|
||||
b: b,
|
||||
rank: rank,
|
||||
value: left + right,
|
||||
}
|
||||
}
|
||||
|
||||
pairs := heap.NewWith(func(i, j *pair) int {
|
||||
return cmp.Compare(i.rank, j.rank)
|
||||
})
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for !pairs.Empty() {
|
||||
pair, _ := pairs.Pop()
|
||||
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
||||
string(left.runes)+string(right.runes) != pair.value {
|
||||
continue
|
||||
}
|
||||
|
||||
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial {
|
||||
ids = bpe.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
for _, r := range bpe.vocab.Decode(id) {
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// this produces 0x00 aka NULL
|
||||
continue
|
||||
case r == 0x0143:
|
||||
r = 0x00ad
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
r = r - 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
}
|
||||
|
||||
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||
// encoding of the rune which is _not_ what we want
|
||||
if err := sb.WriteByte(byte(r)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", ids)
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -133,7 +134,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, ok := m.(TextProcessor)
|
||||
tp, ok := m.(tokenizer.Tokenizer)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocab := &model.Vocabulary{
|
||||
vocab := &tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
|
||||
},
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
var t tokenizer.Tokenizer
|
||||
switch c.String("tokenizer.ggml.model", "bert") {
|
||||
case "bert":
|
||||
processor = model.NewWordPiece(vocab, true)
|
||||
t = tokenizer.NewWordPiece(vocab, true)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Tokenizer: t,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
Sam *samModel `gguf:"s"`
|
||||
Vision *visionModel `gguf:"v"`
|
||||
@@ -134,8 +135,8 @@ func init() {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -43,8 +44,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
poolingType pooling.Type
|
||||
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
|
||||
),
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
var t tokenizer.Tokenizer
|
||||
switch c.String("tokenizer.ggml.model") {
|
||||
case "gpt2":
|
||||
processor = model.NewBytePairEncoding(&vocabulary)
|
||||
t = tokenizer.NewBytePairEncoding(&vocabulary)
|
||||
default:
|
||||
// Previous uploads of Gemma 3 on Ollama did not have token 106
|
||||
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
t = tokenizer.NewSentencePiece(&vocabulary)
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Tokenizer: t,
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
|
||||
@@ -6,11 +6,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
}
|
||||
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewSentencePiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
|
||||
@@ -198,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -236,8 +237,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -37,8 +38,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
|
||||
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Transformer struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TransformerBlocks []TransformerBlock `gguf:"blk"`
|
||||
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Transformer{
|
||||
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -59,7 +60,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -78,7 +79,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
vocabulary := model.Vocabulary{
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -104,8 +105,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
vocabulary := model.Vocabulary{
|
||||
var processor tokenizer.Tokenizer
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
}
|
||||
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
case "llama":
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
processor = tokenizer.NewSentencePiece(&vocabulary)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
ImageProcessor
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -28,12 +29,12 @@ type Model struct {
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
// Implement TextProcessor interface
|
||||
var _ model.TextProcessor = (*Model)(nil)
|
||||
var _ tokenizer.Tokenizer = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -32,8 +33,8 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
@@ -178,29 +179,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
headDim := hiddenSize / numHeads
|
||||
|
||||
processor := model.NewWordPiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
)
|
||||
|
||||
blockCount := int(c.Uint("block_count"))
|
||||
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
|
||||
layers := make([]EncoderLayer, blockCount)
|
||||
@@ -219,8 +197,29 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: layers,
|
||||
Tokenizer: tokenizer.NewWordPiece(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: Options{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,7 +34,7 @@ type Options struct {
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -44,28 +45,24 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
processor := model.NewBytePairEncoding(
|
||||
&vocabulary,
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []DecoderLayer `gguf:"blk"`
|
||||
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
m := Model{
|
||||
Layers: make([]DecoderLayer, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*Model
|
||||
poolingType pooling.Type
|
||||
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
m := embedModel{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// Options contains model configuration
|
||||
@@ -207,7 +208,7 @@ func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outp
|
||||
// Model is the main Qwen3-Next model
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
@@ -353,8 +354,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWordPiece(t *testing.T) {
|
||||
wpm := NewWordPiece(
|
||||
&Vocabulary{
|
||||
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
|
||||
AddBOS: true,
|
||||
AddEOS: true,
|
||||
BOS: []int32{1},
|
||||
EOS: []int32{2},
|
||||
},
|
||||
true, // lowercase
|
||||
)
|
||||
|
||||
ids, err := wpm.Encode("Hello world!", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
|
||||
t.Errorf("unexpected ids (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
words, err := wpm.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWordPieceWords(t *testing.T) {
|
||||
var wpm WordPiece
|
||||
|
||||
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
|
||||
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
|
||||
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
}
|
||||
|
||||
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
|
||||
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
|
||||
decoder := func(tokenID int) string {
|
||||
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
|
||||
text, _ := tok.Decode([]int32{int32(tokenID)})
|
||||
return text
|
||||
}
|
||||
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
|
||||
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
|
||||
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -764,7 +765,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
nextBatchTokens[i].Token = token
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
@@ -773,14 +774,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
|
||||
if err != nil {
|
||||
panic("failed to decode token")
|
||||
}
|
||||
|
||||
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
|
||||
if seq.logprobs {
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
|
||||
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
|
||||
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
|
||||
}
|
||||
|
||||
@@ -878,7 +879,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var grammar *sample.GrammarSampler
|
||||
var err error
|
||||
if req.Grammar != "" {
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -3,6 +3,7 @@ package runner
|
||||
import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
@@ -11,22 +12,15 @@ func Execute(args []string) error {
|
||||
args = args[1:]
|
||||
}
|
||||
|
||||
var newRunner bool
|
||||
var mlxRunner bool
|
||||
if len(args) > 0 && args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
if len(args) > 0 && args[0] == "--mlx-engine" {
|
||||
args = args[1:]
|
||||
mlxRunner = true
|
||||
}
|
||||
|
||||
if mlxRunner {
|
||||
return mlxrunner.Execute(args)
|
||||
} else if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
return llamarunner.Execute(args)
|
||||
if len(args) > 0 {
|
||||
switch args[0] {
|
||||
case "--ollama-engine":
|
||||
return ollamarunner.Execute(args[1:])
|
||||
case "--imagegen-engine":
|
||||
return imagegen.Execute(args[1:])
|
||||
case "--mlx-engine":
|
||||
return mlxrunner.Execute(args[1:])
|
||||
}
|
||||
}
|
||||
return llamarunner.Execute(args)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// token represents information about a single token during sampling
|
||||
@@ -168,15 +168,15 @@ type GrammarSampler struct {
|
||||
grammar *llama.Grammar
|
||||
}
|
||||
|
||||
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(model.Vocabulary().Values))
|
||||
pieces := make([]string, len(model.Vocabulary().Values))
|
||||
for i := range model.Vocabulary().Values {
|
||||
pieces[i], _ = model.Decode([]int32{int32(i)})
|
||||
func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(tok.Vocabulary().Values))
|
||||
pieces := make([]string, len(tok.Vocabulary().Values))
|
||||
for i := range tok.Vocabulary().Values {
|
||||
pieces[i], _ = tok.Decode([]int32{int32(i)})
|
||||
vocabIds[i] = uint32(i)
|
||||
}
|
||||
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tok.Vocabulary().EOS)
|
||||
if grammar == nil {
|
||||
return nil, errors.New("sample: failed to initialize grammar")
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
func modelHelper(t testing.TB) tokenizer.Tokenizer {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||
f, err := os.Open(filepath.FromSlash("../tokenizer/testdata/llama3.2/encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
|
||||
merges := make([]string, 0, 1)
|
||||
// Only need vocab for Grammar Test
|
||||
return model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
return tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: tokens,
|
||||
Types: make([]int32, len(vocab)),
|
||||
Merges: merges,
|
||||
|
||||
@@ -302,12 +302,22 @@ function deps {
|
||||
}
|
||||
|
||||
function sign {
|
||||
# Copy install.ps1 to dist for release packaging
|
||||
write-host "Copying install.ps1 to dist"
|
||||
Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1"
|
||||
|
||||
if ("${env:KEY_CONTAINER}") {
|
||||
write-host "Signing Ollama executables, scripts and libraries"
|
||||
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
|
||||
$(get-childitem -path "${script:SRC_DIR}\dist\windows-*" -r -include @('*.exe', '*.dll'))
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
|
||||
write-host "Signing install.ps1"
|
||||
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
|
||||
"${script:SRC_DIR}\dist\install.ps1"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
} else {
|
||||
write-host "Signing not enabled"
|
||||
}
|
||||
|
||||
271
scripts/install.ps1
Normal file
271
scripts/install.ps1
Normal file
@@ -0,0 +1,271 @@
|
||||
<#
|
||||
.SYNOPSIS
|
||||
Install, upgrade, or uninstall Ollama on Windows.
|
||||
|
||||
.DESCRIPTION
|
||||
Downloads and installs Ollama.
|
||||
|
||||
Quick install:
|
||||
|
||||
irm https://ollama.com/install.ps1 | iex
|
||||
|
||||
Specific version:
|
||||
|
||||
$env:OLLAMA_VERSION="0.5.7"; irm https://ollama.com/install.ps1 | iex
|
||||
|
||||
Custom install directory:
|
||||
|
||||
$env:OLLAMA_INSTALL_DIR="D:\Ollama"; irm https://ollama.com/install.ps1 | iex
|
||||
|
||||
Uninstall:
|
||||
|
||||
$env:OLLAMA_UNINSTALL=1; irm https://ollama.com/install.ps1 | iex
|
||||
|
||||
Environment variables:
|
||||
|
||||
OLLAMA_VERSION Target version (default: latest stable)
|
||||
OLLAMA_INSTALL_DIR Custom install directory
|
||||
OLLAMA_UNINSTALL Set to 1 to uninstall Ollama
|
||||
OLLAMA_DEBUG Enable verbose output
|
||||
|
||||
.EXAMPLE
|
||||
irm https://ollama.com/install.ps1 | iex
|
||||
|
||||
.EXAMPLE
|
||||
$env:OLLAMA_VERSION = "0.5.7"; irm https://ollama.com/install.ps1 | iex
|
||||
|
||||
.LINK
|
||||
https://ollama.com
|
||||
#>
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
$ProgressPreference = "SilentlyContinue"
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Configuration from environment variables
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
$Version = if ($env:OLLAMA_VERSION) { $env:OLLAMA_VERSION } else { "" }
|
||||
$InstallDir = if ($env:OLLAMA_INSTALL_DIR) { $env:OLLAMA_INSTALL_DIR } else { "" }
|
||||
$Uninstall = $env:OLLAMA_UNINSTALL -eq "1"
|
||||
$DebugInstall = [bool]$env:OLLAMA_DEBUG
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Constants
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
# OLLAMA_DOWNLOAD_URL for developer testing only
|
||||
$DownloadBaseURL = if ($env:OLLAMA_DOWNLOAD_URL) { $env:OLLAMA_DOWNLOAD_URL.TrimEnd('/') } else { "https://ollama.com/download" }
|
||||
$InnoSetupUninstallGuid = "{44E83376-CE68-45EB-8FC1-393500EB558C}_is1"
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
function Write-Status {
|
||||
param([string]$Message)
|
||||
if ($DebugInstall) { Write-Host $Message }
|
||||
}
|
||||
|
||||
function Write-Step {
|
||||
param([string]$Message)
|
||||
if ($DebugInstall) { Write-Host ">>> $Message" -ForegroundColor Cyan }
|
||||
}
|
||||
|
||||
function Test-Signature {
|
||||
param([string]$FilePath)
|
||||
|
||||
$sig = Get-AuthenticodeSignature -FilePath $FilePath
|
||||
if ($sig.Status -ne "Valid") {
|
||||
Write-Status " Signature status: $($sig.Status)"
|
||||
return $false
|
||||
}
|
||||
|
||||
# Verify it's signed by Ollama Inc. (check exact organization name)
|
||||
# Anchor with comma/boundary to prevent "O=Not Ollama Inc." from matching
|
||||
$subject = $sig.SignerCertificate.Subject
|
||||
if ($subject -notmatch "(^|, )O=Ollama Inc\.(,|$)") {
|
||||
Write-Status " Unexpected signer: $subject"
|
||||
return $false
|
||||
}
|
||||
|
||||
Write-Status " Signature valid: $subject"
|
||||
return $true
|
||||
}
|
||||
|
||||
function Find-InnoSetupInstall {
|
||||
# Check both HKCU (per-user) and HKLM (per-machine) locations
|
||||
$possibleKeys = @(
|
||||
"HKCU:\Software\Microsoft\Windows\CurrentVersion\Uninstall\$InnoSetupUninstallGuid",
|
||||
"HKLM:\Software\Microsoft\Windows\CurrentVersion\Uninstall\$InnoSetupUninstallGuid",
|
||||
"HKLM:\Software\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\$InnoSetupUninstallGuid"
|
||||
)
|
||||
|
||||
foreach ($key in $possibleKeys) {
|
||||
if (Test-Path $key) {
|
||||
Write-Status " Found install at: $key"
|
||||
return $key
|
||||
}
|
||||
}
|
||||
return $null
|
||||
}
|
||||
|
||||
function Update-SessionPath {
|
||||
# Update PATH in current session so 'ollama' works immediately
|
||||
if ($InstallDir) {
|
||||
$ollamaDir = $InstallDir
|
||||
} else {
|
||||
$ollamaDir = Join-Path $env:LOCALAPPDATA "Programs\Ollama"
|
||||
}
|
||||
|
||||
# Add to PATH if not already present
|
||||
if (Test-Path $ollamaDir) {
|
||||
$currentPath = $env:PATH -split ';'
|
||||
if ($ollamaDir -notin $currentPath) {
|
||||
$env:PATH = "$ollamaDir;$env:PATH"
|
||||
Write-Status " Added $ollamaDir to session PATH"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function Invoke-Download {
|
||||
param(
|
||||
[string]$Url,
|
||||
[string]$OutFile
|
||||
)
|
||||
|
||||
Write-Status " Downloading: $Url"
|
||||
try {
|
||||
Invoke-WebRequest -Uri $Url -OutFile $OutFile -UseBasicParsing
|
||||
$size = (Get-Item $OutFile).Length
|
||||
Write-Status " Downloaded: $([math]::Round($size / 1MB, 1)) MB"
|
||||
} catch {
|
||||
if ($_.Exception.Response.StatusCode -eq 404) {
|
||||
throw "Download failed: not found at $Url"
|
||||
}
|
||||
throw "Download failed for ${Url}: $($_.Exception.Message)"
|
||||
}
|
||||
}
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Uninstall
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
function Invoke-Uninstall {
|
||||
Write-Step "Uninstalling Ollama"
|
||||
|
||||
$regKey = Find-InnoSetupInstall
|
||||
if (-not $regKey) {
|
||||
Write-Host "Ollama is not installed."
|
||||
return
|
||||
}
|
||||
|
||||
$uninstallString = (Get-ItemProperty -Path $regKey).UninstallString
|
||||
if (-not $uninstallString) {
|
||||
Write-Warning "No uninstall string found in registry"
|
||||
return
|
||||
}
|
||||
|
||||
# Strip quotes if present
|
||||
$uninstallExe = $uninstallString -replace '"', ''
|
||||
Write-Status " Uninstaller: $uninstallExe"
|
||||
|
||||
if (-not (Test-Path $uninstallExe)) {
|
||||
Write-Warning "Uninstaller not found at: $uninstallExe"
|
||||
return
|
||||
}
|
||||
|
||||
Write-Host "Launching uninstaller..."
|
||||
# Run with GUI so user can choose whether to keep models
|
||||
Start-Process -FilePath $uninstallExe -Wait
|
||||
|
||||
# Verify removal
|
||||
if (Find-InnoSetupInstall) {
|
||||
Write-Warning "Uninstall may not have completed"
|
||||
} else {
|
||||
Write-Host "Ollama has been uninstalled."
|
||||
}
|
||||
}
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Install
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
function Invoke-Install {
|
||||
# Determine installer URL
|
||||
if ($Version) {
|
||||
$installerUrl = "$DownloadBaseURL/OllamaSetup.exe?version=$Version"
|
||||
} else {
|
||||
$installerUrl = "$DownloadBaseURL/OllamaSetup.exe"
|
||||
}
|
||||
|
||||
# Download installer
|
||||
Write-Step "Downloading Ollama"
|
||||
if (-not $DebugInstall) {
|
||||
Write-Host "Downloading Ollama..."
|
||||
}
|
||||
|
||||
$tempInstaller = Join-Path $env:TEMP "OllamaSetup.exe"
|
||||
Invoke-Download -Url $installerUrl -OutFile $tempInstaller
|
||||
|
||||
# Verify signature
|
||||
Write-Step "Verifying signature"
|
||||
if (-not (Test-Signature -FilePath $tempInstaller)) {
|
||||
Remove-Item $tempInstaller -Force -ErrorAction SilentlyContinue
|
||||
throw "Installer signature verification failed"
|
||||
}
|
||||
|
||||
# Build installer arguments
|
||||
$installerArgs = "/VERYSILENT /NORESTART /SUPPRESSMSGBOXES"
|
||||
if ($InstallDir) {
|
||||
$installerArgs += " /DIR=`"$InstallDir`""
|
||||
}
|
||||
Write-Status " Installer args: $installerArgs"
|
||||
|
||||
# Run installer
|
||||
Write-Step "Installing Ollama"
|
||||
if (-not $DebugInstall) {
|
||||
Write-Host "Installing..."
|
||||
}
|
||||
|
||||
# Create upgrade marker so the app starts hidden
|
||||
# The app checks for this file on startup and removes it after
|
||||
$markerDir = Join-Path $env:LOCALAPPDATA "Ollama"
|
||||
$markerFile = Join-Path $markerDir "upgraded"
|
||||
if (-not (Test-Path $markerDir)) {
|
||||
New-Item -ItemType Directory -Path $markerDir -Force | Out-Null
|
||||
}
|
||||
New-Item -ItemType File -Path $markerFile -Force | Out-Null
|
||||
Write-Status " Created upgrade marker: $markerFile"
|
||||
|
||||
# Start installer and wait for just the installer process (not children)
|
||||
# Using -Wait would wait for Ollama to exit too, which we don't want
|
||||
$proc = Start-Process -FilePath $tempInstaller `
|
||||
-ArgumentList $installerArgs `
|
||||
-PassThru
|
||||
$proc.WaitForExit()
|
||||
|
||||
if ($proc.ExitCode -ne 0) {
|
||||
Remove-Item $tempInstaller -Force -ErrorAction SilentlyContinue
|
||||
throw "Installation failed with exit code $($proc.ExitCode)"
|
||||
}
|
||||
|
||||
# Cleanup
|
||||
Remove-Item $tempInstaller -Force -ErrorAction SilentlyContinue
|
||||
|
||||
# Update PATH in current session so 'ollama' works immediately
|
||||
Write-Step "Updating session PATH"
|
||||
Update-SessionPath
|
||||
|
||||
Write-Host "Install complete. You can now run 'ollama'."
|
||||
}
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Main
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
if ($Uninstall) {
|
||||
Invoke-Uninstall
|
||||
} else {
|
||||
Invoke-Install
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/bin/sh
|
||||
# This script installs Ollama on Linux.
|
||||
# This script installs Ollama on Linux and macOS.
|
||||
# It detects the current operating system architecture and installs the appropriate version of Ollama.
|
||||
|
||||
set -eu
|
||||
@@ -27,8 +27,7 @@ require() {
|
||||
echo $MISSING
|
||||
}
|
||||
|
||||
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
||||
|
||||
OS="$(uname -s)"
|
||||
ARCH=$(uname -m)
|
||||
case "$ARCH" in
|
||||
x86_64) ARCH="amd64" ;;
|
||||
@@ -36,6 +35,63 @@ case "$ARCH" in
|
||||
*) error "Unsupported architecture: $ARCH" ;;
|
||||
esac
|
||||
|
||||
VER_PARAM="${OLLAMA_VERSION:+?version=$OLLAMA_VERSION}"
|
||||
|
||||
###########################################
|
||||
# macOS
|
||||
###########################################
|
||||
|
||||
if [ "$OS" = "Darwin" ]; then
|
||||
NEEDS=$(require curl unzip)
|
||||
if [ -n "$NEEDS" ]; then
|
||||
status "ERROR: The following tools are required but missing:"
|
||||
for NEED in $NEEDS; do
|
||||
echo " - $NEED"
|
||||
done
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DOWNLOAD_URL="https://ollama.com/download/Ollama-darwin.zip${VER_PARAM}"
|
||||
|
||||
if pgrep -x Ollama >/dev/null 2>&1; then
|
||||
status "Stopping running Ollama instance..."
|
||||
pkill -x Ollama 2>/dev/null || true
|
||||
sleep 2
|
||||
fi
|
||||
|
||||
if [ -d "/Applications/Ollama.app" ]; then
|
||||
status "Removing existing Ollama installation..."
|
||||
rm -rf "/Applications/Ollama.app"
|
||||
fi
|
||||
|
||||
status "Downloading Ollama for macOS..."
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
-o "$TEMP_DIR/Ollama-darwin.zip" "$DOWNLOAD_URL"
|
||||
|
||||
status "Installing Ollama to /Applications..."
|
||||
unzip -q "$TEMP_DIR/Ollama-darwin.zip" -d "$TEMP_DIR"
|
||||
mv "$TEMP_DIR/Ollama.app" "/Applications/"
|
||||
|
||||
status "Adding 'ollama' command to PATH (may require password)..."
|
||||
mkdir -p "/usr/local/bin" 2>/dev/null || sudo mkdir -p "/usr/local/bin"
|
||||
ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama" 2>/dev/null || \
|
||||
sudo ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama"
|
||||
|
||||
if [ -z "${OLLAMA_NO_START:-}" ]; then
|
||||
status "Starting Ollama..."
|
||||
open -a Ollama --args hidden
|
||||
fi
|
||||
|
||||
status "Install complete. You can now run 'ollama'."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
###########################################
|
||||
# Linux
|
||||
###########################################
|
||||
|
||||
[ "$OS" = "Linux" ] || error 'This script is intended to run on Linux and macOS only.'
|
||||
|
||||
IS_WSL2=false
|
||||
|
||||
KERN=$(uname -r)
|
||||
@@ -45,8 +101,6 @@ case "$KERN" in
|
||||
*) ;;
|
||||
esac
|
||||
|
||||
VER_PARAM="${OLLAMA_VERSION:+?version=$OLLAMA_VERSION}"
|
||||
|
||||
SUDO=
|
||||
if [ "$(id -u)" -ne 0 ]; then
|
||||
# Running as root, no need for sudo
|
||||
|
||||
422
server/aliases.go
Normal file
422
server/aliases.go
Normal file
@@ -0,0 +1,422 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const (
|
||||
serverConfigFilename = "server.json"
|
||||
serverConfigVersion = 1
|
||||
)
|
||||
|
||||
var errAliasCycle = errors.New("alias cycle detected")
|
||||
|
||||
type aliasEntry struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
type serverConfig struct {
|
||||
Version int `json:"version"`
|
||||
Aliases []aliasEntry `json:"aliases"`
|
||||
}
|
||||
|
||||
type store struct {
|
||||
mu sync.RWMutex
|
||||
path string
|
||||
entries map[string]aliasEntry // normalized alias -> entry (exact matches)
|
||||
prefixEntries []aliasEntry // prefix matches, sorted longest-first
|
||||
}
|
||||
|
||||
func createStore(path string) (*store, error) {
|
||||
store := &store{
|
||||
path: path,
|
||||
entries: make(map[string]aliasEntry),
|
||||
}
|
||||
if err := store.load(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *store) load() error {
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var cfg serverConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Version != 0 && cfg.Version != serverConfigVersion {
|
||||
return fmt.Errorf("unsupported router config version %d", cfg.Version)
|
||||
}
|
||||
|
||||
for _, entry := range cfg.Aliases {
|
||||
targetName := model.ParseName(entry.Target)
|
||||
if !targetName.IsValid() {
|
||||
slog.Warn("invalid alias target in router config", "target", entry.Target)
|
||||
continue
|
||||
}
|
||||
canonicalTarget := displayAliasName(targetName)
|
||||
|
||||
if entry.PrefixMatching {
|
||||
// Prefix aliases don't need to be valid model names
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if alias == "" {
|
||||
slog.Warn("empty prefix alias in router config")
|
||||
continue
|
||||
}
|
||||
s.prefixEntries = append(s.prefixEntries, aliasEntry{
|
||||
Alias: alias,
|
||||
Target: canonicalTarget,
|
||||
PrefixMatching: true,
|
||||
})
|
||||
} else {
|
||||
aliasName := model.ParseName(entry.Alias)
|
||||
if !aliasName.IsValid() {
|
||||
slog.Warn("invalid alias name in router config", "alias", entry.Alias)
|
||||
continue
|
||||
}
|
||||
canonicalAlias := displayAliasName(aliasName)
|
||||
s.entries[normalizeAliasKey(aliasName)] = aliasEntry{
|
||||
Alias: canonicalAlias,
|
||||
Target: canonicalTarget,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort prefix entries by alias length descending (longest prefix wins)
|
||||
s.sortPrefixEntriesLocked()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *store) saveLocked() error {
|
||||
dir := filepath.Dir(s.path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine exact and prefix entries
|
||||
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
|
||||
for _, entry := range s.entries {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
entries = append(entries, s.prefixEntries...)
|
||||
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
|
||||
})
|
||||
|
||||
cfg := serverConfig{
|
||||
Version: serverConfigVersion,
|
||||
Aliases: entries,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp(dir, "router-*.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(cfg); err != nil {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Chmod(f.Name(), 0o644); err != nil {
|
||||
_ = os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
return os.Rename(f.Name(), s.path)
|
||||
}
|
||||
|
||||
func (s *store) ResolveName(name model.Name) (model.Name, bool, error) {
|
||||
// If a local model exists, do not allow alias shadowing (highest priority).
|
||||
exists, err := localModelExists(name)
|
||||
if err != nil {
|
||||
return name, false, err
|
||||
}
|
||||
if exists {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
key := normalizeAliasKey(name)
|
||||
|
||||
s.mu.RLock()
|
||||
entry, exactMatch := s.entries[key]
|
||||
var prefixMatch *aliasEntry
|
||||
if !exactMatch {
|
||||
// Try prefix matching - prefixEntries is sorted longest-first
|
||||
nameStr := strings.ToLower(displayAliasName(name))
|
||||
for i := range s.prefixEntries {
|
||||
prefix := strings.ToLower(s.prefixEntries[i].Alias)
|
||||
if strings.HasPrefix(nameStr, prefix) {
|
||||
prefixMatch = &s.prefixEntries[i]
|
||||
break // First match is longest due to sorting
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exactMatch && prefixMatch == nil {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
var current string
|
||||
var visited map[string]struct{}
|
||||
|
||||
if exactMatch {
|
||||
visited = map[string]struct{}{key: {}}
|
||||
current = entry.Target
|
||||
} else {
|
||||
// For prefix match, use the target as-is
|
||||
visited = map[string]struct{}{}
|
||||
current = prefixMatch.Target
|
||||
}
|
||||
|
||||
targetKey := normalizeAliasKeyString(current)
|
||||
|
||||
for {
|
||||
targetName := model.ParseName(current)
|
||||
if !targetName.IsValid() {
|
||||
return name, false, fmt.Errorf("alias target %q is invalid", current)
|
||||
}
|
||||
|
||||
if _, seen := visited[targetKey]; seen {
|
||||
return name, false, errAliasCycle
|
||||
}
|
||||
visited[targetKey] = struct{}{}
|
||||
|
||||
s.mu.RLock()
|
||||
next, ok := s.entries[targetKey]
|
||||
s.mu.RUnlock()
|
||||
if !ok {
|
||||
return targetName, true, nil
|
||||
}
|
||||
|
||||
current = next.Target
|
||||
targetKey = normalizeAliasKeyString(current)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *store) Set(alias, target model.Name, prefixMatching bool) error {
|
||||
targetKey := normalizeAliasKey(target)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if prefixMatching {
|
||||
// For prefix aliases, we skip cycle detection since prefix matching
|
||||
// works differently and the target is a specific model
|
||||
aliasStr := displayAliasName(alias)
|
||||
|
||||
// Remove any existing prefix entry with the same alias
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, aliasStr) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.prefixEntries = append(s.prefixEntries, aliasEntry{
|
||||
Alias: aliasStr,
|
||||
Target: displayAliasName(target),
|
||||
PrefixMatching: true,
|
||||
})
|
||||
s.sortPrefixEntriesLocked()
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
aliasKey := normalizeAliasKey(alias)
|
||||
|
||||
if aliasKey == targetKey {
|
||||
return fmt.Errorf("alias cannot point to itself")
|
||||
}
|
||||
|
||||
visited := map[string]struct{}{aliasKey: {}}
|
||||
currentKey := targetKey
|
||||
for {
|
||||
if _, seen := visited[currentKey]; seen {
|
||||
return errAliasCycle
|
||||
}
|
||||
visited[currentKey] = struct{}{}
|
||||
|
||||
next, ok := s.entries[currentKey]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
currentKey = normalizeAliasKeyString(next.Target)
|
||||
}
|
||||
|
||||
s.entries[aliasKey] = aliasEntry{
|
||||
Alias: displayAliasName(alias),
|
||||
Target: displayAliasName(target),
|
||||
}
|
||||
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *store) Delete(alias model.Name) (bool, error) {
|
||||
aliasKey := normalizeAliasKey(alias)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Try exact match first
|
||||
if _, ok := s.entries[aliasKey]; ok {
|
||||
delete(s.entries, aliasKey)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
|
||||
// Try prefix entries
|
||||
aliasStr := displayAliasName(alias)
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, aliasStr) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// DeleteByString deletes an alias by its raw string value, useful for prefix
|
||||
// aliases that may not be valid model names.
|
||||
func (s *store) DeleteByString(alias string) (bool, error) {
|
||||
alias = strings.TrimSpace(alias)
|
||||
aliasLower := strings.ToLower(alias)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Try prefix entries first (since this is mainly for prefix aliases)
|
||||
for i, e := range s.prefixEntries {
|
||||
if strings.EqualFold(e.Alias, alias) {
|
||||
s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// Also check exact entries by normalized key
|
||||
if _, ok := s.entries[aliasLower]; ok {
|
||||
delete(s.entries, aliasLower)
|
||||
return true, s.saveLocked()
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *store) List() []aliasEntry {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
|
||||
for _, entry := range s.entries {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
entries = append(entries, s.prefixEntries...)
|
||||
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
|
||||
})
|
||||
return entries
|
||||
}
|
||||
|
||||
func normalizeAliasKey(name model.Name) string {
|
||||
return strings.ToLower(displayAliasName(name))
|
||||
}
|
||||
|
||||
func (s *store) sortPrefixEntriesLocked() {
|
||||
sort.Slice(s.prefixEntries, func(i, j int) bool {
|
||||
// Sort by length descending (longest prefix first)
|
||||
return len(s.prefixEntries[i].Alias) > len(s.prefixEntries[j].Alias)
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeAliasKeyString(value string) string {
|
||||
n := model.ParseName(value)
|
||||
if !n.IsValid() {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
return normalizeAliasKey(n)
|
||||
}
|
||||
|
||||
func displayAliasName(n model.Name) string {
|
||||
display := n.DisplayShortest()
|
||||
if strings.EqualFold(n.Tag, "latest") {
|
||||
if idx := strings.LastIndex(display, ":"); idx != -1 {
|
||||
return display[:idx]
|
||||
}
|
||||
}
|
||||
return display
|
||||
}
|
||||
|
||||
func localModelExists(name model.Name) (bool, error) {
|
||||
manifests, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
needle := name.String()
|
||||
for existing := range manifests {
|
||||
if strings.EqualFold(existing.String(), needle) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func serverConfigPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return filepath.Join(".ollama", serverConfigFilename)
|
||||
}
|
||||
return filepath.Join(home, ".ollama", serverConfigFilename)
|
||||
}
|
||||
|
||||
func (s *Server) aliasStore() (*store, error) {
|
||||
s.aliasesOnce.Do(func() {
|
||||
s.aliases, s.aliasesErr = createStore(serverConfigPath())
|
||||
})
|
||||
|
||||
return s.aliases, s.aliasesErr
|
||||
}
|
||||
|
||||
func (s *Server) resolveAlias(name model.Name) (model.Name, bool, error) {
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
return name, false, err
|
||||
}
|
||||
|
||||
if store == nil {
|
||||
return name, false, nil
|
||||
}
|
||||
|
||||
return store.ResolveName(name)
|
||||
}
|
||||
@@ -1,144 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
type inferenceRequestLogger struct {
|
||||
dir string
|
||||
counter uint64
|
||||
}
|
||||
|
||||
func newInferenceRequestLogger() (*inferenceRequestLogger, error) {
|
||||
dir, err := os.MkdirTemp("", "ollama-request-logs-*")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &inferenceRequestLogger{dir: dir}, nil
|
||||
}
|
||||
|
||||
func (s *Server) initRequestLogging() error {
|
||||
if !envconfig.DebugLogRequests() {
|
||||
return nil
|
||||
}
|
||||
|
||||
requestLogger, err := newInferenceRequestLogger()
|
||||
if err != nil {
|
||||
return fmt.Errorf("enable OLLAMA_DEBUG_LOG_REQUESTS: %w", err)
|
||||
}
|
||||
|
||||
s.requestLogger = requestLogger
|
||||
slog.Info(fmt.Sprintf("request debug logging enabled; inference request logs will be stored in %s and include request bodies and replay curl commands", requestLogger.dir))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) withInferenceRequestLogging(route string, handlers ...gin.HandlerFunc) []gin.HandlerFunc {
|
||||
if s.requestLogger == nil {
|
||||
return handlers
|
||||
}
|
||||
|
||||
return append([]gin.HandlerFunc{s.requestLogger.middleware(route)}, handlers...)
|
||||
}
|
||||
|
||||
func (l *inferenceRequestLogger) middleware(route string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
method := c.Request.Method
|
||||
host := c.Request.Host
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
|
||||
var body []byte
|
||||
if c.Request.Body != nil {
|
||||
var err error
|
||||
body, err = io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
slog.Warn("failed to read request body for debug logging", "route", route, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
l.log(route, method, scheme, host, contentType, body)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *inferenceRequestLogger) log(route, method, scheme, host, contentType string, body []byte) {
|
||||
if l == nil || l.dir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
if host == "" || scheme == "" {
|
||||
base := envconfig.Host()
|
||||
if host == "" {
|
||||
host = base.Host
|
||||
}
|
||||
if scheme == "" {
|
||||
scheme = base.Scheme
|
||||
}
|
||||
}
|
||||
|
||||
routeForFilename := sanitizeRouteForFilename(route)
|
||||
timestamp := fmt.Sprintf("%s-%06d", time.Now().UTC().Format("20060102T150405.000000000Z"), atomic.AddUint64(&l.counter, 1))
|
||||
bodyFilename := fmt.Sprintf("%s_%s_body.json", timestamp, routeForFilename)
|
||||
curlFilename := fmt.Sprintf("%s_%s_request.sh", timestamp, routeForFilename)
|
||||
bodyPath := filepath.Join(l.dir, bodyFilename)
|
||||
curlPath := filepath.Join(l.dir, curlFilename)
|
||||
|
||||
if err := os.WriteFile(bodyPath, body, 0o600); err != nil {
|
||||
slog.Warn("failed to write debug request body", "route", route, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s://%s%s", scheme, host, route)
|
||||
curl := fmt.Sprintf("#!/bin/sh\nSCRIPT_DIR=\"$(CDPATH= cd -- \"$(dirname -- \"$0\")\" && pwd)\"\ncurl --request %s --url %q --header %q --data-binary @\"${SCRIPT_DIR}/%s\"\n", method, url, "Content-Type: "+contentType, bodyFilename)
|
||||
if err := os.WriteFile(curlPath, []byte(curl), 0o600); err != nil {
|
||||
slog.Warn("failed to write debug request replay command", "route", route, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info(fmt.Sprintf("logged to %s, replay using curl with `sh %s`", bodyPath, curlPath))
|
||||
}
|
||||
|
||||
func sanitizeRouteForFilename(route string) string {
|
||||
route = strings.TrimPrefix(route, "/")
|
||||
if route == "" {
|
||||
return "root"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(route))
|
||||
for _, r := range route {
|
||||
if ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || ('0' <= r && r <= '9') {
|
||||
b.WriteRune(r)
|
||||
} else {
|
||||
b.WriteByte('_')
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -51,7 +52,7 @@ import (
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenmanifest "github.com/ollama/ollama/x/imagegen/manifest"
|
||||
xserver "github.com/ollama/ollama/x/server"
|
||||
)
|
||||
|
||||
@@ -81,7 +82,9 @@ type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
defaultNumCtx int
|
||||
requestLogger *inferenceRequestLogger
|
||||
aliasesOnce sync.Once
|
||||
aliases *store
|
||||
aliasesErr error
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -192,9 +195,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
name = resolvedName
|
||||
|
||||
// We cannot currently consolidate this into GetModel because all we'll
|
||||
// induce infinite recursion given the current code structure.
|
||||
name, err := getExistingName(name)
|
||||
name, err = getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
@@ -1096,7 +1106,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
|
||||
// For image generation models, populate details from imagegen package
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
|
||||
if info, err := imagegenmanifest.GetModelInfo(name.String()); err == nil {
|
||||
modelDetails.Family = info.Architecture
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
||||
modelDetails.QuantizationLevel = info.Quantization
|
||||
@@ -1581,27 +1591,30 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
||||
r.POST("/api/copy", s.CopyHandler)
|
||||
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
|
||||
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
||||
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
r.POST("/api/generate", s.withInferenceRequestLogging("/api/generate", s.GenerateHandler)...)
|
||||
r.POST("/api/chat", s.withInferenceRequestLogging("/api/chat", s.ChatHandler)...)
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
r.POST("/v1/chat/completions", s.withInferenceRequestLogging("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)...)
|
||||
r.POST("/v1/completions", s.withInferenceRequestLogging("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)...)
|
||||
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", s.withInferenceRequestLogging("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)...)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoints
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", s.withInferenceRequestLogging("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)...)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
@@ -1651,9 +1664,6 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
if err := s.initRequestLogging(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
@@ -1954,13 +1964,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name, err := getExistingName(name)
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
name = resolvedName
|
||||
|
||||
name, err = getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(req.Model)
|
||||
m, err := GetModel(name.String())
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
|
||||
159
server/routes_aliases.go
Normal file
159
server/routes_aliases.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type aliasListResponse struct {
|
||||
Aliases []aliasEntry `json:"aliases"`
|
||||
}
|
||||
|
||||
type aliasDeleteRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
func (s *Server) ListAliasesHandler(c *gin.Context) {
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var aliases []aliasEntry
|
||||
if store != nil {
|
||||
aliases = store.List()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases})
|
||||
}
|
||||
|
||||
func (s *Server) CreateAliasHandler(c *gin.Context) {
|
||||
var req aliasEntry
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Alias = strings.TrimSpace(req.Alias)
|
||||
req.Target = strings.TrimSpace(req.Target)
|
||||
if req.Alias == "" || req.Target == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Target must always be a valid model name
|
||||
targetName := model.ParseName(req.Target)
|
||||
if !targetName.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)})
|
||||
return
|
||||
}
|
||||
|
||||
var aliasName model.Name
|
||||
if req.PrefixMatching {
|
||||
// For prefix aliases, we still parse the alias to normalize it,
|
||||
// but we allow any non-empty string since prefix patterns may not be valid model names
|
||||
aliasName = model.ParseName(req.Alias)
|
||||
// Even if not valid as a model name, we accept it for prefix matching
|
||||
} else {
|
||||
aliasName = model.ParseName(req.Alias)
|
||||
if !aliasName.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
|
||||
return
|
||||
}
|
||||
|
||||
if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"})
|
||||
return
|
||||
}
|
||||
|
||||
exists, err := localModelExists(aliasName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if exists {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.Set(aliasName, targetName, req.PrefixMatching); err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if errors.Is(err, errAliasCycle) {
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
c.AbortWithStatusJSON(status, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp := aliasEntry{
|
||||
Alias: displayAliasName(aliasName),
|
||||
Target: displayAliasName(targetName),
|
||||
PrefixMatching: req.PrefixMatching,
|
||||
}
|
||||
if req.PrefixMatching && !aliasName.IsValid() {
|
||||
// For prefix aliases that aren't valid model names, use the raw alias
|
||||
resp.Alias = req.Alias
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) DeleteAliasHandler(c *gin.Context) {
|
||||
var req aliasDeleteRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Alias = strings.TrimSpace(req.Alias)
|
||||
if req.Alias == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"})
|
||||
return
|
||||
}
|
||||
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
aliasName := model.ParseName(req.Alias)
|
||||
var deleted bool
|
||||
if aliasName.IsValid() {
|
||||
deleted, err = store.Delete(aliasName)
|
||||
} else {
|
||||
// For invalid model names (like prefix aliases), try deleting by raw string
|
||||
deleted, err = store.DeleteByString(req.Alias)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !deleted {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
426
server/routes_aliases_test.go
Normal file
426
server/routes_aliases_test.go
Normal file
@@ -0,0 +1,426 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestAliasShadowingRejected(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "shadowed-model",
|
||||
RemoteHost: "example.com",
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasResolvesForChatRemote(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
var remoteModel string
|
||||
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req api.ChatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
remoteModel = req.Model
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}))
|
||||
defer rs.Close()
|
||||
|
||||
p, err := url.Parse(rs.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_REMOTES", p.Hostname())
|
||||
|
||||
s := Server{}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "target-model",
|
||||
RemoteHost: rs.URL,
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "alias-model",
|
||||
Messages: []api.Message{{Role: "user", Content: "hi"}},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Model != "alias-model" {
|
||||
t.Fatalf("expected response model to be alias-model, got %q", resp.Model)
|
||||
}
|
||||
|
||||
if remoteModel != "test" {
|
||||
t.Fatalf("expected remote model to be 'test', got %q", remoteModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasBasicMatching(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a prefix alias: "myprefix-" -> "targetmodel"
|
||||
targetName := model.ParseName("targetmodel")
|
||||
|
||||
// Set a prefix alias (using "myprefix-" as the pattern)
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = append(store.prefixEntries, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "targetmodel",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
store.mu.Unlock()
|
||||
|
||||
// Test that "myprefix-foo" resolves to "targetmodel"
|
||||
testName := model.ParseName("myprefix-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != targetName.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", targetName.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Test that "otherprefix-foo" does not resolve
|
||||
otherName := model.ParseName("otherprefix-foo")
|
||||
_, wasResolved, err = store.ResolveName(otherName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if wasResolved {
|
||||
t.Fatal("expected name not to be resolved")
|
||||
}
|
||||
|
||||
// Test that exact alias takes precedence
|
||||
exactAlias := model.ParseName("myprefix-exact")
|
||||
exactTarget := model.ParseName("exacttarget")
|
||||
if err := store.Set(exactAlias, exactTarget, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resolved, wasResolved, err = store.ResolveName(exactAlias)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != exactTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q (exact match), got %q", exactTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasLongestMatchWins(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add two prefix aliases with overlapping patterns
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "abc-", Target: "short-target", PrefixMatching: true},
|
||||
{Alias: "abc-def-", Target: "long-target", PrefixMatching: true},
|
||||
}
|
||||
store.sortPrefixEntriesLocked()
|
||||
store.mu.Unlock()
|
||||
|
||||
// "abc-def-ghi" should match the longer prefix "abc-def-"
|
||||
testName := model.ParseName("abc-def-ghi")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
expectedLongTarget := model.ParseName("long-target")
|
||||
if resolved.DisplayShortest() != expectedLongTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q (longest prefix match), got %q", expectedLongTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// "abc-xyz" should match the shorter prefix "abc-"
|
||||
testName2 := model.ParseName("abc-xyz")
|
||||
resolved, wasResolved, err = store.ResolveName(testName2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
expectedShortTarget := model.ParseName("short-target")
|
||||
if resolved.DisplayShortest() != expectedShortTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedShortTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasChain(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a chain: prefix "test-" -> "intermediate" -> "final"
|
||||
intermediate := model.ParseName("intermediate")
|
||||
final := model.ParseName("final")
|
||||
|
||||
// Add prefix alias
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "test-", Target: "intermediate", PrefixMatching: true},
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
// Add exact alias for the intermediate step
|
||||
if err := store.Set(intermediate, final, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// "test-foo" should resolve through the chain to "final"
|
||||
testName := model.ParseName("test-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved")
|
||||
}
|
||||
if resolved.DisplayShortest() != final.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", final.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasCRUD(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
|
||||
// Create a prefix alias via API
|
||||
w := createRequest(t, s.CreateAliasHandler, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "llama2",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var createResp aliasEntry
|
||||
if err := json.NewDecoder(w.Body).Decode(&createResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !createResp.PrefixMatching {
|
||||
t.Fatal("expected prefix_matching to be true in response")
|
||||
}
|
||||
|
||||
// List aliases and verify the prefix alias is included
|
||||
w = createRequest(t, s.ListAliasesHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var listResp aliasListResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, a := range listResp.Aliases {
|
||||
if a.PrefixMatching && a.Target == "llama2" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("expected to find prefix alias in list")
|
||||
}
|
||||
|
||||
// Delete the prefix alias
|
||||
w = createRequest(t, s.DeleteAliasHandler, aliasDeleteRequest{Alias: "myprefix-"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify it's deleted
|
||||
w = createRequest(t, s.ListAliasesHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, a := range listResp.Aliases {
|
||||
if a.PrefixMatching {
|
||||
t.Fatal("expected prefix alias to be deleted")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasCaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
store, err := createStore(filepath.Join(tmpDir, "server.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add a prefix alias with mixed case
|
||||
store.mu.Lock()
|
||||
store.prefixEntries = []aliasEntry{
|
||||
{Alias: "MyPrefix-", Target: "targetmodel", PrefixMatching: true},
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
// Test that matching is case-insensitive
|
||||
testName := model.ParseName("myprefix-foo")
|
||||
resolved, wasResolved, err := store.ResolveName(testName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved (case-insensitive)")
|
||||
}
|
||||
expectedTarget := model.ParseName("targetmodel")
|
||||
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Test uppercase request
|
||||
testName2 := model.ParseName("MYPREFIX-BAR")
|
||||
_, wasResolved, err = store.ResolveName(testName2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected name to be resolved (uppercase)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("HOME", t.TempDir())
|
||||
|
||||
s := Server{}
|
||||
|
||||
// Create a local model that would match a prefix alias
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "myprefix-localmodel",
|
||||
RemoteHost: "example.com",
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Create a prefix alias that would match the local model name
|
||||
w = createRequest(t, s.CreateAliasHandler, aliasEntry{
|
||||
Alias: "myprefix-",
|
||||
Target: "someothermodel",
|
||||
PrefixMatching: true,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify that resolving "myprefix-localmodel" returns the local model, not the alias target
|
||||
store, err := s.aliasStore()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
localModelName := model.ParseName("myprefix-localmodel")
|
||||
resolved, wasResolved, err := store.ResolveName(localModelName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if wasResolved {
|
||||
t.Fatalf("expected local model to take precedence (wasResolved should be false), but got resolved to %q", resolved.DisplayShortest())
|
||||
}
|
||||
if resolved.DisplayShortest() != localModelName.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be local model %q, got %q", localModelName.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
|
||||
// Also verify that a non-local model matching the prefix DOES resolve to the alias target
|
||||
nonLocalName := model.ParseName("myprefix-nonexistent")
|
||||
resolved, wasResolved, err = store.ResolveName(nonLocalName)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasResolved {
|
||||
t.Fatal("expected non-local model to resolve via prefix alias")
|
||||
}
|
||||
expectedTarget := model.ParseName("someothermodel")
|
||||
if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
|
||||
t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
|
||||
}
|
||||
}
|
||||
@@ -1,128 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestInferenceRequestLoggerMiddlewareWritesReplayArtifacts(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logDir := t.TempDir()
|
||||
requestLogger := &inferenceRequestLogger{dir: logDir}
|
||||
|
||||
const route = "/v1/chat/completions"
|
||||
const requestBody = `{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`
|
||||
|
||||
var bodySeenByHandler string
|
||||
|
||||
r := gin.New()
|
||||
r.POST(route, requestLogger.middleware(route), func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body in handler: %v", err)
|
||||
}
|
||||
|
||||
bodySeenByHandler = string(body)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, route, strings.NewReader(requestBody))
|
||||
req.Host = "127.0.0.1:11434"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if bodySeenByHandler != requestBody {
|
||||
t.Fatalf("handler body mismatch:\nexpected: %s\ngot: %s", requestBody, bodySeenByHandler)
|
||||
}
|
||||
|
||||
bodyFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_body.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to glob body logs: %v", err)
|
||||
}
|
||||
if len(bodyFiles) != 1 {
|
||||
t.Fatalf("expected 1 body log, got %d (%v)", len(bodyFiles), bodyFiles)
|
||||
}
|
||||
|
||||
curlFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_request.sh"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to glob curl logs: %v", err)
|
||||
}
|
||||
if len(curlFiles) != 1 {
|
||||
t.Fatalf("expected 1 curl log, got %d (%v)", len(curlFiles), curlFiles)
|
||||
}
|
||||
|
||||
bodyData, err := os.ReadFile(bodyFiles[0])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body log: %v", err)
|
||||
}
|
||||
if string(bodyData) != requestBody {
|
||||
t.Fatalf("body log mismatch:\nexpected: %s\ngot: %s", requestBody, string(bodyData))
|
||||
}
|
||||
|
||||
curlData, err := os.ReadFile(curlFiles[0])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read curl log: %v", err)
|
||||
}
|
||||
|
||||
curlString := string(curlData)
|
||||
if !strings.Contains(curlString, "http://127.0.0.1:11434"+route) {
|
||||
t.Fatalf("curl log does not contain expected route URL: %s", curlString)
|
||||
}
|
||||
|
||||
bodyFileName := filepath.Base(bodyFiles[0])
|
||||
if !strings.Contains(curlString, "@\"${SCRIPT_DIR}/"+bodyFileName+"\"") {
|
||||
t.Fatalf("curl log does not reference sibling body file: %s", curlString)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewInferenceRequestLoggerCreatesDirectory(t *testing.T) {
|
||||
requestLogger, err := newInferenceRequestLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error creating request logger: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(requestLogger.dir)
|
||||
})
|
||||
|
||||
if requestLogger == nil || requestLogger.dir == "" {
|
||||
t.Fatalf("expected request logger directory to be set")
|
||||
}
|
||||
|
||||
info, err := os.Stat(requestLogger.dir)
|
||||
if err != nil {
|
||||
t.Fatalf("expected directory to exist: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Fatalf("expected %q to be a directory", requestLogger.dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRouteForFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
route string
|
||||
want string
|
||||
}{
|
||||
{route: "/api/generate", want: "api_generate"},
|
||||
{route: "/v1/chat/completions", want: "v1_chat_completions"},
|
||||
{route: "/v1/messages", want: "v1_messages"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := sanitizeRouteForFilename(tt.route); got != tt.want {
|
||||
t.Fatalf("sanitizeRouteForFilename(%q) = %q, want %q", tt.route, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -567,16 +567,16 @@ iGPUScan:
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
// Determine mode based on capabilities
|
||||
var mode mlxrunner.ModelMode
|
||||
var mode imagegen.ModelMode
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
mode = mlxrunner.ModeImageGen
|
||||
mode = imagegen.ModeImageGen
|
||||
} else {
|
||||
mode = mlxrunner.ModeLLM
|
||||
mode = imagegen.ModeLLM
|
||||
}
|
||||
|
||||
// Use model name for MLX (it resolves manifests by name, not file path)
|
||||
modelName := req.model.ShortName
|
||||
server, err := mlxrunner.NewServer(modelName, mode)
|
||||
server, err := imagegen.NewServer(modelName, mode)
|
||||
if err != nil {
|
||||
req.errCh <- err
|
||||
return true
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
@@ -18,19 +18,19 @@ type BytePairEncoding struct {
|
||||
regexps []*regexp2.Regexp
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
var _ Tokenizer = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||
if len(pretokenizers) == 0 {
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
|
||||
if len(pretokenizer) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
|
||||
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizers {
|
||||
for _, p := range pretokenizer {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
func llama(t testing.TB) BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
|
||||
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
|
||||
}
|
||||
}
|
||||
|
||||
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
|
||||
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
@@ -17,7 +17,7 @@ type SentencePiece struct {
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePiece)(nil)
|
||||
var _ Tokenizer = (*SentencePiece)(nil)
|
||||
|
||||
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
@@ -224,7 +224,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
data := spm.vocab.Decode(id)
|
||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||
|
||||
// For tokenizers that use byte tokens like "<0xEA>"
|
||||
// For tokenizer that use byte tokens like "<0xEA>"
|
||||
// convert them to the partial unicode character
|
||||
// so they are buffered correctly by the runner instead
|
||||
// of being sent back to the api as "<0xEA>"
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||||
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
const (
|
||||
TOKEN_TYPE_NORMAL = iota + 1
|
||||
@@ -9,7 +9,7 @@ const (
|
||||
TOKEN_TYPE_BYTE
|
||||
)
|
||||
|
||||
type TextProcessor interface {
|
||||
type Tokenizer interface {
|
||||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(int32, Special) bool
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -32,7 +32,7 @@ var wordPieceReplacer = strings.NewReplacer(
|
||||
" 're", "'re",
|
||||
)
|
||||
|
||||
// Decode implements TextProcessor.
|
||||
// Decode implements Tokenizer.
|
||||
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for i, id := range ids {
|
||||
@@ -96,7 +96,7 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||
}
|
||||
}
|
||||
|
||||
// Encode implements TextProcessor.
|
||||
// Encode implements Tokenizer.
|
||||
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
var ids []int32
|
||||
|
||||
@@ -151,17 +151,17 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// Is implements TextProcessor.
|
||||
// Is implements Tokenizer.
|
||||
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||
return wpm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
// Vocabulary implements TextProcessor.
|
||||
// Vocabulary implements Tokenizer.
|
||||
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||
return wpm.vocab
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*WordPiece)(nil)
|
||||
var _ Tokenizer = (*WordPiece)(nil)
|
||||
|
||||
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
|
||||
return WordPiece{
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"slices"
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
|
||||
@@ -35,7 +36,7 @@ type ModelfileConfig struct {
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "q4", "q8", "nvfp4", or "mxfp8" for quantization
|
||||
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
}
|
||||
|
||||
@@ -94,6 +95,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities, parserName, rendererName),
|
||||
progressFn,
|
||||
newPackedTensorLayerCreator(),
|
||||
)
|
||||
} else {
|
||||
err = create.CreateImageGenModel(
|
||||
@@ -141,60 +143,33 @@ func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
|
||||
}
|
||||
}
|
||||
|
||||
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
|
||||
// createQuantizedLayers quantizes a tensor and returns a single combined layer.
|
||||
// The combined blob contains data, scale, and optional bias tensors with metadata.
|
||||
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
|
||||
if !QuantizeSupported() {
|
||||
return nil, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
|
||||
// Quantize the tensor
|
||||
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize)
|
||||
// Quantize the tensor into a single combined blob
|
||||
blobData, err := quantizeTensor(r, name, dtype, shape, quantize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
|
||||
// Create single layer for the combined blob
|
||||
layer, err := manifest.NewLayer(bytes.NewReader(blobData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales
|
||||
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := []create.LayerInfo{
|
||||
return []create.LayerInfo{
|
||||
{
|
||||
Digest: weightLayer.Digest,
|
||||
Size: weightLayer.Size,
|
||||
MediaType: weightLayer.MediaType,
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
{
|
||||
Digest: scalesLayer.Digest,
|
||||
Size: scalesLayer.Size,
|
||||
MediaType: scalesLayer.MediaType,
|
||||
Name: name + "_scale",
|
||||
},
|
||||
}
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, create.LayerInfo{
|
||||
Digest: qbiasLayer.Digest,
|
||||
Size: qbiasLayer.Size,
|
||||
MediaType: qbiasLayer.MediaType,
|
||||
Name: name + "_qbias",
|
||||
})
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createUnquantizedLayer creates a single tensor layer without quantization.
|
||||
@@ -214,6 +189,58 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newPackedTensorLayerCreator returns a PackedTensorLayerCreator callback for
|
||||
// creating packed multi-tensor blob layers (used for expert groups).
|
||||
func newPackedTensorLayerCreator() create.PackedTensorLayerCreator {
|
||||
return func(groupName string, tensors []create.PackedTensorInput) (create.LayerInfo, error) {
|
||||
// Check if any tensor in the group needs quantization
|
||||
hasQuantize := false
|
||||
for _, t := range tensors {
|
||||
if t.Quantize != "" {
|
||||
hasQuantize = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var blobReader io.Reader
|
||||
if hasQuantize {
|
||||
if !QuantizeSupported() {
|
||||
return create.LayerInfo{}, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
blobData, err := quantizePackedGroup(tensors)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, fmt.Errorf("failed to quantize packed group %s: %w", groupName, err)
|
||||
}
|
||||
blobReader = bytes.NewReader(blobData)
|
||||
} else {
|
||||
// Build unquantized packed blob using streaming reader
|
||||
// Extract raw tensor data from safetensors-wrapped readers
|
||||
var tds []*safetensors.TensorData
|
||||
for _, t := range tensors {
|
||||
rawData, err := safetensors.ExtractRawFromSafetensors(t.Reader)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, fmt.Errorf("failed to extract tensor %s: %w", t.Name, err)
|
||||
}
|
||||
td := safetensors.NewTensorDataFromBytes(t.Name, t.Dtype, t.Shape, rawData)
|
||||
tds = append(tds, td)
|
||||
}
|
||||
blobReader = safetensors.BuildPackedSafetensorsReader(tds)
|
||||
}
|
||||
|
||||
layer, err := manifest.NewLayer(blobReader, manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, err
|
||||
}
|
||||
|
||||
return create.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: groupName,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter {
|
||||
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
|
||||
|
||||
@@ -3,128 +3,195 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/x/create"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it,
|
||||
// and returns safetensors data for the quantized weights, scales, and biases.
|
||||
// Supported quantization types:
|
||||
// - "q4": affine 4-bit, group_size=32 (with qbiases)
|
||||
// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales)
|
||||
// - "q8": affine 8-bit, group_size=64 (with qbiases)
|
||||
// - "mxfp8": Microsoft MX FP8, group_size=32 (no qbiases, E4M3 scales)
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
// quantizeParams maps quantization type names to MLX quantize parameters.
|
||||
var quantizeParams = map[string]struct {
|
||||
groupSize int
|
||||
bits int
|
||||
mode string
|
||||
}{
|
||||
"int4": {32, 4, "affine"},
|
||||
"nvfp4": {16, 4, "nvfp4"},
|
||||
"int8": {64, 8, "affine"},
|
||||
"mxfp8": {32, 8, "mxfp8"},
|
||||
}
|
||||
|
||||
// loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX,
|
||||
// quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias)
|
||||
// to the provided maps. If quantize is empty, the tensor is kept as-is.
|
||||
// Returns any temp file paths created (caller must clean up) and arrays needing eval.
|
||||
func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) {
|
||||
tmpDir := ensureTempDir()
|
||||
|
||||
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
|
||||
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
|
||||
tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors")
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
|
||||
return "", nil, nil, fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
defer os.Remove(tmpPath)
|
||||
tmpPath = tmpFile.Name()
|
||||
|
||||
if _, err := io.Copy(tmpFile, r); err != nil {
|
||||
tmpFile.Close()
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
|
||||
return tmpPath, nil, nil, fmt.Errorf("failed to write temp file for %s: %w", name, err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
// Load the tensor using MLX's native loader
|
||||
st, err := mlx.LoadSafetensorsNative(tmpPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
|
||||
return tmpPath, nil, nil, fmt.Errorf("failed to load safetensors for %s: %w", name, err)
|
||||
}
|
||||
defer st.Free()
|
||||
|
||||
// Get the tensor (it's stored as "data" in our minimal safetensors format)
|
||||
arr := st.Get("data")
|
||||
// Find the tensor key (may differ from name for single-tensor blobs)
|
||||
inputKey, err := findSafetensorsKey(tmpPath)
|
||||
if err != nil {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err)
|
||||
}
|
||||
|
||||
arr := st.Get(inputKey)
|
||||
if arr == nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey)
|
||||
}
|
||||
|
||||
// Convert to BFloat16 if needed (quantize expects float type)
|
||||
if quantize == "" {
|
||||
arr = mlx.Contiguous(arr)
|
||||
arrays[name] = arr
|
||||
return tmpPath, []*mlx.Array{arr}, st, nil
|
||||
}
|
||||
|
||||
// Convert to float type if needed (quantize expects float)
|
||||
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
|
||||
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
|
||||
mlx.Eval(arr)
|
||||
}
|
||||
|
||||
// Quantize based on quantization type
|
||||
var qweight, scales, qbiases *mlx.Array
|
||||
switch quantize {
|
||||
case "q4":
|
||||
// affine mode: group_size=32, bits=4 (with qbiases for zero-point offset)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
|
||||
case "nvfp4":
|
||||
// NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4")
|
||||
case "q8":
|
||||
// affine mode: group_size=64, bits=8 (with qbiases for zero-point offset)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 64, 8, "affine")
|
||||
case "mxfp8":
|
||||
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbiases)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "mxfp8")
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
params, ok := quantizeParams[quantize]
|
||||
if !ok {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
}
|
||||
|
||||
// Eval and make contiguous for data access
|
||||
qweight, scales, qbiases := mlx.Quantize(arr, params.groupSize, params.bits, params.mode)
|
||||
|
||||
qweight = mlx.Contiguous(qweight)
|
||||
scales = mlx.Contiguous(scales)
|
||||
arrays[name] = qweight
|
||||
arrays[name+".scale"] = scales
|
||||
toEval = append(toEval, qweight, scales)
|
||||
|
||||
if qbiases != nil {
|
||||
qbiases = mlx.Contiguous(qbiases)
|
||||
mlx.Eval(qweight, scales, qbiases)
|
||||
} else {
|
||||
mlx.Eval(qweight, scales)
|
||||
arrays[name+".bias"] = qbiases
|
||||
toEval = append(toEval, qbiases)
|
||||
}
|
||||
|
||||
// Get shapes
|
||||
qweightShape = qweight.Shape()
|
||||
scalesShape = scales.Shape()
|
||||
return tmpPath, toEval, st, nil
|
||||
}
|
||||
|
||||
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
|
||||
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
|
||||
defer os.Remove(qweightPath)
|
||||
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it,
|
||||
// and returns a single combined safetensors blob with the quantized weight, scale, and optional bias.
|
||||
// Tensor keys use the original tensor name: name, name.scale, name.bias.
|
||||
// The blob includes __metadata__ with quant_type and group_size.
|
||||
// Supported quantization types: "int4", "nvfp4", "int8", "mxfp8".
|
||||
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
|
||||
arrays := make(map[string]*mlx.Array)
|
||||
tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays)
|
||||
if tmpPath != "" {
|
||||
defer os.Remove(tmpPath)
|
||||
}
|
||||
if st != nil {
|
||||
defer st.Free()
|
||||
}
|
||||
qweightData, err = os.ReadFile(qweightPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Save scales using MLX's native safetensors
|
||||
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
|
||||
defer os.Remove(scalesPath)
|
||||
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
|
||||
}
|
||||
scalesData, err = os.ReadFile(scalesPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
|
||||
mlx.Eval(toEval...)
|
||||
|
||||
// Build metadata for single-tensor blobs
|
||||
params := quantizeParams[quantize]
|
||||
metadata := map[string]string{
|
||||
"quant_type": quantize,
|
||||
"group_size": strconv.Itoa(params.groupSize),
|
||||
}
|
||||
|
||||
// Affine mode returns qbiases for zero-point offset
|
||||
if qbiases != nil {
|
||||
qbiasShape = qbiases.Shape()
|
||||
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
|
||||
defer os.Remove(qbiasPath)
|
||||
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
|
||||
tmpDir := ensureTempDir()
|
||||
outPath := filepath.Join(tmpDir, "combined.safetensors")
|
||||
defer os.Remove(outPath)
|
||||
if err := mlx.SaveSafetensorsWithMetadata(outPath, arrays, metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to save combined blob: %w", err)
|
||||
}
|
||||
return os.ReadFile(outPath)
|
||||
}
|
||||
|
||||
// quantizePackedGroup quantizes multiple tensors and saves them all into a single
|
||||
// combined safetensors blob. Used for packing expert groups.
|
||||
// Each tensor may have a different quantization type (mixed-precision).
|
||||
// Returns the blob bytes. No __metadata__ is added because different tensors
|
||||
// may use different quantization types.
|
||||
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
allArrays := make(map[string]*mlx.Array)
|
||||
var allToEval []*mlx.Array
|
||||
var tmpPaths []string
|
||||
var handles []*mlx.SafetensorsFile
|
||||
|
||||
for _, input := range inputs {
|
||||
tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays)
|
||||
if tmpPath != "" {
|
||||
tmpPaths = append(tmpPaths, tmpPath)
|
||||
}
|
||||
if st != nil {
|
||||
handles = append(handles, st)
|
||||
}
|
||||
qbiasData, err = os.ReadFile(qbiasPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
|
||||
// Cleanup on error
|
||||
for _, h := range handles {
|
||||
h.Free()
|
||||
}
|
||||
for _, p := range tmpPaths {
|
||||
os.Remove(p)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
allToEval = append(allToEval, toEval...)
|
||||
}
|
||||
|
||||
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
|
||||
mlx.Eval(allToEval...)
|
||||
|
||||
// Free native handles after eval
|
||||
for _, h := range handles {
|
||||
h.Free()
|
||||
}
|
||||
|
||||
// Save combined blob (no global metadata for mixed-precision packed blobs)
|
||||
tmpDir := ensureTempDir()
|
||||
outPath := filepath.Join(tmpDir, "packed-combined.safetensors")
|
||||
defer os.Remove(outPath)
|
||||
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, nil); err != nil {
|
||||
return nil, fmt.Errorf("failed to save packed blob: %w", err)
|
||||
}
|
||||
|
||||
blobData, err := os.ReadFile(outPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read packed blob: %w", err)
|
||||
}
|
||||
|
||||
for _, p := range tmpPaths {
|
||||
os.Remove(p)
|
||||
}
|
||||
|
||||
return blobData, nil
|
||||
}
|
||||
|
||||
// QuantizeSupported returns true if quantization is supported (MLX build)
|
||||
@@ -138,3 +205,33 @@ func ensureTempDir() string {
|
||||
os.MkdirAll(tmpDir, 0755)
|
||||
return tmpDir
|
||||
}
|
||||
|
||||
// findSafetensorsKey reads the first non-metadata tensor key from a safetensors file.
|
||||
func findSafetensorsKey(path string) (string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return "", err
|
||||
}
|
||||
headerBytes := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(f, headerBytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for k := range header {
|
||||
if k != "__metadata__" {
|
||||
return k, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no tensor found in safetensors header")
|
||||
}
|
||||
|
||||
@@ -5,11 +5,18 @@ package client
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
|
||||
// quantizeTensor is not available without MLX
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
|
||||
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
// quantizePackedGroup is not available without MLX
|
||||
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
// QuantizeSupported returns false when MLX is not available
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user